Skip to content

Commit

Permalink
Merge branch 'main' into owl-bot-update-lock-5cddfe2fb5019bbf78335bc5…
Browse files Browse the repository at this point in the history
…5f15bc13e18354a56b3ff46e1834f8e540807f05
  • Loading branch information
daniel-sanche authored Dec 19, 2024
2 parents 3314c9a + a1596a3 commit 1117399
Show file tree
Hide file tree
Showing 38 changed files with 436 additions and 338 deletions.
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from google.cloud.firestore_v1 import gapic_version as package_version

__version__ = package_version.__version__
__version__: str = package_version.__version__

from typing import List

Expand Down
63 changes: 41 additions & 22 deletions google/cloud/firestore_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""Common helpers shared across Google Cloud Firestore modules."""

from __future__ import annotations
import datetime
import json
from typing import (
Expand All @@ -22,14 +22,17 @@
Generator,
Iterator,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Union,
cast,
TYPE_CHECKING,
)

import grpc # type: ignore
from google.api_core import gapic_v1
from google.api_core import retry as retries
from google.api_core.datetime_helpers import DatetimeWithNanoseconds
from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore
from google.protobuf import struct_pb2
Expand All @@ -44,6 +47,9 @@
from google.cloud.firestore_v1.types.write import DocumentTransform
from google.cloud.firestore_v1.vector import Vector

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1 import DocumentSnapshot

_EmptyDict: transforms.Sentinel
_GRPC_ERROR_MAPPING: dict

Expand Down Expand Up @@ -234,7 +240,9 @@ def encode_dict(values_dict) -> dict:
return {key: encode_value(value) for key, value in values_dict.items()}


def document_snapshot_to_protobuf(snapshot: "google.cloud.firestore_v1.base_document.DocumentSnapshot") -> Optional["google.cloud.firestore_v1.types.Document"]: # type: ignore
def document_snapshot_to_protobuf(
snapshot: "DocumentSnapshot",
) -> Optional["google.cloud.firestore_v1.types.Document"]:
from google.cloud.firestore_v1.types import Document

if not snapshot.exists:
Expand Down Expand Up @@ -405,7 +413,8 @@ def decode_dict(value_fields, client) -> Union[dict, Vector]:
if res.get("__type__", None) == "__vector__":
# Vector data type is represented as mapping.
# {"__type__":"__vector__", "value": [1.0, 2.0, 3.0]}.
return Vector(res["value"])
values = cast(Sequence[float], res["value"])
return Vector(values)

return res

Expand Down Expand Up @@ -504,7 +513,7 @@ def __init__(self, document_data) -> None:
self.increments = {}
self.minimums = {}
self.maximums = {}
self.set_fields = {}
self.set_fields: dict = {}
self.empty_document = False

prefix_path = FieldPath()
Expand Down Expand Up @@ -566,7 +575,9 @@ def transform_paths(self):
+ list(self.minimums)
)

def _get_update_mask(self, allow_empty_mask=False) -> None:
def _get_update_mask(
self, allow_empty_mask=False
) -> Optional[types.common.DocumentMask]:
return None

def get_update_pb(
Expand Down Expand Up @@ -730,9 +741,9 @@ class DocumentExtractorForMerge(DocumentExtractor):

def __init__(self, document_data) -> None:
super(DocumentExtractorForMerge, self).__init__(document_data)
self.data_merge = []
self.transform_merge = []
self.merge = []
self.data_merge: list = []
self.transform_merge: list = []
self.merge: list = []

def _apply_merge_all(self) -> None:
self.data_merge = sorted(self.field_paths + self.deleted_fields)
Expand Down Expand Up @@ -786,7 +797,7 @@ def _apply_merge_paths(self, merge) -> None:
self.data_merge.append(field_path)

# Clear out data for fields not merged.
merged_set_fields = {}
merged_set_fields: dict = {}
for field_path in self.data_merge:
value = get_field_value(self.document_data, field_path)
set_field_value(merged_set_fields, field_path, value)
Expand Down Expand Up @@ -1019,7 +1030,7 @@ def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]:
class WriteOption(object):
"""Option used to assert a condition on a write operation."""

def modify_write(self, write, no_create_msg=None) -> NoReturn:
def modify_write(self, write, no_create_msg=None) -> None:
"""Modify a ``Write`` protobuf based on the state of this write option.
This is a virtual method intended to be implemented by subclasses.
Expand Down Expand Up @@ -1059,7 +1070,7 @@ def __eq__(self, other):
return NotImplemented
return self._last_update_time == other._last_update_time

def modify_write(self, write, **unused_kwargs) -> None:
def modify_write(self, write, *unused_args, **unused_kwargs) -> None:
"""Modify a ``Write`` protobuf based on the state of this write option.
The ``last_update_time`` is added to ``write_pb`` as an "update time"
Expand Down Expand Up @@ -1096,7 +1107,7 @@ def __eq__(self, other):
return NotImplemented
return self._exists == other._exists

def modify_write(self, write, **unused_kwargs) -> None:
def modify_write(self, write, *unused_args, **unused_kwargs) -> None:
"""Modify a ``Write`` protobuf based on the state of this write option.
If:
Expand All @@ -1115,7 +1126,9 @@ def modify_write(self, write, **unused_kwargs) -> None:
write._pb.current_document.CopyFrom(current_doc._pb)


def make_retry_timeout_kwargs(retry, timeout) -> dict:
def make_retry_timeout_kwargs(
retry: retries.Retry | retries.AsyncRetry | object | None, timeout: float | None
) -> dict:
"""Helper fo API methods which take optional 'retry' / 'timeout' args."""
kwargs = {}

Expand Down Expand Up @@ -1152,8 +1165,8 @@ def compare_timestamps(

def deserialize_bundle(
serialized: Union[str, bytes],
client: "google.cloud.firestore_v1.client.BaseClient", # type: ignore
) -> "google.cloud.firestore_bundle.FirestoreBundle": # type: ignore
client: "google.cloud.firestore_v1.client.BaseClient",
) -> "google.cloud.firestore_bundle.FirestoreBundle":
"""Inverse operation to a `FirestoreBundle` instance's `build()` method.
Args:
Expand Down Expand Up @@ -1211,7 +1224,7 @@ def deserialize_bundle(
# Create and add our BundleElement
bundle_element: BundleElement
try:
bundle_element: BundleElement = BundleElement.from_json(json.dumps(data)) # type: ignore
bundle_element = BundleElement.from_json(json.dumps(data))
except AttributeError as e:
# Some bad serialization formats cannot be universally deserialized.
if e.args[0] == "'dict' object has no attribute 'find'": # pragma: NO COVER
Expand All @@ -1235,18 +1248,22 @@ def deserialize_bundle(

if "__end__" not in allowed_next_element_types:
raise ValueError("Unexpected end to serialized FirestoreBundle")

# state machine guarantees bundle and metadata have been populated
bundle = cast(FirestoreBundle, bundle)
metadata_bundle_element = cast(BundleElement, metadata_bundle_element)
# Now, finally add the metadata element
bundle._add_bundle_element(
metadata_bundle_element,
client=client,
type="metadata", # type: ignore
type="metadata",
)

return bundle


def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict, None, None]: # type: ignore
def _parse_bundle_elements_data(
serialized: Union[str, bytes]
) -> Generator[Dict, None, None]:
"""Reads through a serialized FirestoreBundle and yields JSON chunks that
were created via `BundleElement.to_json(bundle_element)`.
Expand Down Expand Up @@ -1290,7 +1307,7 @@ def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict

def _get_documents_from_bundle(
bundle, *, query_name: Optional[str] = None
) -> Generator["google.cloud.firestore.DocumentSnapshot", None, None]: # type: ignore
) -> Generator["DocumentSnapshot", None, None]:
from google.cloud.firestore_bundle.bundle import _BundledDocument

bundled_doc: _BundledDocument
Expand All @@ -1304,7 +1321,9 @@ def _get_document_from_bundle(
bundle,
*,
document_id: str,
) -> Optional["google.cloud.firestore.DocumentSnapshot"]: # type: ignore
) -> Optional["DocumentSnapshot"]:
bundled_doc = bundle.documents.get(document_id)
if bundled_doc:
return bundled_doc.snapshot
else:
return None
12 changes: 3 additions & 9 deletions google/cloud/firestore_v1/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def __init__(
def get(
self,
transaction=None,
retry: Union[
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
retry: Union[retries.Retry, None, object] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
*,
explain_options: Optional[ExplainOptions] = None,
Expand Down Expand Up @@ -131,9 +129,7 @@ def _retry_query_after_exception(self, exc, retry, transaction):
def _make_stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Union[
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
retry: Union[retries.Retry, None, object] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
explain_options: Optional[ExplainOptions] = None,
) -> Generator[List[AggregationResult], Any, Optional[ExplainMetrics]]:
Expand Down Expand Up @@ -206,9 +202,7 @@ def _make_stream(
def stream(
self,
transaction: Optional["transaction.Transaction"] = None,
retry: Union[
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
retry: Union[retries.Retry, None, object] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
*,
explain_options: Optional[ExplainOptions] = None,
Expand Down
8 changes: 3 additions & 5 deletions google/cloud/firestore_v1/async_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def __init__(
async def get(
self,
transaction=None,
retry: Union[
retries.AsyncRetry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
retry: Union[retries.AsyncRetry, None, object] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
*,
explain_options: Optional[ExplainOptions] = None,
Expand Down Expand Up @@ -102,7 +100,7 @@ async def get(
async def _make_stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
explain_options: Optional[ExplainOptions] = None,
) -> AsyncGenerator[List[AggregationResult] | query_profile_pb.ExplainMetrics, Any]:
Expand Down Expand Up @@ -162,7 +160,7 @@ async def _make_stream(
def stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
*,
explain_options: Optional[ExplainOptions] = None,
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/firestore_v1/async_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""Helpers for batch requests to the Google Cloud Firestore API."""

from __future__ import annotations

from google.api_core import gapic_v1
from google.api_core import retry_async as retries
Expand All @@ -38,8 +38,8 @@ def __init__(self, client) -> None:

async def commit(
self,
retry: retries.AsyncRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> list:
"""Commit the changes accumulated in this batch.
Expand Down
21 changes: 11 additions & 10 deletions google/cloud/firestore_v1/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* a :class:`~google.cloud.firestore_v1.client.Client` owns a
:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional, Union

Expand Down Expand Up @@ -222,10 +223,10 @@ def document(self, *document_path: str) -> AsyncDocumentReference:
async def get_all(
self,
references: List[AsyncDocumentReference],
field_paths: Iterable[str] = None,
transaction=None,
retry: retries.AsyncRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
field_paths: Iterable[str] | None = None,
transaction: AsyncTransaction | None = None,
retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> AsyncGenerator[DocumentSnapshot, Any]:
"""Retrieve a batch of documents.
Expand Down Expand Up @@ -280,8 +281,8 @@ async def get_all(

async def collections(
self,
retry: retries.AsyncRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> AsyncGenerator[AsyncCollectionReference, Any]:
"""List top-level collections of the client's database.
Expand Down Expand Up @@ -310,8 +311,8 @@ async def recursive_delete(
reference: Union[AsyncCollectionReference, AsyncDocumentReference],
*,
bulk_writer: Optional["BulkWriter"] = None,
chunk_size: Optional[int] = 5000,
):
chunk_size: int = 5000,
) -> int:
"""Deletes documents and their subcollections, regardless of collection
name.
Expand Down Expand Up @@ -346,8 +347,8 @@ async def _recursive_delete(
reference: Union[AsyncCollectionReference, AsyncDocumentReference],
bulk_writer: "BulkWriter",
*,
chunk_size: Optional[int] = 5000,
depth: Optional[int] = 0,
chunk_size: int = 5000,
depth: int = 0,
) -> int:
"""Recursion helper for `recursive_delete."""

Expand Down
Loading

0 comments on commit 1117399

Please sign in to comment.