diff --git a/google/cloud/firestore_v1/__init__.py b/google/cloud/firestore_v1/__init__.py index 049eb4183..122c146ea 100644 --- a/google/cloud/firestore_v1/__init__.py +++ b/google/cloud/firestore_v1/__init__.py @@ -21,7 +21,7 @@ from google.cloud.firestore_v1 import gapic_version as package_version -__version__ = package_version.__version__ +__version__: str = package_version.__version__ from typing import List diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index eb08f92b2..399bdb066 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -13,7 +13,7 @@ # limitations under the License. """Common helpers shared across Google Cloud Firestore modules.""" - +from __future__ import annotations import datetime import json from typing import ( @@ -22,14 +22,17 @@ Generator, Iterator, List, - NoReturn, Optional, + Sequence, Tuple, Union, + cast, + TYPE_CHECKING, ) import grpc # type: ignore from google.api_core import gapic_v1 +from google.api_core import retry as retries from google.api_core.datetime_helpers import DatetimeWithNanoseconds from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore from google.protobuf import struct_pb2 @@ -44,6 +47,9 @@ from google.cloud.firestore_v1.types.write import DocumentTransform from google.cloud.firestore_v1.vector import Vector +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1 import DocumentSnapshot + _EmptyDict: transforms.Sentinel _GRPC_ERROR_MAPPING: dict @@ -234,7 +240,9 @@ def encode_dict(values_dict) -> dict: return {key: encode_value(value) for key, value in values_dict.items()} -def document_snapshot_to_protobuf(snapshot: "google.cloud.firestore_v1.base_document.DocumentSnapshot") -> Optional["google.cloud.firestore_v1.types.Document"]: # type: ignore +def document_snapshot_to_protobuf( + snapshot: "DocumentSnapshot", +) -> Optional["google.cloud.firestore_v1.types.Document"]: from google.cloud.firestore_v1.types import Document if not snapshot.exists: @@ -405,7 +413,8 @@ def decode_dict(value_fields, client) -> Union[dict, Vector]: if res.get("__type__", None) == "__vector__": # Vector data type is represented as mapping. # {"__type__":"__vector__", "value": [1.0, 2.0, 3.0]}. - return Vector(res["value"]) + values = cast(Sequence[float], res["value"]) + return Vector(values) return res @@ -504,7 +513,7 @@ def __init__(self, document_data) -> None: self.increments = {} self.minimums = {} self.maximums = {} - self.set_fields = {} + self.set_fields: dict = {} self.empty_document = False prefix_path = FieldPath() @@ -566,7 +575,9 @@ def transform_paths(self): + list(self.minimums) ) - def _get_update_mask(self, allow_empty_mask=False) -> None: + def _get_update_mask( + self, allow_empty_mask=False + ) -> Optional[types.common.DocumentMask]: return None def get_update_pb( @@ -730,9 +741,9 @@ class DocumentExtractorForMerge(DocumentExtractor): def __init__(self, document_data) -> None: super(DocumentExtractorForMerge, self).__init__(document_data) - self.data_merge = [] - self.transform_merge = [] - self.merge = [] + self.data_merge: list = [] + self.transform_merge: list = [] + self.merge: list = [] def _apply_merge_all(self) -> None: self.data_merge = sorted(self.field_paths + self.deleted_fields) @@ -786,7 +797,7 @@ def _apply_merge_paths(self, merge) -> None: self.data_merge.append(field_path) # Clear out data for fields not merged. - merged_set_fields = {} + merged_set_fields: dict = {} for field_path in self.data_merge: value = get_field_value(self.document_data, field_path) set_field_value(merged_set_fields, field_path, value) @@ -1019,7 +1030,7 @@ def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]: class WriteOption(object): """Option used to assert a condition on a write operation.""" - def modify_write(self, write, no_create_msg=None) -> NoReturn: + def modify_write(self, write, no_create_msg=None) -> None: """Modify a ``Write`` protobuf based on the state of this write option. This is a virtual method intended to be implemented by subclasses. @@ -1059,7 +1070,7 @@ def __eq__(self, other): return NotImplemented return self._last_update_time == other._last_update_time - def modify_write(self, write, **unused_kwargs) -> None: + def modify_write(self, write, *unused_args, **unused_kwargs) -> None: """Modify a ``Write`` protobuf based on the state of this write option. The ``last_update_time`` is added to ``write_pb`` as an "update time" @@ -1096,7 +1107,7 @@ def __eq__(self, other): return NotImplemented return self._exists == other._exists - def modify_write(self, write, **unused_kwargs) -> None: + def modify_write(self, write, *unused_args, **unused_kwargs) -> None: """Modify a ``Write`` protobuf based on the state of this write option. If: @@ -1115,7 +1126,9 @@ def modify_write(self, write, **unused_kwargs) -> None: write._pb.current_document.CopyFrom(current_doc._pb) -def make_retry_timeout_kwargs(retry, timeout) -> dict: +def make_retry_timeout_kwargs( + retry: retries.Retry | retries.AsyncRetry | object | None, timeout: float | None +) -> dict: """Helper fo API methods which take optional 'retry' / 'timeout' args.""" kwargs = {} @@ -1152,8 +1165,8 @@ def compare_timestamps( def deserialize_bundle( serialized: Union[str, bytes], - client: "google.cloud.firestore_v1.client.BaseClient", # type: ignore -) -> "google.cloud.firestore_bundle.FirestoreBundle": # type: ignore + client: "google.cloud.firestore_v1.client.BaseClient", +) -> "google.cloud.firestore_bundle.FirestoreBundle": """Inverse operation to a `FirestoreBundle` instance's `build()` method. Args: @@ -1211,7 +1224,7 @@ def deserialize_bundle( # Create and add our BundleElement bundle_element: BundleElement try: - bundle_element: BundleElement = BundleElement.from_json(json.dumps(data)) # type: ignore + bundle_element = BundleElement.from_json(json.dumps(data)) except AttributeError as e: # Some bad serialization formats cannot be universally deserialized. if e.args[0] == "'dict' object has no attribute 'find'": # pragma: NO COVER @@ -1235,18 +1248,22 @@ def deserialize_bundle( if "__end__" not in allowed_next_element_types: raise ValueError("Unexpected end to serialized FirestoreBundle") - + # state machine guarantees bundle and metadata have been populated + bundle = cast(FirestoreBundle, bundle) + metadata_bundle_element = cast(BundleElement, metadata_bundle_element) # Now, finally add the metadata element bundle._add_bundle_element( metadata_bundle_element, client=client, - type="metadata", # type: ignore + type="metadata", ) return bundle -def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict, None, None]: # type: ignore +def _parse_bundle_elements_data( + serialized: Union[str, bytes] +) -> Generator[Dict, None, None]: """Reads through a serialized FirestoreBundle and yields JSON chunks that were created via `BundleElement.to_json(bundle_element)`. @@ -1290,7 +1307,7 @@ def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict def _get_documents_from_bundle( bundle, *, query_name: Optional[str] = None -) -> Generator["google.cloud.firestore.DocumentSnapshot", None, None]: # type: ignore +) -> Generator["DocumentSnapshot", None, None]: from google.cloud.firestore_bundle.bundle import _BundledDocument bundled_doc: _BundledDocument @@ -1304,7 +1321,9 @@ def _get_document_from_bundle( bundle, *, document_id: str, -) -> Optional["google.cloud.firestore.DocumentSnapshot"]: # type: ignore +) -> Optional["DocumentSnapshot"]: bundled_doc = bundle.documents.get(document_id) if bundled_doc: return bundled_doc.snapshot + else: + return None diff --git a/google/cloud/firestore_v1/aggregation.py b/google/cloud/firestore_v1/aggregation.py index f0e3f94ba..ec0fbc189 100644 --- a/google/cloud/firestore_v1/aggregation.py +++ b/google/cloud/firestore_v1/aggregation.py @@ -52,9 +52,7 @@ def __init__( def get( self, transaction=None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, + retry: Union[retries.Retry, None, object] = gapic_v1.method.DEFAULT, timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, @@ -131,9 +129,7 @@ def _retry_query_after_exception(self, exc, retry, transaction): def _make_stream( self, transaction: Optional[transaction.Transaction] = None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, + retry: Union[retries.Retry, None, object] = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, ) -> Generator[List[AggregationResult], Any, Optional[ExplainMetrics]]: @@ -206,9 +202,7 @@ def _make_stream( def stream( self, transaction: Optional["transaction.Transaction"] = None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, + retry: Union[retries.Retry, None, object] = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/google/cloud/firestore_v1/async_aggregation.py b/google/cloud/firestore_v1/async_aggregation.py index 5855b7161..fc78f31fd 100644 --- a/google/cloud/firestore_v1/async_aggregation.py +++ b/google/cloud/firestore_v1/async_aggregation.py @@ -51,9 +51,7 @@ def __init__( async def get( self, transaction=None, - retry: Union[ - retries.AsyncRetry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, + retry: Union[retries.AsyncRetry, None, object] = gapic_v1.method.DEFAULT, timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, @@ -102,7 +100,7 @@ async def get( async def _make_stream( self, transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, ) -> AsyncGenerator[List[AggregationResult] | query_profile_pb.ExplainMetrics, Any]: @@ -162,7 +160,7 @@ async def _make_stream( def stream( self, transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index fed87d27f..689753fe9 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -13,7 +13,7 @@ # limitations under the License. """Helpers for batch requests to the Google Cloud Firestore API.""" - +from __future__ import annotations from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -38,8 +38,8 @@ def __init__(self, client) -> None: async def commit( self, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> list: """Commit the changes accumulated in this batch. diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index f14ec6573..275bcb9b6 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -23,6 +23,7 @@ * a :class:`~google.cloud.firestore_v1.client.Client` owns a :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` """ +from __future__ import annotations from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional, Union @@ -222,10 +223,10 @@ def document(self, *document_path: str) -> AsyncDocumentReference: async def get_all( self, references: List[AsyncDocumentReference], - field_paths: Iterable[str] = None, - transaction=None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + field_paths: Iterable[str] | None = None, + transaction: AsyncTransaction | None = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieve a batch of documents. @@ -280,8 +281,8 @@ async def get_all( async def collections( self, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> AsyncGenerator[AsyncCollectionReference, Any]: """List top-level collections of the client's database. @@ -310,8 +311,8 @@ async def recursive_delete( reference: Union[AsyncCollectionReference, AsyncDocumentReference], *, bulk_writer: Optional["BulkWriter"] = None, - chunk_size: Optional[int] = 5000, - ): + chunk_size: int = 5000, + ) -> int: """Deletes documents and their subcollections, regardless of collection name. @@ -346,8 +347,8 @@ async def _recursive_delete( reference: Union[AsyncCollectionReference, AsyncDocumentReference], bulk_writer: "BulkWriter", *, - chunk_size: Optional[int] = 5000, - depth: Optional[int] = 0, + chunk_size: int = 5000, + depth: int = 0, ) -> int: """Recursion helper for `recursive_delete.""" diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index ec15de65f..8c832b8f4 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -100,9 +100,9 @@ async def _chunkify(self, chunk_size: int): async def add( self, document_data: dict, - document_id: str = None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + document_id: str | None = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. @@ -141,7 +141,7 @@ async def add( return write_result.update_time, document_ref def document( - self, document_id: str = None + self, document_id: str | None = None ) -> async_document.AsyncDocumentReference: """Create a sub-document underneath the current collection. @@ -159,9 +159,9 @@ def document( async def list_documents( self, - page_size: int = None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + page_size: int | None = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> AsyncGenerator[DocumentReference, None]: """List all subdocuments of the current collection. @@ -193,7 +193,7 @@ async def list_documents( async def get( self, transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, @@ -233,7 +233,7 @@ async def get( def stream( self, transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index a697e8630..78c71b33f 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -13,6 +13,7 @@ # limitations under the License. """Classes for representing documents for the Google Cloud Firestore API.""" +from __future__ import annotations import datetime import logging from typing import AsyncGenerator, Iterable @@ -64,8 +65,8 @@ def __init__(self, *path, **kwargs) -> None: async def create( self, document_data: dict, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> write.WriteResult: """Create the current document in the Firestore database. @@ -94,8 +95,8 @@ async def set( self, document_data: dict, merge: bool = False, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> write.WriteResult: """Replace the current document in the Firestore database. @@ -133,9 +134,9 @@ async def set( async def update( self, field_updates: dict, - option: _helpers.WriteOption = None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + option: _helpers.WriteOption | None = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> write.WriteResult: """Update an existing document in the Firestore database. @@ -290,9 +291,9 @@ async def update( async def delete( self, - option: _helpers.WriteOption = None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + option: _helpers.WriteOption | None = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> Timestamp: """Delete the current document in the Firestore database. @@ -324,10 +325,10 @@ async def delete( async def get( self, - field_paths: Iterable[str] = None, + field_paths: Iterable[str] | None = None, transaction=None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> DocumentSnapshot: """Retrieve a snapshot of the current document. @@ -393,9 +394,9 @@ async def get( async def collections( self, - page_size: int = None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + page_size: int | None = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> AsyncGenerator: """List subcollections of the current document. diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 0cc9b550a..aa16725d8 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -26,7 +26,6 @@ from google.api_core import retry_async as retries from google.cloud import firestore_v1 -from google.cloud.firestore_v1 import transaction from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery @@ -42,6 +41,7 @@ if TYPE_CHECKING: # pragma: NO COVER # Types needed only for Type Hints + from google.cloud.firestore_v1.async_transaction import AsyncTransaction from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.field_path import FieldPath @@ -177,8 +177,8 @@ async def _chunkify( async def get( self, - transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + transaction: Optional[AsyncTransaction] = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, @@ -329,8 +329,8 @@ def avg( async def _make_stream( self, - transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + transaction: Optional[AsyncTransaction] = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, ) -> AsyncGenerator[DocumentSnapshot | query_profile_pb.ExplainMetrics, Any]: @@ -404,8 +404,8 @@ async def _make_stream( def stream( self, - transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + transaction: Optional[AsyncTransaction] = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, @@ -509,8 +509,8 @@ def _get_query_class(): async def get_partitions( self, partition_count, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> AsyncGenerator[QueryPartition, None]: """Partition a query for parallelization. diff --git a/google/cloud/firestore_v1/async_stream_generator.py b/google/cloud/firestore_v1/async_stream_generator.py index c38e6eea1..c222b5d87 100644 --- a/google/cloud/firestore_v1/async_stream_generator.py +++ b/google/cloud/firestore_v1/async_stream_generator.py @@ -17,7 +17,7 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, TypeVar +from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Optional, TypeVar from google.cloud.firestore_v1.query_profile import ( ExplainMetrics, @@ -66,10 +66,10 @@ async def __anext__(self) -> T: except StopAsyncIteration: raise - def asend(self, value: Any = None) -> Awaitable[T]: + def asend(self, value: Any = None) -> Coroutine[Any, Any, T]: return self._generator.asend(value) - def athrow(self, *args, **kwargs) -> Awaitable[T]: + def athrow(self, *args, **kwargs) -> Coroutine[Any, Any, T]: return self._generator.athrow(*args, **kwargs) def aclose(self): diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index cf751c9f0..038710929 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -45,7 +45,7 @@ class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): """Accumulate read-and-write operations to be sent in a transaction. Args: - client (:class:`~google.cloud.firestore_v1.client.Client`): + client (:class:`~google.cloud.firestore_v1.async_client.AsyncClient`): The client that created this transaction. max_attempts (Optional[int]): The maximum number of attempts for the transaction (i.e. allowing retries). Defaults to @@ -74,7 +74,7 @@ def _add_write_pbs(self, write_pbs: list) -> None: super(AsyncTransaction, self)._add_write_pbs(write_pbs) - async def _begin(self, retry_id: bytes = None) -> None: + async def _begin(self, retry_id: bytes | None = None) -> None: """Begin the transaction. Args: @@ -152,8 +152,8 @@ async def _commit(self) -> list: async def get_all( self, references: list, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieves multiple documents from Firestore. @@ -175,7 +175,7 @@ async def get_all( async def get( self, ref_or_query: AsyncDocumentReference | AsyncQuery, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/google/cloud/firestore_v1/async_vector_query.py b/google/cloud/firestore_v1/async_vector_query.py index 97ea3d0aa..6e3d1a854 100644 --- a/google/cloud/firestore_v1/async_vector_query.py +++ b/google/cloud/firestore_v1/async_vector_query.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, TypeVar, Union from google.api_core import gapic_v1 -from google.api_core import retry_async as retries +from google.api_core import retry as retries from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.base_query import ( @@ -54,7 +54,7 @@ def __init__( async def get( self, transaction=None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, @@ -103,10 +103,10 @@ async def get( async def _make_stream( self, transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, - ) -> AsyncGenerator[[DocumentSnapshot | query_profile_pb.ExplainMetrics], Any]: + ) -> AsyncGenerator[DocumentSnapshot | query_profile_pb.ExplainMetrics, Any]: """Internal method for stream(). Read the documents in the collection that match this query. @@ -171,7 +171,7 @@ async def _make_stream( def stream( self, transaction=None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, + retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 807c753f1..34a3baad8 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -18,8 +18,6 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create an aggregation query than direct usage of the constructor. """ - - from __future__ import annotations import abc @@ -32,7 +30,6 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.types import ( - RunAggregationQueryResponse, StructuredAggregationQuery, ) @@ -123,7 +120,7 @@ def _to_protobuf(self): def _query_response_to_result( - response_pb: RunAggregationQueryResponse, + response_pb, ) -> List[AggregationResult]: results = [ AggregationResult( @@ -205,7 +202,7 @@ def _to_protobuf(self) -> StructuredAggregationQuery: def _prep_stream( self, transaction=None, - retry: Union[retries.Retry, None, gapic_v1.method._MethodDefault] = None, + retry: Union[retries.Retry, retries.AsyncRetry, None, object] = None, timeout: float | None = None, explain_options: Optional[ExplainOptions] = None, ) -> Tuple[dict, dict]: @@ -226,7 +223,7 @@ def get( self, transaction=None, retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault + retries.Retry, retries.AsyncRetry, None, object ] = gapic_v1.method.DEFAULT, timeout: float | None = None, *, @@ -266,9 +263,10 @@ def get( def stream( self, transaction: Optional[transaction.Transaction] = None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, + retry: retries.Retry + | retries.AsyncRetry + | object + | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index 0827122b6..b0d50f1f4 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -13,7 +13,7 @@ # limitations under the License. """Helpers for batch requests to the Google Cloud Firestore API.""" - +from __future__ import annotations import abc from typing import Dict, Union @@ -22,6 +22,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.types import write as write_pb class BaseBatch(metaclass=abc.ABCMeta): @@ -38,9 +39,9 @@ class BaseBatch(metaclass=abc.ABCMeta): def __init__(self, client) -> None: self._client = client - self._write_pbs = [] + self._write_pbs: list[write_pb.Write] = [] self._document_references: Dict[str, BaseDocumentReference] = {} - self.write_results = None + self.write_results: list[write_pb.WriteResult] | None = None self.commit_time = None def __len__(self): @@ -49,7 +50,7 @@ def __len__(self): def __contains__(self, reference: BaseDocumentReference): return reference._document_path in self._document_references - def _add_write_pbs(self, write_pbs: list) -> None: + def _add_write_pbs(self, write_pbs: list[write_pb.Write]) -> None: """Add `Write`` protobufs to this transaction. This method intended to be over-ridden by subclasses. @@ -120,7 +121,7 @@ def update( self, reference: BaseDocumentReference, field_updates: dict, - option: _helpers.WriteOption = None, + option: _helpers.WriteOption | None = None, ) -> None: """Add a "change" to update a document. @@ -146,7 +147,9 @@ def update( self._add_write_pbs(write_pbs) def delete( - self, reference: BaseDocumentReference, option: _helpers.WriteOption = None + self, + reference: BaseDocumentReference, + option: _helpers.WriteOption | None = None, ) -> None: """Add a "change" to delete a document. @@ -171,7 +174,11 @@ class BaseWriteBatch(BaseBatch): """Base class for a/sync implementations of the `commit` RPC. `commit` is useful for lower volumes or when the order of write operations is important.""" - def _prep_commit(self, retry: retries.Retry, timeout: float): + def _prep_commit( + self, + retry: retries.Retry | retries.AsyncRetry | object | None, + timeout: float | None, + ): """Shared setup for async/sync :meth:`commit`.""" request = { "database": self._client._database_string, diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index f36ff357b..9b1c0bccd 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -23,11 +23,13 @@ * a :class:`~google.cloud.firestore_v1.client.Client` owns a :class:`~google.cloud.firestore_v1.document.DocumentReference` """ +from __future__ import annotations import os from typing import ( Any, AsyncGenerator, + Awaitable, Generator, Iterable, List, @@ -57,6 +59,7 @@ from google.cloud.firestore_v1.base_transaction import BaseTransaction from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path +from google.cloud.firestore_v1.services.firestore import client as firestore_client DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" @@ -219,6 +222,16 @@ def _target_helper(self, client_class) -> str: else: return client_class.DEFAULT_ENDPOINT + @property + def _target(self): + """Return the target (where the API is). + Eg. "firestore.googleapis.com" + + Returns: + str: The location of the API. + """ + return self._target_helper(firestore_client.FirestoreClient) + @property def _database_string(self): """The database string corresponding to this client's project. @@ -265,7 +278,7 @@ def _rpc_metadata(self): return self._rpc_metadata_internal - def collection(self, *collection_path) -> BaseCollectionReference[BaseQuery]: + def collection(self, *collection_path) -> BaseCollectionReference: raise NotImplementedError def collection_group(self, collection_id: str) -> BaseQuery: @@ -330,9 +343,11 @@ def _document_path_helper(self, *document_path) -> List[str]: def recursive_delete( self, - reference: Union[BaseCollectionReference[BaseQuery], BaseDocumentReference], - bulk_writer: Optional["BulkWriter"] = None, # type: ignore - ) -> int: + reference, + *, + bulk_writer: Optional["BulkWriter"] = None, + chunk_size: int = 5000, + ) -> int | Awaitable[int]: raise NotImplementedError @staticmethod @@ -418,10 +433,10 @@ def write_option( def _prep_get_all( self, references: list, - field_paths: Iterable[str] = None, - transaction: BaseTransaction = None, - retry: retries.Retry = None, - timeout: float = None, + field_paths: Iterable[str] | None = None, + transaction: BaseTransaction | None = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, + timeout: float | None = None, ) -> Tuple[dict, dict, dict]: """Shared setup for async/sync :meth:`get_all`.""" document_paths, reference_map = _reference_info(references) @@ -439,10 +454,10 @@ def _prep_get_all( def get_all( self, references: list, - field_paths: Iterable[str] = None, - transaction: BaseTransaction = None, - retry: retries.Retry = None, - timeout: float = None, + field_paths: Iterable[str] | None = None, + transaction=None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, + timeout: float | None = None, ) -> Union[ AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any] ]: @@ -450,8 +465,8 @@ def get_all( def _prep_collections( self, - retry: retries.Retry = None, - timeout: float = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, + timeout: float | None = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`collections`.""" request = {"parent": "{}/documents".format(self._database_string)} @@ -461,12 +476,9 @@ def _prep_collections( def collections( self, - retry: retries.Retry = None, - timeout: float = None, - ) -> Union[ - AsyncGenerator[BaseCollectionReference[BaseQuery], Any], - Generator[BaseCollectionReference[BaseQuery], Any, Any], - ]: + retry: retries.Retry | retries.AsyncRetry | object | None = None, + timeout: float | None = None, + ): raise NotImplementedError def batch(self) -> BaseWriteBatch: @@ -583,7 +595,9 @@ def _parse_batch_get( return snapshot -def _get_doc_mask(field_paths: Iterable[str]) -> Optional[types.common.DocumentMask]: +def _get_doc_mask( + field_paths: Iterable[str] | None, +) -> Optional[types.common.DocumentMask]: """Get a document mask if field paths are provided. Args: diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 1ac1ba318..b74ced2a3 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -25,7 +25,6 @@ Generator, Generic, Iterable, - NoReturn, Optional, Tuple, Union, @@ -129,7 +128,7 @@ def _aggregation_query(self) -> BaseAggregationQuery: def _vector_query(self) -> BaseVectorQuery: raise NotImplementedError - def document(self, document_id: Optional[str] = None) -> DocumentReference: + def document(self, document_id: Optional[str] = None): """Create a sub-document underneath the current collection. Args: @@ -177,7 +176,7 @@ def _prep_add( self, document_data: dict, document_id: Optional[str] = None, - retry: Optional[retries.Retry] = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, ) -> Tuple[DocumentReference, dict]: """Shared setup for async / sync :method:`add`""" @@ -193,7 +192,7 @@ def add( self, document_data: dict, document_id: Optional[str] = None, - retry: Optional[retries.Retry] = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, ) -> Union[Tuple[Any, Any], Coroutine[Any, Any, Tuple[Any, Any]]]: raise NotImplementedError @@ -201,7 +200,7 @@ def add( def _prep_list_documents( self, page_size: Optional[int] = None, - retry: Optional[retries.Retry] = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, ) -> Tuple[dict, dict]: """Shared setup for async / sync :method:`list_documents`""" @@ -223,7 +222,7 @@ def _prep_list_documents( def list_documents( self, page_size: Optional[int] = None, - retry: Optional[retries.Retry] = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, ) -> Union[ Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any] @@ -482,7 +481,7 @@ def end_at( def _prep_get_or_stream( self, - retry: Optional[retries.Retry] = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, ) -> Tuple[Any, dict]: """Shared setup for async / sync :meth:`get` / :meth:`stream`""" @@ -494,7 +493,7 @@ def _prep_get_or_stream( def get( self, transaction: Optional[Transaction] = None, - retry: Optional[retries.Retry] = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, @@ -507,14 +506,14 @@ def get( def stream( self, transaction: Optional[Transaction] = None, - retry: Optional[retries.Retry] = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, ) -> StreamGenerator[DocumentSnapshot] | AsyncIterator[DocumentSnapshot]: raise NotImplementedError - def on_snapshot(self, callback) -> NoReturn: + def on_snapshot(self, callback): raise NotImplementedError def count(self, alias=None): diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index c17e10586..b16b8abac 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -21,10 +21,10 @@ Any, Dict, Iterable, - NoReturn, Optional, Tuple, Union, + Awaitable, ) from google.api_core import retry as retries @@ -181,7 +181,7 @@ def parent(self): parent_path = self._path[:-1] return self._client.collection(*parent_path) - def collection(self, collection_id: str) -> Any: + def collection(self, collection_id: str): """Create a sub-collection underneath the current document. Args: @@ -198,8 +198,8 @@ def collection(self, collection_id: str) -> Any: def _prep_create( self, document_data: dict, - retry: retries.Retry = None, - timeout: float = None, + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, ) -> Tuple[Any, dict]: batch = self._client.batch() batch.create(self, document_data) @@ -210,17 +210,17 @@ def _prep_create( def create( self, document_data: dict, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, + ) -> write.WriteResult | Awaitable[write.WriteResult]: raise NotImplementedError def _prep_set( self, document_data: dict, merge: bool = False, - retry: retries.Retry = None, - timeout: float = None, + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, ) -> Tuple[Any, dict]: batch = self._client.batch() batch.set(self, document_data, merge=merge) @@ -232,17 +232,17 @@ def set( self, document_data: dict, merge: bool = False, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, + ): raise NotImplementedError def _prep_update( self, field_updates: dict, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, + option: _helpers.WriteOption | None = None, + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, ) -> Tuple[Any, dict]: batch = self._client.batch() batch.update(self, field_updates, option=option) @@ -253,17 +253,17 @@ def _prep_update( def update( self, field_updates: dict, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + option: _helpers.WriteOption | None = None, + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, + ): raise NotImplementedError def _prep_delete( self, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, + option: _helpers.WriteOption | None = None, + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`delete`.""" write_pb = _helpers.pb_for_delete(self._document_path, option) @@ -278,18 +278,18 @@ def _prep_delete( def delete( self, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + option: _helpers.WriteOption | None = None, + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, + ): raise NotImplementedError def _prep_batch_get( self, - field_paths: Iterable[str] = None, + field_paths: Iterable[str] | None = None, transaction=None, - retry: retries.Retry = None, - timeout: float = None, + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`get`.""" if isinstance(field_paths, str): @@ -312,18 +312,18 @@ def _prep_batch_get( def get( self, - field_paths: Iterable[str] = None, + field_paths: Iterable[str] | None = None, transaction=None, - retry: retries.Retry = None, - timeout: float = None, - ) -> "DocumentSnapshot": + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, + ) -> "DocumentSnapshot" | Awaitable["DocumentSnapshot"]: raise NotImplementedError def _prep_collections( self, - page_size: int = None, - retry: retries.Retry = None, - timeout: float = None, + page_size: int | None = None, + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`collections`.""" request = {"parent": self._document_path, "page_size": page_size} @@ -333,13 +333,13 @@ def _prep_collections( def collections( self, - page_size: int = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> None: + page_size: int | None = None, + retry: retries.Retry | retries.AsyncRetry | None | object = None, + timeout: float | None = None, + ): raise NotImplementedError - def on_snapshot(self, callback) -> None: + def on_snapshot(self, callback): raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 3a473094a..3509bbf17 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -30,12 +30,12 @@ Coroutine, Dict, Iterable, - NoReturn, + List, Optional, Tuple, Type, - TypeVar, Union, + TypeVar, ) from google.api_core import retry as retries @@ -60,7 +60,6 @@ if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator - from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList @@ -382,16 +381,17 @@ def select(self: QueryType, field_paths: Iterable[str]) -> QueryType: def _copy( self: QueryType, *, - projection: Optional[query.StructuredQuery.Projection] = _not_passed, - field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed, - orders: Optional[Tuple[query.StructuredQuery.Order]] = _not_passed, - limit: Optional[int] = _not_passed, - limit_to_last: Optional[bool] = _not_passed, - offset: Optional[int] = _not_passed, - start_at: Optional[Tuple[dict, bool]] = _not_passed, - end_at: Optional[Tuple[dict, bool]] = _not_passed, - all_descendants: Optional[bool] = _not_passed, - recursive: Optional[bool] = _not_passed, + projection: Optional[query.StructuredQuery.Projection] | object = _not_passed, + field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] + | object = _not_passed, + orders: Optional[Tuple[query.StructuredQuery.Order]] | object = _not_passed, + limit: Optional[int] | object = _not_passed, + limit_to_last: Optional[bool] | object = _not_passed, + offset: Optional[int] | object = _not_passed, + start_at: Optional[Tuple[dict, bool]] | object = _not_passed, + end_at: Optional[Tuple[dict, bool]] | object = _not_passed, + all_descendants: Optional[bool] | object = _not_passed, + recursive: Optional[bool] | object = _not_passed, ) -> QueryType: return self.__class__( self._parent, @@ -630,7 +630,7 @@ def _check_snapshot(self, document_snapshot) -> None: def _cursor_helper( self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple, None], before: bool, start: bool, ) -> QueryType: @@ -687,7 +687,7 @@ def _cursor_helper( def start_at( self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple, None], ) -> QueryType: """Start query results at a particular document value. @@ -720,7 +720,7 @@ def start_at( def start_after( self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple, None], ) -> QueryType: """Start query results after a particular document value. @@ -754,7 +754,7 @@ def start_after( def end_before( self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple, None], ) -> QueryType: """End query results before a particular document value. @@ -788,7 +788,7 @@ def end_before( def end_at( self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple, None], ) -> QueryType: """End query results at a particular document value. @@ -895,7 +895,7 @@ def _normalize_orders(self) -> list: return orders - def _normalize_cursor(self, cursor, orders) -> Optional[Tuple[Any, Any]]: + def _normalize_cursor(self, cursor, orders) -> Tuple[List, bool] | None: """Helper: convert cursor to a list of values based on orders.""" if cursor is None: return None @@ -990,7 +990,7 @@ def find_nearest( *, distance_result_field: Optional[str] = None, distance_threshold: Optional[float] = None, - ) -> BaseVectorQuery: + ): raise NotImplementedError def count( @@ -1024,7 +1024,7 @@ def get( def _prep_stream( self, transaction=None, - retry: Optional[retries.Retry] = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, ) -> Tuple[dict, str, dict]: @@ -1060,7 +1060,7 @@ def stream( ): raise NotImplementedError - def on_snapshot(self, callback) -> NoReturn: + def on_snapshot(self, callback): raise NotImplementedError def recursive(self: QueryType) -> QueryType: @@ -1149,6 +1149,10 @@ def _comparator(self, doc1, doc2) -> int: return 0 + @staticmethod + def _get_collection_reference_class(): + raise NotImplementedError + def _enum_from_op_string(op_string: str) -> int: """Convert a string representation of a binary operator to an enum. @@ -1245,7 +1249,7 @@ def _filter_pb(field_or_unary) -> StructuredQuery.Filter: raise ValueError("Unexpected filter type", type(field_or_unary), field_or_unary) -def _cursor_pb(cursor_pair: Tuple[list, bool]) -> Optional[Cursor]: +def _cursor_pb(cursor_pair: Optional[Tuple[list, bool]]) -> Optional[Cursor]: """Convert a cursor pair to a protobuf. If ``cursor_pair`` is :data:`None`, just returns :data:`None`. @@ -1264,6 +1268,8 @@ def _cursor_pb(cursor_pair: Tuple[list, bool]) -> Optional[Cursor]: data, before = cursor_pair value_pbs = [_helpers.encode_value(value) for value in data] return query.Cursor(values=value_pbs, before=before) + else: + return None def _query_response_to_snapshot( @@ -1402,8 +1408,8 @@ def _get_query_class(self): def _prep_get_partitions( self, partition_count, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, + retry: retries.Retry | object | None = None, + timeout: float | None = None, ) -> Tuple[dict, dict]: self._validate_partition_query() parent_path, expected_prefix = self._parent._parent_info() @@ -1429,11 +1435,7 @@ def get_partitions( partition_count, retry: Optional[retries.Retry] = None, timeout: Optional[float] = None, - ) -> NoReturn: - raise NotImplementedError - - @staticmethod - def _get_collection_reference_class() -> Type["BaseCollectionGroup"]: + ): raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index 752c83169..92e54c81c 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -21,7 +21,6 @@ AsyncGenerator, Coroutine, Generator, - NoReturn, Optional, Union, ) @@ -36,18 +35,7 @@ from google.cloud.firestore_v1.document import DocumentSnapshot from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.stream_generator import StreamGenerator - - -_CANT_BEGIN: str -_CANT_COMMIT: str -_CANT_RETRY_READ_ONLY: str -_CANT_ROLLBACK: str -_EXCEED_ATTEMPTS_TEMPLATE: str -_INITIAL_SLEEP: float -_MAX_SLEEP: float -_MISSING_ID_TEMPLATE: str -_MULTIPLIER: float -_WRITE_READ_ONLY: str + from google.cloud.firestore_v1.types import write as write_pb MAX_ATTEMPTS = 5 @@ -78,7 +66,7 @@ def __init__(self, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: self._read_only = read_only self._id = None - def _add_write_pbs(self, write_pbs) -> NoReturn: + def _add_write_pbs(self, write_pbs: list[write_pb.Write]): raise NotImplementedError def _options_protobuf( @@ -143,13 +131,13 @@ def _clean_up(self) -> None: This intended to occur on success or failure of the associated RPCs. """ - self._write_pbs = [] + self._write_pbs: list[write_pb.Write] = [] self._id = None - def _begin(self, retry_id=None) -> NoReturn: + def _begin(self, retry_id=None): raise NotImplementedError - def _rollback(self) -> NoReturn: + def _rollback(self): raise NotImplementedError def _commit(self) -> Union[list, Coroutine[Any, Any, list]]: @@ -158,8 +146,8 @@ def _commit(self) -> Union[list, Coroutine[Any, Any, list]]: def get_all( self, references: list, - retry: retries.Retry = None, - timeout: float = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, + timeout: float | None = None, ) -> ( Generator[DocumentSnapshot, Any, None] | Coroutine[Any, Any, AsyncGenerator[DocumentSnapshot, Any]] @@ -169,8 +157,8 @@ def get_all( def get( self, ref_or_query, - retry: retries.Retry = None, - timeout: float = None, + retry: retries.Retry | retries.AsyncRetry | object | None = None, + timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, ) -> ( @@ -205,7 +193,7 @@ def _reset(self) -> None: self.current_id = None self.retry_id = None - def _pre_commit(self, transaction, *args, **kwargs) -> NoReturn: + def _pre_commit(self, transaction, *args, **kwargs): raise NotImplementedError def __call__(self, transaction, *args, **kwargs): diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index 30c79bc7e..f5a4403c8 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -99,7 +99,7 @@ def _to_protobuf(self) -> query.StructuredQuery: def _prep_stream( self, transaction=None, - retry: Union[retries.Retry, None, gapic_v1.method._MethodDefault] = None, + retry: Union[retries.Retry, retries.AsyncRetry, object, None] = None, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, ) -> Tuple[dict, str, dict]: @@ -120,7 +120,10 @@ def _prep_stream( def get( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: retries.Retry + | retries.AsyncRetry + | object + | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, @@ -153,7 +156,10 @@ def find_nearest( def stream( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: retries.Retry + | retries.AsyncRetry + | object + | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/google/cloud/firestore_v1/batch.py b/google/cloud/firestore_v1/batch.py index 406cdb122..cc98c6503 100644 --- a/google/cloud/firestore_v1/batch.py +++ b/google/cloud/firestore_v1/batch.py @@ -13,7 +13,7 @@ # limitations under the License. """Helpers for batch requests to the Google Cloud Firestore API.""" - +from __future__ import annotations from google.api_core import gapic_v1 from google.api_core import retry as retries @@ -38,7 +38,9 @@ def __init__(self, client) -> None: super(WriteBatch, self).__init__(client=client) def commit( - self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None + self, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> list: """Commit the changes accumulated in this batch. diff --git a/google/cloud/firestore_v1/bulk_batch.py b/google/cloud/firestore_v1/bulk_batch.py index 631310beb..29a3e509f 100644 --- a/google/cloud/firestore_v1/bulk_batch.py +++ b/google/cloud/firestore_v1/bulk_batch.py @@ -13,6 +13,7 @@ # limitations under the License. """Helpers for batch requests to the Google Cloud Firestore API.""" +from __future__ import annotations from google.api_core import gapic_v1 from google.api_core import retry as retries @@ -46,7 +47,9 @@ def __init__(self, client) -> None: super(BulkWriteBatch, self).__init__(client=client) def commit( - self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None + self, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> BatchWriteResponse: """Writes the changes accumulated in this batch. @@ -81,7 +84,7 @@ def commit( return save_response - def _prep_commit(self, retry: retries.Retry, timeout: float): + def _prep_commit(self, retry: retries.Retry | object | None, timeout: float | None): request = { "database": self._client._database_string, "writes": self._write_pbs, diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py index 4c1c7bde9..ec0fa4881 100644 --- a/google/cloud/firestore_v1/bulk_writer.py +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -24,7 +24,7 @@ import logging import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Deque, Dict, List, Optional, Union from google.rpc import status_pb2 # type: ignore @@ -82,7 +82,7 @@ class AsyncBulkWriterMixin: wrapped in a decorator which ensures that the `SendMode` is honored. """ - def _with_send_mode(fn): + def _with_send_mode(fn: Callable): # type: ignore """Decorates a method to ensure it is only called via the executor (IFF the SendMode value is SendMode.parallel!). @@ -117,8 +117,10 @@ def wrapper(self, *args, **kwargs): return wrapper @_with_send_mode - def _send_batch( - self, batch: BulkWriteBatch, operations: List["BulkWriterOperation"] + def _send_batch( # type: ignore + self: "BulkWriter", + batch: BulkWriteBatch, + operations: List["BulkWriterOperation"], ): """Sends a batch without regard to rate limits, meaning limits must have already been checked. To that end, do not call this directly; instead, @@ -138,12 +140,12 @@ def _send_batch( self._process_response(batch, response, operations) - def _process_response( - self, + def _process_response( # type: ignore + self: "BulkWriter", batch: BulkWriteBatch, response: BatchWriteResponse, operations: List["BulkWriterOperation"], - ) -> None: + ): """Invokes submitted callbacks for each batch and each operation within each batch. As this is called from `_send_batch()`, this is parallelized if we are in that mode. @@ -180,10 +182,10 @@ def _process_response( operation.attempts += 1 self._retry_operation(operation) - def _retry_operation( - self, + def _retry_operation( # type: ignore + self: "BulkWriter", operation: "BulkWriterOperation", - ) -> concurrent.futures.Future: + ): delay: int = 0 if self._options.retry == BulkRetry.exponential: delay = operation.attempts**2 # pragma: NO COVER @@ -257,7 +259,7 @@ class BulkWriter(AsyncBulkWriterMixin): def __init__( self, - client: "BaseClient" = None, + client: Optional["BaseClient"] = None, options: Optional["BulkWriterOptions"] = None, ): # Because `BulkWriter` instances are all synchronous/blocking on the @@ -266,9 +268,10 @@ def __init__( # `BulkWriter` parallelizes all of its network I/O without the developer # having to worry about awaiting async methods, so we must convert an # AsyncClient instance into a plain Client instance. - self._client = ( - client._to_sync_copy() if type(client).__name__ == "AsyncClient" else client - ) + if type(client).__name__ == "AsyncClient": + self._client = client._to_sync_copy() # type: ignore + else: + self._client = client self._options = options or BulkWriterOptions() self._send_mode = self._options.mode @@ -284,9 +287,9 @@ def __init__( # the raw operation with the `datetime` of its next scheduled attempt. # `self._retries` must always remain sorted for efficient reads, so it is # required to only ever add elements via `bisect.insort`. - self._retries: collections.deque["OperationRetry"] = collections.deque([]) + self._retries: Deque["OperationRetry"] = collections.deque([]) - self._queued_batches = collections.deque([]) + self._queued_batches: Deque[List[BulkWriterOperation]] = collections.deque([]) self._is_open: bool = True # This list will go on to store the future returned from each submission @@ -441,7 +444,7 @@ def _enqueue_current_batch(self): # here we make sure that is running. self._ensure_sending() - def _send_until_queue_is_empty(self): + def _send_until_queue_is_empty(self) -> None: """First domino in the sending codepath. This does not need to be parallelized for two reasons: @@ -488,8 +491,9 @@ def _send_until_queue_is_empty(self): self._pending_batch_futures.append(future) self._schedule_ready_retries() + return None - def _schedule_ready_retries(self): + def _schedule_ready_retries(self) -> None: """Grabs all ready retries and re-queues them.""" # Because `self._retries` always exists in a sorted state (thanks to only @@ -503,6 +507,7 @@ def _schedule_ready_retries(self): for _ in range(take_until_index): retry: OperationRetry = self._retries.popleft() retry.retry(self) + return None def _request_send(self, batch_size: int) -> bool: # Set up this boolean to avoid repeatedly taking tokens if we're only @@ -519,8 +524,8 @@ def _request_send(self, batch_size: int) -> bool: ) # Ask for tokens each pass through this loop until they are granted, # and then stop. - have_received_tokens = ( - have_received_tokens or self._rate_limiter.take_tokens(batch_size) + have_received_tokens = have_received_tokens or bool( + self._rate_limiter.take_tokens(batch_size) ) if not under_threshold or not have_received_tokens: # Try again until both checks are true. @@ -705,20 +710,24 @@ def update( def on_write_result( self, - callback: Callable[[BaseDocumentReference, WriteResult, "BulkWriter"], None], + callback: Optional[ + Callable[[BaseDocumentReference, WriteResult, "BulkWriter"], None] + ], ) -> None: """Sets a callback that will be invoked once for every successful operation.""" self._success_callback = callback or BulkWriter._default_on_success def on_batch_result( self, - callback: Callable[[BulkWriteBatch, BatchWriteResponse, "BulkWriter"], None], + callback: Optional[ + Callable[[BulkWriteBatch, BatchWriteResponse, "BulkWriter"], None] + ], ) -> None: """Sets a callback that will be invoked once for every successful batch.""" self._batch_callback = callback or BulkWriter._default_on_batch def on_write_error( - self, callback: Callable[["BulkWriteFailure", "BulkWriter"], bool] + self, callback: Optional[Callable[["BulkWriteFailure", "BulkWriter"], bool]] ) -> None: """Sets a callback that will be invoked once for every batch that contains an error.""" @@ -739,6 +748,9 @@ class BulkWriterOperation: similar writes to the same document. """ + def __init__(self, attempts: int = 0): + self.attempts = attempts + def add_to_batch(self, batch: BulkWriteBatch): """Adds `self` to the supplied batch.""" assert isinstance(batch, BulkWriteBatch) @@ -781,7 +793,7 @@ class BaseOperationRetry: Python 3.6 is dropped and `dataclasses` becomes universal. """ - def __lt__(self, other: "OperationRetry"): + def __lt__(self: "OperationRetry", other: "OperationRetry"): # type: ignore """Allows use of `bisect` to maintain a sorted list of `OperationRetry` instances, which in turn allows us to cheaply grab all that are ready to run.""" @@ -791,7 +803,7 @@ def __lt__(self, other: "OperationRetry"): return self.run_at < other return NotImplemented # pragma: NO COVER - def retry(self, bulk_writer: BulkWriter) -> None: + def retry(self: "OperationRetry", bulk_writer: BulkWriter) -> None: # type: ignore """Call this after waiting any necessary time to re-add the enclosed operation to the supplied BulkWriter's internal queue.""" if isinstance(self.operation, BulkWriterCreateOperation): diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 8bdaf7f81..23c6b36ef 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -23,6 +23,7 @@ * a :class:`~google.cloud.firestore_v1.client.Client` owns a :class:`~google.cloud.firestore_v1.document.DocumentReference` """ +from __future__ import annotations from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Union @@ -109,16 +110,6 @@ def _firestore_api(self): firestore_client, ) - @property - def _target(self): - """Return the target (where the API is). - Eg. "firestore.googleapis.com" - - Returns: - str: The location of the API. - """ - return self._target_helper(firestore_client.FirestoreClient) - def collection(self, *collection_path: str) -> CollectionReference: """Get a reference to a collection. @@ -210,10 +201,10 @@ def document(self, *document_path: str) -> DocumentReference: def get_all( self, references: list, - field_paths: Iterable[str] = None, - transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + field_paths: Iterable[str] | None = None, + transaction: Transaction | None = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> Generator[DocumentSnapshot, Any, None]: """Retrieve a batch of documents. @@ -268,8 +259,8 @@ def get_all( def collections( self, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> Generator[Any, Any, None]: """List top-level collections of the client's database. @@ -299,7 +290,7 @@ def recursive_delete( reference: Union[CollectionReference, DocumentReference], *, bulk_writer: Optional["BulkWriter"] = None, - chunk_size: Optional[int] = 5000, + chunk_size: int = 5000, ) -> int: """Deletes documents and their subcollections, regardless of collection name. @@ -336,8 +327,8 @@ def _recursive_delete( reference: Union[CollectionReference, DocumentReference], bulk_writer: "BulkWriter", *, - chunk_size: Optional[int] = 5000, - depth: Optional[int] = 0, + chunk_size: int = 5000, + depth: int = 0, ) -> int: """Recursion helper for `recursive_delete.""" diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 372dacd7b..cd6929b68 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -93,7 +93,7 @@ def add( self, document_data: dict, document_id: Union[str, None] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Union[float, None] = None, ) -> Tuple[Any, Any]: """Create a document in the Firestore database with the provided data. @@ -135,7 +135,7 @@ def add( def list_documents( self, page_size: Union[int, None] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Union[float, None] = None, ) -> Generator[Any, Any, None]: """List all subdocuments of the current collection. @@ -170,7 +170,7 @@ def _chunkify(self, chunk_size: int): def get( self, transaction: Union[transaction.Transaction, None] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Union[float, None] = None, *, explain_options: Optional[ExplainOptions] = None, @@ -210,7 +210,7 @@ def get( def stream( self, transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 305d10df6..0c7d7872f 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -13,6 +13,7 @@ # limitations under the License. """Classes for representing documents for the Google Cloud Firestore API.""" +from __future__ import annotations import datetime import logging from typing import Any, Callable, Generator, Iterable @@ -65,8 +66,8 @@ def __init__(self, *path, **kwargs) -> None: def create( self, document_data: dict, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> write.WriteResult: """Create a document in the Firestore database. @@ -102,8 +103,8 @@ def set( self, document_data: dict, merge: bool = False, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> write.WriteResult: """Create / replace / merge a document in the Firestore database. @@ -169,9 +170,9 @@ def set( def update( self, field_updates: dict, - option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + option: _helpers.WriteOption | None = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> write.WriteResult: """Update an existing document in the Firestore database. @@ -326,9 +327,9 @@ def update( def delete( self, - option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + option: _helpers.WriteOption | None = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> Timestamp: """Delete the current document in the Firestore database. @@ -360,10 +361,10 @@ def delete( def get( self, - field_paths: Iterable[str] = None, + field_paths: Iterable[str] | None = None, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> DocumentSnapshot: """Retrieve a snapshot of the current document. @@ -430,9 +431,9 @@ def get( def collections( self, - page_size: int = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + page_size: int | None = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> Generator[Any, Any, None]: """List subcollections of the current document. diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index df7d10a78..c3383cbb8 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -13,10 +13,10 @@ # limitations under the License. """Utilities for managing / converting field paths to / from strings.""" - +from __future__ import annotations import re from collections import abc -from typing import Iterable +from typing import Iterable, cast _FIELD_PATH_MISSING_TOP = "{!r} is not contained in the data" _FIELD_PATH_MISSING_KEY = "{!r} is not contained in the data for the key {!r}" @@ -53,7 +53,7 @@ def _tokenize_field_path(path: str): get_token = TOKENS_REGEX.match match = get_token(path) while match is not None: - type_ = match.lastgroup + type_ = cast(str, match.lastgroup) value = match.group(type_) yield value pos = match.end() @@ -62,7 +62,7 @@ def _tokenize_field_path(path: str): raise ValueError("Path {} not consumed, residue: {}".format(path, path[pos:])) -def split_field_path(path: str): +def split_field_path(path: str | None): """Split a field path into valid elements (without dots). Args: diff --git a/google/cloud/firestore_v1/order.py b/google/cloud/firestore_v1/order.py index 9395d05b9..08144577b 100644 --- a/google/cloud/firestore_v1/order.py +++ b/google/cloud/firestore_v1/order.py @@ -17,6 +17,7 @@ from typing import Any from google.cloud.firestore_v1._helpers import decode_value +from google.cloud.firestore_v1._helpers import GeoPoint class TypeOrder(Enum): @@ -150,6 +151,10 @@ def compare_timestamps(left, right) -> Any: def compare_geo_points(left, right) -> Any: left_value = decode_value(left, None) right_value = decode_value(right, None) + if not isinstance(left_value, GeoPoint) or not isinstance( + right_value, GeoPoint + ): + raise AttributeError("invalid geopoint encountered") cmp = (left_value.latitude > right_value.latitude) - ( left_value.latitude < right_value.latitude ) diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 3ae0c3d0b..0b52afc83 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -137,7 +137,7 @@ def __init__( def get( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, @@ -350,8 +350,8 @@ def avg( def _make_stream( self, transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, - timeout: Optional[float] = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, explain_options: Optional[ExplainOptions] = None, ) -> Generator[DocumentSnapshot, Any, Optional[ExplainMetrics]]: """Internal method for stream(). Read the documents in the collection @@ -443,9 +443,9 @@ def _make_stream( def stream( self, - transaction: Optional[transaction.Transaction] = None, - retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, - timeout: Optional[float] = None, + transaction: transaction.Transaction | None = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, ) -> StreamGenerator[DocumentSnapshot]: @@ -578,8 +578,8 @@ def _get_query_class(): def get_partitions( self, partition_count, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> Generator[QueryPartition, None, None]: """Partition a query for parallelization. diff --git a/google/cloud/firestore_v1/rate_limiter.py b/google/cloud/firestore_v1/rate_limiter.py index 4cd06d866..ff362e195 100644 --- a/google/cloud/firestore_v1/rate_limiter.py +++ b/google/cloud/firestore_v1/rate_limiter.py @@ -14,7 +14,7 @@ import datetime import warnings -from typing import NoReturn, Optional +from typing import Optional def utcnow(): @@ -110,7 +110,7 @@ def _start_clock(self): self._start = self._start or utcnow self._last_refill = self._last_refill or utcnow - def take_tokens(self, num: Optional[int] = 1, allow_less: bool = False) -> int: + def take_tokens(self, num: int = 1, allow_less: bool = False) -> int: """Returns the number of available tokens, up to the amount requested.""" self._start_clock() self._check_phase() @@ -125,7 +125,7 @@ def take_tokens(self, num: Optional[int] = 1, allow_less: bool = False) -> int: return _num_to_take return 0 - def _check_phase(self): + def _check_phase(self) -> None: """Increments or decrements [_phase] depending on traffic. Every [_phase_length] seconds, if > 50% of available traffic was used @@ -134,6 +134,8 @@ def _check_phase(self): This is a no-op unless a new [_phase_length] number of seconds since the start was crossed since it was last called. """ + if self._start is None: + raise TypeError("RateLimiter error: unset _start value") age: datetime.timedelta = ( datetime.datetime.now(datetime.timezone.utc) - self._start ) @@ -157,14 +159,16 @@ def _check_phase(self): if operations_last_phase and self._phase > previous_phase: self._increase_maximum_tokens() - def _increase_maximum_tokens(self) -> NoReturn: + def _increase_maximum_tokens(self) -> None: self._maximum_tokens = round(self._maximum_tokens * 1.5) if self._global_max_tokens is not None: self._maximum_tokens = min(self._maximum_tokens, self._global_max_tokens) - def _refill(self) -> NoReturn: + def _refill(self) -> None: """Replenishes any tokens that should have regenerated since the last operation.""" + if self._last_refill is None: + raise TypeError("RateLimiter error: unset _last_refill value") now: datetime.datetime = datetime.datetime.now(datetime.timezone.utc) time_since_last_refill: datetime.timedelta = now - self._last_refill diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index b18a71254..37afd5fb0 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -74,7 +74,7 @@ def _add_write_pbs(self, write_pbs: list) -> None: super(Transaction, self)._add_write_pbs(write_pbs) - def _begin(self, retry_id: bytes = None) -> None: + def _begin(self, retry_id: bytes | None = None) -> None: """Begin the transaction. Args: @@ -152,8 +152,8 @@ def _commit(self) -> list: def get_all( self, references: list, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, + timeout: float | None = None, ) -> Generator[DocumentSnapshot, Any, None]: """Retrieves multiple documents from Firestore. @@ -175,7 +175,7 @@ def get_all( def get( self, ref_or_query: DocumentReference | Query, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/google/cloud/firestore_v1/vector.py b/google/cloud/firestore_v1/vector.py index 3349b57e1..4deebdd5b 100644 --- a/google/cloud/firestore_v1/vector.py +++ b/google/cloud/firestore_v1/vector.py @@ -14,7 +14,7 @@ # limitations under the License. import collections -from typing import Sequence, Tuple +from typing import Sequence class Vector(collections.abc.Sequence): @@ -23,18 +23,20 @@ class Vector(collections.abc.Sequence): Underlying object will be converted to a map representation in Firestore API. """ - _value: Tuple[float] = () + _value: Sequence[float] = () def __init__(self, value: Sequence[float]): self._value = tuple([float(v) for v in value]) - def __getitem__(self, arg: int): + def __getitem__(self, arg): return self._value[arg] def __len__(self): return len(self._value) def __eq__(self, other: object) -> bool: + if not isinstance(other, Vector): + return False return self._value == other._value def __repr__(self): diff --git a/google/cloud/firestore_v1/vector_query.py b/google/cloud/firestore_v1/vector_query.py index 9e2d4ad0f..77bf6dbdf 100644 --- a/google/cloud/firestore_v1/vector_query.py +++ b/google/cloud/firestore_v1/vector_query.py @@ -57,7 +57,7 @@ def __init__( def get( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, @@ -122,7 +122,7 @@ def _get_stream_iterator(self, transaction, retry, timeout, explain_options=None def _make_stream( self, transaction: Optional["transaction.Transaction"] = None, - retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, ) -> Generator[DocumentSnapshot, Any, Optional[ExplainMetrics]]: @@ -192,7 +192,7 @@ def _make_stream( def stream( self, transaction: Optional["transaction.Transaction"] = None, - retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, diff --git a/tests/unit/v1/test_base_client.py b/tests/unit/v1/test_base_client.py index e7eddcdea..856c771a1 100644 --- a/tests/unit/v1/test_base_client.py +++ b/tests/unit/v1/test_base_client.py @@ -99,7 +99,7 @@ def test_baseclient__firestore_api_helper_w_already(): def test_baseclient__firestore_api_helper_wo_emulator(): client = _make_default_base_client() client_options = client._client_options = mock.Mock() - target = client._target = mock.Mock() + target = client._target assert client._firestore_api_internal is None transport_class = mock.Mock() @@ -130,7 +130,6 @@ def test_baseclient__firestore_api_helper_w_emulator(): client = _make_default_base_client() client_options = client._client_options = mock.Mock() - target = client._target = mock.Mock() emulator_channel = client._emulator_channel = mock.Mock() assert client._firestore_api_internal is None @@ -145,7 +144,7 @@ def test_baseclient__firestore_api_helper_w_emulator(): emulator_channel.assert_called_once_with(transport_class) transport_class.assert_called_once_with( - host=target, + host=emulator_host, channel=emulator_channel.return_value, ) client_class.assert_called_once_with( diff --git a/tests/unit/v1/test_bulk_writer.py b/tests/unit/v1/test_bulk_writer.py index ac7d2e1da..17486600b 100644 --- a/tests/unit/v1/test_bulk_writer.py +++ b/tests/unit/v1/test_bulk_writer.py @@ -136,6 +136,15 @@ def test_basebulkwriter_ctor_explicit(self): options = BulkWriterOptions(retry=BulkRetry.immediate) self._basebulkwriter_ctor_helper(options=options) + def test_bulkwriteroperation_ctor(self): + from google.cloud.firestore_v1.bulk_writer import BulkWriterOperation + + op = BulkWriterOperation() + assert op.attempts == 0 + attempts = 9 + op2 = BulkWriterOperation(attempts) + assert op2.attempts == attempts + def _doc_iter(self, client, num: int, ids: Optional[List[str]] = None): for _ in range(num): id: Optional[str] = ids[_] if ids else None diff --git a/tests/unit/v1/test_order.py b/tests/unit/v1/test_order.py index 8b723b14f..1942a5298 100644 --- a/tests/unit/v1/test_order.py +++ b/tests/unit/v1/test_order.py @@ -178,6 +178,20 @@ def test_order_compare_w_failure_to_find_type(): assert message.startswith("Unknown TypeOrder") +@pytest.mark.parametrize("invalid_point_is_left", [True, False]) +def test_order_compare_invalid_geo_points(invalid_point_is_left): + """ + comparing invalid geopoints should raise exception + """ + target = _make_order() + points = [_array_value(), _geoPoint_value(10, 10)] + if not invalid_point_is_left: + # reverse points + points = points[::-1] + with pytest.raises(AttributeError): + target.compare_geo_points(*points) + + def test_order_all_value_present(): from google.cloud.firestore_v1.order import _TYPE_ORDER_MAP, TypeOrder diff --git a/tests/unit/v1/test_rate_limiter.py b/tests/unit/v1/test_rate_limiter.py index 3767108ae..1ed1d6053 100644 --- a/tests/unit/v1/test_rate_limiter.py +++ b/tests/unit/v1/test_rate_limiter.py @@ -233,3 +233,25 @@ def test_utcnow(): ): now = rate_limiter.utcnow() assert isinstance(now, datetime.datetime) + + +def test_rate_limiter_check_phase_error(): + """ + calling _check_phase with no _start time raises TypeError + """ + ramp = rate_limiter.RateLimiter( + global_max_tokens=499, + ) + with pytest.raises(TypeError): + ramp._check_phase() + + +def test_rate_limiter_refill_error(): + """ + calling _refill with no _last_refill raises TypeError + """ + ramp = rate_limiter.RateLimiter( + global_max_tokens=499, + ) + with pytest.raises(TypeError): + ramp._refill() diff --git a/tests/unit/v1/test_vector.py b/tests/unit/v1/test_vector.py index e411eac47..a28a05525 100644 --- a/tests/unit/v1/test_vector.py +++ b/tests/unit/v1/test_vector.py @@ -56,6 +56,13 @@ def test_compare_vector(): assert vector1 == vector2 +def test_compare_different_type(): + vector1 = Vector([1.0, 2.0, 3.0]) + vector2 = [1.0, 2.0, 3.0] + + assert vector1 != vector2 + + def test_vector_get_items(): vector = Vector([1.0, 2.0, 3.0])