diff --git a/.coveragerc b/.coveragerc index 1ee3561b23..659df7b6e9 100644 --- a/.coveragerc +++ b/.coveragerc @@ -29,6 +29,9 @@ exclude_lines = @overload + # Those are not supposed to be hit + assert_never\(\w+\) + ignore_errors = True omit = diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..aa7840e194 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: minor + +Initial relay spec implementation. For information on how to use +it, check out the docs in here: https://strawberry.rocks/docs/guides/relay diff --git a/docs/README.md b/docs/README.md index 75a6e7b34d..c3af64722c 100644 --- a/docs/README.md +++ b/docs/README.md @@ -44,6 +44,7 @@ - [Dealing with errors](./guides/errors.md) - [Federation](./guides/federation.md) - [Federation V1](./guides/federation-v1.md) +- [Relay](./guides/relay.md) - [Custom extensions](./guides/custom-extensions.md) - [File upload](./guides/file-upload.md) - [Pagination](./guides/pagination/overview.md) diff --git a/docs/errors/relay-wrong-annotation.md b/docs/errors/relay-wrong-annotation.md new file mode 100644 index 0000000000..0722423043 --- /dev/null +++ b/docs/errors/relay-wrong-annotation.md @@ -0,0 +1,70 @@ +--- +title: Relay wrong annotation Error +--- + +# Relay wrong annotation error + +## Description + +This error is thrown when a field on a relay connection has a wrong +type annotation. For example, the following code will throw this error: + +```python +from typing import List + +import strawberry + + +@strawberry.type +class NonNodeSubclassType: + ... + + +@strawberry.type +class Query: + @strawberry.relay.connection + def some_connection(self) -> int: + ... + + @strawberry.relay.connection + def some_other_connection(self) -> List[NonNodeSubclassType]: + ... +``` + +This happens because when defining a custom resolver for the connection, +it expects the type annotation to be one of: `Iterable[]`, +`Iterator[]`, `AsyncIterable[]` or `AsyncIterator[]` +- `Iterator[]` +- `AsyncIterable[]` +- `AsyncIterator[ List[NodeSubclassType]: + ... + + @strawberry.relay.connection + def some_other_connection(self) -> Iterable[NodeSubclassType]: + ... +``` diff --git a/docs/errors/relay-wrong-node-resolver-annotation.md b/docs/errors/relay-wrong-node-resolver-annotation.md new file mode 100644 index 0000000000..62b6e6a930 --- /dev/null +++ b/docs/errors/relay-wrong-node-resolver-annotation.md @@ -0,0 +1,66 @@ +--- +title: Relay wrong annotation Error +--- + +# Relay wrong annotation error + +## Description + +This error is thrown when a field on a relay connection was defined with +a `node_converter` that has the wrong type type annotation. For example, +the following code will throw this error: + +```python +from typing import Iterable + +import strawberry + + +@strawberry.type +class NonNodeSubclassType: + ... + + +def node_converter(node_id: str) -> NonNodeSubclassType: + ... + + +@strawberry.type +class Query: + @strawberry.relay.connection(NonNodeSubclassType) + def some_connection(self) -> Iterable[str]: + ... +``` + +This happens because when defining a `node_converter`, it is expected +to be a function that receives the iterable element as its single argument, +and should return the correct strawberry `Node` implemented type. + +## How to fix this error + +You can fix this error by annotating the `node_converter` function to +return the correct strawberry `Node` implemented type. + +For example: + +```python +from typing import Iterable + +import strawberry + + +@strawberry.type +class NodeSubclassType(strawberry.relay.Node): + ... + + +def node_converter(node_id: str) -> NodeSubclassType: + ... + + +@strawberry.type +class Query: + @strawberry.relay.connection(node_converter=node_converter) + def some_connection(self) -> Iterable[str]: + ... +``` diff --git a/docs/guides/pagination/overview.md b/docs/guides/pagination/overview.md index 0d29c67986..a7c8289eab 100644 --- a/docs/guides/pagination/overview.md +++ b/docs/guides/pagination/overview.md @@ -117,6 +117,14 @@ need a reliable and consistent way to handle pagination. ### Cursor based pagination + + +Strawberry provides a cursor based pagination implementing the +[relay spec](https://relay.dev/docs/guides/graphql-server-specification/). +You can read more about it in the [relay](./input-types) page. + + + Cursor based pagination, also known as keyset pagination, works by returning a pointer to a specific item in the dataset. On subsequent requests, the server returns results after the given pointer. This method addresses the drawbacks of using offset pagination, but does so by making certain trade offs: diff --git a/docs/guides/relay.md b/docs/guides/relay.md new file mode 100644 index 0000000000..613e246d65 --- /dev/null +++ b/docs/guides/relay.md @@ -0,0 +1,394 @@ +--- +title: Relay +--- + +# Relay Guide + +## What is Relay? + +The relay spec defines some interfaces that GraphQL servers can follow to allow +clients to interact with them in a more efficient way. The spec makes two +core assumptions about a GraphQL server: + +1. It provides a mechanism for refetching an object +2. It provides a description of how to page through connections. + +You can read more about the relay spec +[here](https://relay.dev/docs/en/graphql-server-specification/). + +### Relay implementation example + +Suppose we have the following type: + +```python +@strawberry.type +class Fruit: + name: str + weight: str +``` + +We want it to have a globally unique ID, a way to retrieve a paginated results +list of it and a way to refetch if if necessary. For that, we need to inherit it +from the `Node` interface and implement its abstract methods: `resolve_id`, +and `resolve_nodes`. + +```python +from strawberry import relay + + +@strawberry.type +class Fruit(relay.Node): + code: relay.NodeID[str] + name: str + weight: float + + @classmethod + def resolve_nodes( + cls, + *, + info: Info, + node_ids: Optional[Iterable[str]] = None, + required: bool = False, + ): + if node_ids is not None: + return [fruits[nid] if required else fruits.get(nid) for nid in node_ids] + + return fruits.values() + + +# Assume we have a dict mapping the fruits code to the Fruit object itself +fruits: Dict[int, Fruit] +``` + +With that, our `Fruit` type know knows how to retrieve a `Fruit` instance given its +`id`, and also how to retrieve that `id`. + +Now we can expose it in the schema for retrieval and pagination like: + +```python +@strawberry.type +class Query: + node: relay.Node + fruits: relay.Connection[Fruit] +``` + +This will generate a schema like this: + +```graphql +scalar GlobalID + +interface Node { + id: GlobalID! +} + +type PageInfo { + hasNextPage: Boolean! + hasPreviousPage: Boolean! + startCursor: String + endCursor: String +} + +type Fruit implements Node { + id: GlobalID! + name: String! + weight: Float! +} + +type FruitEdge { + cursor: String! + node: Fruit! +} + +type FruitConnection { + pageInfo: PageInfo! + edges: [FruitEdge!]! + totalCount: Int +} + +type Query { + node(id: GlobalID!): Node! + fruits( + before: String = null + after: String = null + first: Int = null + last: Int = null + ): FruitConnection! +} +``` + +With only that we have a way to query `node` to retrieve any `Node` implemented +type in our schema (which includes our `Fruit` type), and also a way to retrieve +a list of fruits with pagination. + +For example, to retrieve a single fruit given its unique ID: + +```graphql +query { + node(id: "") { + id + ... on Fruit { + name + weight + } + } +} +``` + +Or to retrieve the first 10 fruits available: + +```graphql +query { + fruitConnection(first: 10) { + pageInfo { + firstCursor + endCursor + hasNextPage + hasPreviousPage + } + edges { + # node here is the Fruit type + node { + id + name + weight + } + } + } +} +``` + +### The node field + +As demonstrated above, the `Node` field can be used to retrieve/refetch any +object in the schema that implements the `Node` interface. + +It can be defined in in the `Query` objects in 4 ways: + +- `node: Node`: This will define a field that accepts a `GlobalID!` and returns + a `Node` instance. This is the most basic way to define it. +- `node: Optional[Node]`: The same as `Node`, but if the given object doesn't + exist, it will return `null`. +- `node: List[Node]`: This will define a field that accepts `[GlobalID!]!` and + returns a list of `Node` instances. They can even be from different types. +- `node: List[Optional[Node]]`: The same as `List[Node]`, but the returned list + can contain `null` values if the given objects don't exist. + +### Custom connection pagination + +The default `Connection` implementation uses a limit/offset approach to paginate +the results. This is a basic approach and might be enough for most use cases. + + + +`Connection` implementes the limit/offset by using slices. That means that you can +override what the slice does by customizing the `__getitem__` method of the object +returned by `resolve_nodes`. + +For example, when working with `Django`, `resolve_nodes` can return a `QuerySet`, +meaning that the slice on it will translate to a `LIMIT`/`OFFSET` in the SQL +query, which will fetch only the data that is needed from the database. + +Also note that if that object doesn't have a `__getitem__` attribute, it will +use `itertools.islice` to paginate it, meaning that when a generator is being +resolved it will only generate as much results as needed for the given pagination, +the worst case scenario being the last results needing to be returned. + + + +You may want to use a different approach to paginate your results. For example, +a cursor-based approach. For that you need to subclass the `Connection` type +and implement your own `from_nodes` method. For example, suppose that in our +exaple above, we want to use the fruit's weight as the cursor, we can implement +it like that: + +```python +from strawberry import relay + + +@strawberry.type +class FruitCustomPaginationConnection(relay.Connection[Fruit]): + @classmethod + def from_nodes( + cls, + nodes: Iterable[Fruit], + *, + info: Optional[Info] = None, + total_count: Optional[int] = None, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + ): + # Note that this is a showcase implementation and is far from + # being optimal performance wise + edges_mapping = { + relay.to_base64("fruit_name", n.name): strawberry.relay.Edge( + node=n, + cursor=relay.to_base64("fruit_name", n.name), + ) + for n in sorted(nodes, key=lambda f: f.name) + } + edges = list(edges_mapping.values()) + first_edge = edges[0] if edges else None + last_edge = edges[-1] if edges else None + + if after is not None: + after_edge_idx = edges.index(edges_mapping[after]) + edges = [e for e in edges if edges.index(e) > after_edge_idx] + + if before is not None: + before_edge_idx = edges.index(edges_mapping[before]) + edges = [e for e in edges if edges.index(e) < before_edge_idx] + + if first is not None: + edges = edges[:first] + + if last is not None: + edges = edges[-last:] + + return cls( + edges=edges, + page_info=strawberry.relay.PageInfo( + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + has_previous_page=( + first_edge is not None and bool(edges) and edges[0] != first_edge + ), + has_next_page=( + last_edge is not None and bool(edges) and edges[-1] != last_edge + ), + ), + ) + + +@strawberry.type +class Query: + fruits: FruitCustomPaginationConnection +``` + + + +In the example above we specialized the `FruitCustomPaginationConnection` by +inheriting it from `relay.Connection[Fruit]`. We could still keep it generic by +inheriting it from `relay.Connection[relay.NodeType]` and then specialize it +when defining the field. + + + +### Custom connection resolver + +We can define custom resolvers for the `Connection` as a way to pre-filter +the results. All that needs to be done is to decorate the resolver with +`@strawberry.relay.connection` and return an `Iterator`/`AsyncIterator` of that +given `Node` type in it. For example, suppose we want to return the pagination +of all fruits whose name starts with a given string: + +```python +@strawberry.type +class Query: + @relay.connection + def fruits_with_filter( + self, + info: Info, + name_endswith: str, + ) -> Iterable[Fruit]: + for f in fruits.values(): + if f.name.endswith(name_endswith): + yield f +``` + +This will generate a schema like this: + +```graphql +type Query { + fruitsWithFilter( + nameEndswith: String! + before: String = null + after: String = null + first: Int = null + last: Int = null + ): FruitConnection! +} +``` + +The custom resolver can to be annotated with any of the following: + +- `List[]` +- `Iterator[]` +- `Iterable[]` +- `AsyncIterator[]` +- `AsyncIterable[]` +- `Generator[, Any, Any]` +- `AsyncGenerator[, Any]` + + + +If your custom resolver returns something different than the expected type +(e.g. a django model, and you are not using the django integration), you can pass +a `node_converter` function to the `Connection` to convert it properly, like: + +```python +def fruit_converter(model: models.Fruit) -> Fruit: + return Fruit(id=model.pk, name=model.name, weight=model.weight) + + +@strawberry.type +class Query: + @relay.connection(node_converter=fruit_converter) + def fruits_with_filter( + self, + info: Info, + name_endswith: str, + ) -> Iterable[Fruit]: + return models.Fruit.objects.filter(name__endswith=name_endswith) +``` + +The main advantage of this approach instead of converting it inside the custom +resolver is that the `Connection` will paginate the `QuerySet` first, which in +case of django will make sure that only the paginated results are fetched from the +database. After that, the `fruit_converter` function will be called for each result +to retrieve the correct object for it. + +We used django for this example, but the same applies to any other other +similar use case, like SQLAlchemy, etc. + + + +### The GlobalID scalar + +The `GlobalID` scalar is a special object that contains all the info necessary to +identify and retrieve a given object that implements the `Node` interface. + +It can for example be useful in a mutation, to receive and object and retrieve +it in its resolver. For example: + +```python +@strawberry.type +class Mutation: + @strawberry.mutation + def update_fruit_weight( + self, + info: Info, + id: relay.GlobalID, + weight: float, + ) -> Fruit: + # resolve_node will return the Fruit object + fruit = id.resolve_node(info, ensure_type=Fruit) + fruit.weight = weight + return fruit + + @strawberry.mutation + async def update_fruit_weight_async( + self, + info: Info, + id: relay.GlobalID, + weight: float, + ) -> Fruit: + # aresolve_node will return an awaitable that returns the Fruit object + fruit = await id.aresolve_node(info, ensure_type=Fruit) + fruit.weight = weight + return fruit +``` + +In the example above, you can also access the type name directly with `id.type_name`, +the raw node ID with `id.id`, or even resolve the type itself with `id.resolve_type(info)`. diff --git a/pyproject.toml b/pyproject.toml index f88be99f47..7a7e5fdec3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ tests/codegen/snapshots/ addopts = "-s --emoji --mypy-ini-file=mypy.ini --benchmark-disable" DJANGO_SETTINGS_MODULE = "tests.django.django_settings" testpaths = ["tests/"] -markers = ["django", "starlette", "channels"] +markers = ["django", "starlette", "channels", "relay"] asyncio_mode = "auto" filterwarnings = [ "ignore::DeprecationWarning:strawberry.*.resolver", diff --git a/strawberry/__init__.py b/strawberry/__init__.py index 4ab62fa742..0a6cd0d256 100644 --- a/strawberry/__init__.py +++ b/strawberry/__init__.py @@ -1,4 +1,4 @@ -from . import experimental, federation +from . import experimental, federation, relay from .arguments import argument from .auto import auto from .custom_scalar import scalar @@ -42,4 +42,5 @@ "union", "auto", "asdict", + "relay", ] diff --git a/strawberry/field.py b/strawberry/field.py index a9801703d2..c532256a35 100644 --- a/strawberry/field.py +++ b/strawberry/field.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: import builtins - from typing_extensions import Literal + from typing_extensions import Literal, Self from strawberry.arguments import StrawberryArgument from strawberry.extensions.field_extension import FieldExtension @@ -129,7 +129,7 @@ def __init__( try: self.default_value = default_factory() except TypeError as exc: - raise InvalidDefaultFactoryError() from exc + raise InvalidDefaultFactoryError from exc self.is_subscription = is_subscription @@ -297,7 +297,7 @@ def type_params(self) -> List[TypeVar]: def copy_with( self, type_var_map: Mapping[TypeVar, Union[StrawberryType, builtins.type]] - ) -> StrawberryField: + ) -> Self: new_type: Union[StrawberryType, type] = self.type # TODO: Remove with creation of StrawberryObject. Will act same as other @@ -317,7 +317,7 @@ def copy_with( else None ) - return StrawberryField( + return type(self)( python_name=self.python_name, graphql_name=self.graphql_name, # TODO: do we need to wrap this in `StrawberryAnnotation`? diff --git a/strawberry/object_type.py b/strawberry/object_type.py index 262a8e3a91..9b10a27d61 100644 --- a/strawberry/object_type.py +++ b/strawberry/object_type.py @@ -3,6 +3,7 @@ import sys import types from typing import ( + TYPE_CHECKING, Callable, Dict, List, @@ -180,9 +181,25 @@ def _process_type( return cls +if TYPE_CHECKING: + # Avoid circular import issues + from .relay import ConnectionField, NodeField, connection, node + + field_descriptors = ( + field, + node, + connection, + StrawberryField, + NodeField, + ConnectionField, + ) +else: + field_descriptors = (field, StrawberryField) + + @overload @__dataclass_transform__( - order_default=True, kw_only_default=True, field_descriptors=(field, StrawberryField) + order_default=True, kw_only_default=True, field_descriptors=field_descriptors ) def type( cls: T, @@ -199,7 +216,7 @@ def type( @overload @__dataclass_transform__( - order_default=True, kw_only_default=True, field_descriptors=(field, StrawberryField) + order_default=True, kw_only_default=True, field_descriptors=field_descriptors ) def type( *, @@ -261,7 +278,7 @@ def wrap(cls): @overload @__dataclass_transform__( - order_default=True, kw_only_default=True, field_descriptors=(field, StrawberryField) + order_default=True, kw_only_default=True, field_descriptors=field_descriptors ) def input( cls: T, @@ -275,7 +292,7 @@ def input( @overload @__dataclass_transform__( - order_default=True, kw_only_default=True, field_descriptors=(field, StrawberryField) + order_default=True, kw_only_default=True, field_descriptors=field_descriptors ) def input( *, @@ -311,7 +328,7 @@ def input( @overload @__dataclass_transform__( - order_default=True, kw_only_default=True, field_descriptors=(field, StrawberryField) + order_default=True, kw_only_default=True, field_descriptors=field_descriptors ) def interface( cls: T, @@ -325,7 +342,7 @@ def interface( @overload @__dataclass_transform__( - order_default=True, kw_only_default=True, field_descriptors=(field, StrawberryField) + order_default=True, kw_only_default=True, field_descriptors=field_descriptors ) def interface( *, @@ -337,7 +354,7 @@ def interface( @__dataclass_transform__( - order_default=True, kw_only_default=True, field_descriptors=(field, StrawberryField) + order_default=True, kw_only_default=True, field_descriptors=field_descriptors ) def interface( cls: Optional[T] = None, diff --git a/strawberry/relay/__init__.py b/strawberry/relay/__init__.py new file mode 100644 index 0000000000..77668820b0 --- /dev/null +++ b/strawberry/relay/__init__.py @@ -0,0 +1,30 @@ +from .fields import ConnectionField, NodeField, RelayField, connection, node +from .types import ( + Connection, + Edge, + GlobalID, + GlobalIDValueError, + Node, + NodeID, + NodeType, + PageInfo, +) +from .utils import from_base64, to_base64 + +__all__ = [ + "Connection", + "ConnectionField", + "Edge", + "GlobalID", + "GlobalIDValueError", + "Node", + "NodeField", + "NodeID", + "NodeType", + "PageInfo", + "RelayField", + "connection", + "from_base64", + "node", + "to_base64", +] diff --git a/strawberry/relay/exceptions.py b/strawberry/relay/exceptions.py new file mode 100644 index 0000000000..81ef548046 --- /dev/null +++ b/strawberry/relay/exceptions.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Optional, cast + +from strawberry.exceptions.exception import StrawberryException +from strawberry.exceptions.utils.source_finder import SourceFinder +from strawberry.utils.cached_property import cached_property + +if TYPE_CHECKING: + from strawberry.exceptions.exception_source import ExceptionSource + from strawberry.types.fields.resolver import StrawberryResolver + + +class RelayWrongAnnotationError(StrawberryException): + def __init__(self, field_name: str, resolver: StrawberryResolver): + self.resolver = resolver.wrapped_func + self.field_name = field_name + + self.message = ( + f'Unable to determine the connection type of field "{field_name}". ' + "It should be annotated with a return value of `List[]`, " + "`Iterable[]`, `Iterator[]`, " + "`AsyncIterable[]` or `AsyncIterator[]`" + ) + self.rich_message = ( + f"Wrong annotation for field `[underline]{self.field_name}[/]`" + ) + self.suggestion = ( + "To fix this error you can annotate the return it using " + "a return value of `List[]`, " + "`Iterable[]`, `Iterator[]`, " + "`AsyncIterable[]` or `AsyncIterator[]`" + ) + self.annotation_message = "relay custom resolver wrong annotation" + + super().__init__(self.message) + + @cached_property + def exception_source(self) -> Optional[ExceptionSource]: + if self.resolver is None: + return None # pragma: no cover + + source_finder = SourceFinder() + return source_finder.find_function_from_object(cast(Callable, self.resolver)) + + +class RelayWrongNodeResolverAnnotationError(StrawberryException): + def __init__(self, field_name: str, resolver: StrawberryResolver): + self.resolver = resolver.wrapped_func + self.field_name = field_name + + self.message = ( + f'Unable to determine the connection type of field "{field_name}". ' + "The `node_resolver` function should be annotated with a return value " + "of ``" + ) + self.rich_message = ( + "Wrong annotation for field `node_resolver` function used " + "in the `@relay.connection` decorator of field " + "[underline]{self.field_name}[/]`" + ) + self.suggestion = ( + "To fix this error you can annotate the `node_resolver` function " + "using a return value of ``" + ) + self.annotation_message = "relay node_resolver wrong annotation" + + super().__init__(self.message) + + @cached_property + def exception_source(self) -> Optional[ExceptionSource]: + if self.resolver is None: + return None # pragma: no cover + + source_finder = SourceFinder() + return source_finder.find_function_from_object(cast(Callable, self.resolver)) diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py new file mode 100644 index 0000000000..bc15937731 --- /dev/null +++ b/strawberry/relay/fields.py @@ -0,0 +1,674 @@ +from __future__ import annotations + +import asyncio +import dataclasses +import inspect +import sys +from collections import defaultdict +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + DefaultDict, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from typing_extensions import Literal, Self, get_args, get_origin, get_type_hints + +from strawberry.annotation import StrawberryAnnotation +from strawberry.arguments import StrawberryArgument +from strawberry.exceptions.missing_return_annotation import MissingReturnAnnotationError +from strawberry.field import _RESOLVER_TYPE, StrawberryField +from strawberry.lazy_type import LazyType +from strawberry.scalars import ID +from strawberry.type import StrawberryList, StrawberryOptional, StrawberryType +from strawberry.types.fields.resolver import StrawberryResolver +from strawberry.types.types import TypeDefinition +from strawberry.utils.aio import asyncgen_to_list, resolve_awaitable +from strawberry.utils.cached_property import cached_property + +from .exceptions import ( + RelayWrongAnnotationError, + RelayWrongNodeResolverAnnotationError, +) +from .types import Connection, GlobalID, Node, NodeIterableType, NodeType + +if TYPE_CHECKING: + from strawberry.permission import BasePermission + from strawberry.types.info import Info + from strawberry.utils.await_maybe import AwaitableOrValue + +_T = TypeVar("_T") + + +class RelayField(StrawberryField): + """Base relay field, containing utilities for both Node and Connection fields.""" + + default_args: Dict[str, StrawberryArgument] + + def __init__( + self, + *args, + node_converter: Optional[Callable[[object], Node]] = None, + **kwargs, + ): + self.node_converter = node_converter + super().__init__(*args, **kwargs) + + @property + def arguments(self) -> List[StrawberryArgument]: + args = { + **self.default_args, + **{arg.python_name: arg for arg in super().arguments}, + } + return list(args.values()) + + @cached_property + def is_basic_field(self): + return False + + @cached_property + def is_optional(self): + type_ = self.type + if isinstance(type_, StrawberryList): + type_ = type_.of_type + + return isinstance(type_, StrawberryOptional) + + @cached_property + def is_list(self): + type_ = self.type + if isinstance(type_, StrawberryOptional): + type_ = type_.of_type + + return isinstance(type_, StrawberryList) + + def copy_with( + self, + type_var_map: Mapping[TypeVar, Union[StrawberryType, type]], + ) -> Self: + retval = super().copy_with(type_var_map) + retval.default_args = self.default_args + return retval + + +class NodeField(RelayField): + """Relay Node field. + + This field is used to fetch a single object by its ID or multiple + objects given a list of IDs. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if not self.base_resolver and self.is_list: + self.default_args = { + "ids": StrawberryArgument( + python_name="ids", + graphql_name=None, + type_annotation=StrawberryAnnotation(List[ID]), + description="The IDs of the objects.", + ), + } + elif not self.base_resolver: + self.default_args = { + "id": StrawberryArgument( + python_name="id", + graphql_name=None, + type_annotation=StrawberryAnnotation(ID), + description="The ID of the object.", + ), + } + + def __call__(self, resolver): + raise NotImplementedError + + def get_result( + self, + source: Any, + info: Optional[Info], + args: List[Any], + kwargs: Dict[str, Any], + ) -> Union[Awaitable[Any], Any]: + assert info is not None + resolver = self.resolve_nodes if self.is_list else self.resolve_node + + return resolver(source, info, args, kwargs) + + def resolve_node( + self, + source: Any, + info: Info, + args: List[Any], + kwargs: Dict[str, Any], + ) -> AwaitableOrValue[Optional[Node]]: + gid = GlobalID.from_id(kwargs["id"]) + return gid.resolve_type(info).resolve_node( + gid.node_id, + info=info, + required=not self.is_optional, + ) + + def resolve_nodes( + self, + source: Any, + info: Info, + args: List[Any], + kwargs: Dict[str, Any], + ) -> AwaitableOrValue[List[Node]]: + gids: List[GlobalID] = [GlobalID.from_id(id) for id in kwargs["ids"]] + + nodes_map: DefaultDict[Type[Node], List[str]] = defaultdict(list) + # Store the index of the node in the list of nodes of the same type + # so that we can return them in the same order while also supporting different + # types + index_map: Dict[GlobalID, Tuple[Type[Node], int]] = {} + for gid in gids: + node_t = gid.resolve_type(info) + nodes_map[node_t].append(gid.node_id) + index_map[gid] = (node_t, len(nodes_map[node_t]) - 1) + + resolved_nodes = { + node_t: node_t.resolve_nodes( + info=info, + node_ids=node_ids, + required=not self.is_optional, + ) + for node_t, node_ids in nodes_map.items() + } + awaitable_nodes = { + node_t: nodes + for node_t, nodes in resolved_nodes.items() + if inspect.isawaitable(nodes) + } + # Async generators are not awaitable, so we need to handle them separately + asyncgen_nodes = { + node_t: nodes + for node_t, nodes in resolved_nodes.items() + if inspect.isasyncgen(nodes) + } + + if awaitable_nodes or asyncgen_nodes: + + async def resolve(resolved=resolved_nodes): + resolved.update( + zip( + [ + *awaitable_nodes.keys(), + *asyncgen_nodes.keys(), + ], + # Resolve all awaitable nodes concurrently + await asyncio.gather( + *awaitable_nodes.values(), + *( + asyncgen_to_list(nodes) + for nodes in asyncgen_nodes.values() + ), + ), + ) + ) + + # Resolve any generator to lists + resolved = {node_t: list(nodes) for node_t, nodes in resolved.items()} + return [resolved[index_map[gid][0]][index_map[gid][1]] for gid in gids] + + return resolve() + + # Resolve any generator to lists + resolved = { + node_t: list(cast(Iterator[Node], nodes)) + for node_t, nodes in resolved_nodes.items() + } + return [resolved[index_map[gid][0]][index_map[gid][1]] for gid in gids] + + +class ConnectionField(RelayField): + """Relay Connection field. + + Do not instantiate this directly. Instead, use `@relay.connection` + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.default_args = { + "before": StrawberryArgument( + python_name="before", + graphql_name=None, + type_annotation=StrawberryAnnotation(Optional[str]), + description=( + "Returns the items in the list that come before the " + "specified cursor." + ), + default=None, + ), + "after": StrawberryArgument( + python_name="after", + graphql_name=None, + type_annotation=StrawberryAnnotation(Optional[str]), + description=( + "Returns the items in the list that come after the " + "specified cursor." + ), + default=None, + ), + "first": StrawberryArgument( + python_name="first", + graphql_name=None, + type_annotation=StrawberryAnnotation(Optional[int]), + description="Returns the first n items from the list.", + default=None, + ), + "last": StrawberryArgument( + python_name="last", + graphql_name=None, + type_annotation=StrawberryAnnotation(Optional[int]), + description=( + "Returns the items in the list that come after the " + "specified cursor." + ), + default=None, + ), + } + + def __call__(self, resolver: _RESOLVER_TYPE): + retval = super().__call__(resolver) + assert self.base_resolver + + field_name = cast(Callable, self.base_resolver.wrapped_func).__name__ + namespace = sys.modules[resolver.__module__].__dict__ + resolved = get_type_hints(cast(Type, resolver), namespace).get("return") + if resolved is None: + raise MissingReturnAnnotationError( + field_name, + resolver=StrawberryResolver(resolver), + ) + + origin = get_origin(resolved) + + is_connection = ( + origin and isinstance(origin, type) and issubclass(origin, Connection) + ) + is_iterable = ( + origin + and isinstance(origin, type) + and issubclass(origin, (Iterator, AsyncIterator, Iterable, AsyncIterable)) + ) + if not is_connection and not is_iterable: + raise RelayWrongAnnotationError( + field_name=field_name, + resolver=StrawberryResolver(resolver), + ) + + if is_iterable and not is_connection and self.type_annotation is None: + if self.node_converter is not None: + ntype = get_type_hints(self.node_converter).get("return") + if ntype is None: + raise MissingReturnAnnotationError( + field_name, + resolver=self.base_resolver, + ) + if not isinstance(ntype, type) or not issubclass(ntype, Node): + raise RelayWrongNodeResolverAnnotationError( + field_name, + resolver=self.base_resolver, + ) + else: + ntype = get_args(resolved)[0] + if not issubclass(ntype, Node): + raise RelayWrongAnnotationError( + field_name, + resolver=StrawberryResolver(resolver), + ) + + self.type_annotation = StrawberryAnnotation( + Connection[ntype], # type: ignore[valid-type] + namespace=namespace, + ) + + return retval + + @cached_property + def resolver_args(self) -> Set[str]: + resolver = self.base_resolver + if not resolver: + return set() + + if isinstance(resolver, StrawberryResolver): + resolver = resolver.wrapped_func # type: ignore[assignment] + + return set(inspect.signature(cast(Callable, resolver)).parameters.keys()) + + def get_result( + self, + source: Any, + info: Optional[Info], + args: List[Any], + kwargs: Dict[str, Any], + ) -> Union[Awaitable[Any], Any]: + assert info is not None + type_def = info.return_type._type_definition # type:ignore + assert isinstance(type_def, TypeDefinition) + + field_type = type_def.type_var_map[cast(TypeVar, NodeType)] + if isinstance(field_type, LazyType): + field_type = field_type.resolve_type() + + if self.base_resolver is not None: + # If base_resolver is not self.conn_resolver, + # then it is defined to something + assert self.base_resolver + + resolver_args = self.resolver_args + resolver_kwargs = { + # Consider both args not in default args and the ones specified + # by the resolver, in case they want to check + # "first"/"last"/"before"/"after" + k: v + for k, v in kwargs.items() + if k in resolver_args + } + nodes = self.base_resolver(*args, **resolver_kwargs) + else: + nodes = None + + return self.resolver(source, info, args, kwargs, nodes=nodes) + + def resolver( + self, + source: Any, + info: Info, + args: List[Any], + kwargs: Dict[str, Any], + *, + nodes: AwaitableOrValue[ + Optional[Union[Iterable[Node], Connection[Node]]] + ] = None, + ): + # The base_resolver might have resolved to a Connection directly + if isinstance(nodes, Connection): + return nodes + + return_type = cast(Connection[Node], info.return_type) + type_def = return_type._type_definition # type:ignore + assert isinstance(type_def, TypeDefinition) + + field_type = type_def.type_var_map[cast(TypeVar, NodeType)] + if isinstance(field_type, LazyType): + field_type = field_type.resolve_type() + + if nodes is None: + nodes = cast(Node, field_type).resolve_nodes(info=info) + + if inspect.isawaitable(nodes): + return resolve_awaitable( + nodes, + lambda resolved: self.resolver( + source, + info, + args, + kwargs, + nodes=resolved, + ), + ) + + # Avoid info being passed twice in case the custom resolver has one + kwargs.pop("info", None) + return self.resolve_connection(cast(Iterable[Node], nodes), info, **kwargs) + + def resolve_connection( + self, + nodes: NodeIterableType[NodeType], + info: Info, + **kwargs, + ): + return_type = cast(Connection[Node], info.return_type) + kwargs.setdefault("info", info) + return return_type.from_nodes( + nodes, + node_converter=self.node_converter, + **kwargs, + ) + + +def node( + *, + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = dataclasses.MISSING, + default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, + metadata: Optional[Mapping[Any, Any]] = None, + directives: Optional[Sequence[object]] = (), + node_converter: Optional[Callable[[object], NodeType]] = None, + # This init parameter is used by pyright to determine whether this field + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init: Literal[True, False, None] = None, +) -> Any: + """Annotate a property to create a relay query field. + + Examples: + Annotating something like this: + + >>> @strawberry.type + >>> class X: + ... some_node: SomeType = relay.node(description="ABC") + + Will produce a query like this that returns `SomeType` given its id. + + ``` + query { + someNode (id: ID) { + id + ... + } + } + ``` + + """ + return NodeField( + python_name=None, + graphql_name=name, + type_annotation=None, + description=description, + is_subscription=is_subscription, + permission_classes=permission_classes or [], + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + metadata=metadata, + directives=directives or (), + node_converter=node_converter, + ) + + +@overload +def connection( + *, + resolver: _RESOLVER_TYPE[NodeIterableType[NodeType]], + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + init: Literal[False] = False, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = dataclasses.MISSING, + default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, + metadata: Optional[Mapping[Any, Any]] = None, + directives: Optional[Sequence[object]] = (), + graphql_type: Optional[Any] = None, +) -> Connection[NodeType]: + ... + + +@overload +def connection( + *, + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + init: Literal[True] = True, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = dataclasses.MISSING, + default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, + metadata: Optional[Mapping[Any, Any]] = None, + directives: Optional[Sequence[object]] = (), + graphql_type: Optional[Any] = None, + node_converter: Optional[Callable[[Any], NodeType]] = None, +) -> Any: + ... + + +@overload +def connection( + resolver: _RESOLVER_TYPE[NodeIterableType[NodeType]], + *, + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = dataclasses.MISSING, + default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, + metadata: Optional[Mapping[Any, Any]] = None, + directives: Optional[Sequence[object]] = (), + graphql_type: Optional[Any] = None, +) -> ConnectionField: + ... + + +@overload +def connection( + resolver: _RESOLVER_TYPE[NodeIterableType[_T]], + *, + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = dataclasses.MISSING, + default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, + metadata: Optional[Mapping[Any, Any]] = None, + directives: Optional[Sequence[object]] = (), + graphql_type: Optional[Any] = None, + node_converter: Callable[[_T], NodeType], +) -> ConnectionField: + ... + + +def connection( + resolver: Optional[_RESOLVER_TYPE[Any]] = None, + *, + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = dataclasses.MISSING, + default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, + metadata: Optional[Mapping[Any, Any]] = None, + directives: Optional[Sequence[object]] = (), + # This init parameter is used by pyright to determine whether this field + graphql_type: Optional[Any] = None, + node_converter: Optional[Callable[[Any], NodeType]] = None, + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init: Literal[True, False, None] = None, +) -> Any: + """Annotate a property or a method to create a relay connection field. + + Relay connections_ are mostly used for pagination purposes. This decorator + helps creating a complete relay endpoint that provides default arguments + and has a default implementation for the connection slicing. + + Note that when setting a resolver to this field, it is expected for this + resolver to return an iterable of the expected node type, not the connection + itself. That iterable will then be paginated accordingly. So, the main use + case for this is to provide a filtered iterable of nodes by using some custom + filter arguments. + + Examples: + Annotating something like this: + + >>> @strawberry.type + >>> class X: + ... some_node: relay.Connection[SomeType] = relay.connection( + ... description="ABC" + ... ) + ... + ... @relay.connection(description="ABC") + ... def get_some_nodes(self, age: int) -> Iterable[SomeType]: + ... ... + + Will produce a query like this: + + ``` + query { + someNode ( + before: String + after: String + first: String + after: String + age: Int + ) { + totalCount + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + edges { + cursor + node { + id + ... + } + } + } + } + ``` + + .. _Relay connections: + https://relay.dev/graphql/connections.htm + + """ + f = ConnectionField( + python_name=None, + graphql_name=name, + description=description, + type_annotation=StrawberryAnnotation.from_annotation(graphql_type), + is_subscription=is_subscription, + permission_classes=permission_classes or [], + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + metadata=metadata, + directives=directives or (), + node_converter=node_converter, + ) + if resolver is not None: + f = f(resolver) + return f diff --git a/strawberry/relay/types.py b/strawberry/relay/types.py new file mode 100644 index 0000000000..0c151c201c --- /dev/null +++ b/strawberry/relay/types.py @@ -0,0 +1,907 @@ +from __future__ import annotations + +import dataclasses +import inspect +import itertools +import sys +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + ClassVar, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from typing_extensions import ( + Annotated, + Literal, + Self, + TypeAlias, + get_args, + get_origin, + get_type_hints, +) + +from strawberry.field import field +from strawberry.lazy_type import LazyType +from strawberry.object_type import interface, type +from strawberry.private import StrawberryPrivate +from strawberry.scalars import ID # noqa: TCH001 +from strawberry.type import StrawberryContainer +from strawberry.types.types import TypeDefinition +from strawberry.utils.aio import aenumerate, aislice, resolve_awaitable +from strawberry.utils.inspect import in_async_context + +from .utils import from_base64, to_base64 + +if TYPE_CHECKING: + from strawberry.schema.schema import Schema + from strawberry.types.info import Info + from strawberry.utils.await_maybe import AwaitableOrValue + +_T = TypeVar("_T") +_R = TypeVar("_R") + +NodeIterableType: TypeAlias = Union[ + Iterator[_T], + Iterable[_T], + AsyncIterator[_T], + AsyncIterable[_T], +] +NodeType = TypeVar("NodeType", bound="Node") + +PREFIX = "arrayconnection" + + +class GlobalIDValueError(ValueError): + """GlobalID value error, usually related to parsing or serialization.""" + + +@dataclasses.dataclass(order=True, frozen=True) +class GlobalID: + """Global ID for relay types. + + Different from `strawberry.ID`, this ID wraps the original object ID in a string + that contains both its GraphQL type name and the ID itself, and encodes it + to a base64_ string. + + This object contains helpers to work with that, including method to retrieve + the python object type or even the encoded node itself. + + Attributes: + type_name: + The type name part of the id + node_id: + The node id part of the id + + .. _base64: + https://en.wikipedia.org/wiki/Base64 + + """ + + _nodes_cache: ClassVar[Dict[Tuple[Schema, str], Type[Node]]] = {} + + type_name: str + node_id: str + + def __post_init__(self): + if not isinstance(self.type_name, str): + raise GlobalIDValueError( + f"type_name is expected to be a string, found {repr(self.type_name)}" + ) + if not isinstance(self.node_id, str): + raise GlobalIDValueError( + f"node_id is expected to be a string, found {repr(self.node_id)}" + ) + + def __str__(self): + return to_base64(self.type_name, self.node_id) + + @classmethod + def from_id(cls, value: Union[str, ID]): + """Create a new GlobalID from parsing the given value. + + Args: + value: + The value to be parsed, as a base64 string in the + "TypeName:NodeID" format + + Returns: + An instance of GLobalID + + Raises: + GlobalIDValueError: + If the value is not in a GLobalID format + + """ + try: + type_name, node_id = from_base64(value) + except ValueError as e: + raise GlobalIDValueError(str(e)) from e + + return cls(type_name=type_name, node_id=node_id) + + def to_id(self) -> ID: + return ID(str(self)) + + def resolve_type(self, info: Info) -> Type[Node]: + """Resolve the internal type name to its type itself. + + Args: + info: + The strawberry execution info resolve the type name from + + Returns: + The resolved GraphQL type for the execution info + + """ + schema = info.schema + # Put the schema in the key so that different schemas can have different types + key = (schema, self.type_name) + origin = self._nodes_cache.get(key) + + if origin is None: + type_def = info.schema.get_type_by_name(self.type_name) + assert isinstance(type_def, TypeDefinition) + origin = ( + type_def.origin.resolve_type + if isinstance(origin, LazyType) + else type_def.origin + ) + assert issubclass(origin, Node) + self._nodes_cache[key] = origin + + return origin + + @overload + def resolve_node( + self, + info: Info, + *, + required: Literal[True] = ..., + ensure_type: Type[_T], + ) -> _T: + ... + + @overload + def resolve_node( + self, + info: Info, + *, + required: Literal[True], + ensure_type: None = ..., + ) -> Node: + ... + + @overload + def resolve_node( + self, + info: Info, + *, + required: bool = ..., + ensure_type: None = ..., + ) -> Optional[Node]: + ... + + def resolve_node(self, info, *, required=False, ensure_type=None) -> Any: + """Resolve the type name and node id info to the node itself. + + Tip: When you know the expected type, calling `ensure_type` should help + not only to enforce it, but also help with typing since it will know that, + if this function returns successfully, the retval should be of that + type and not `Node`. + + Args: + info: + The strawberry execution info resolve the type name from + required: + If the value is required to exist. Note that asking to ensure + the type automatically makes required true. + ensure_type: + Optionally check if the returned node is really an instance + of this type. + + Returns: + The resolved node + + Raises: + TypeError: + If ensure_type was provided and the type is not an instance of it + + """ + n_type = self.resolve_type(info) + node = n_type.resolve_node( + self.node_id, + info=info, + required=required or ensure_type is not None, + ) + + if node is not None and ensure_type is not None: + origin = get_origin(ensure_type) + if origin and origin is Union: + ensure_type = tuple(get_args(ensure_type)) + + if not isinstance(node, ensure_type): + raise TypeError(f"{ensure_type} expected, found {repr(node)}") + + return node + + @overload + async def aresolve_node( + self, + info: Info, + *, + required: Literal[True] = ..., + ensure_type: Type[_T], + ) -> _T: + ... + + @overload + async def aresolve_node( + self, + info: Info, + *, + required: Literal[True], + ensure_type: None = ..., + ) -> Node: + ... + + @overload + async def aresolve_node( + self, + info: Info, + *, + required: bool = ..., + ensure_type: None = ..., + ) -> Optional[Node]: + ... + + async def aresolve_node(self, info, *, required=False, ensure_type=None) -> Any: + """Resolve the type name and node id info to the node itself. + + Tip: When you know the expected type, calling `ensure_type` should help + not only to enforce it, but also help with typing since it will know that, + if this function returns successfully, the retval should be of that + type and not `Node`. + + Args: + info: + The strawberry execution info resolve the type name from + required: + If the value is required to exist. Note that asking to ensure + the type automatically makes required true. + ensure_type: + Optionally check if the returned node is really an instance + of this type. + + Returns: + The resolved node + + Raises: + TypeError: + If ensure_type was provided and the type is not an instance of it + + """ + n_type = self.resolve_type(info) + node = cast( + Awaitable[Node], + n_type.resolve_node( + self.node_id, + info=info, + required=required or ensure_type is not None, + ), + ) + + res = await node if node is not None else None + + if ensure_type is not None: + origin = get_origin(ensure_type) + if origin and origin is Union: + ensure_type = tuple(get_args(ensure_type)) + + if not isinstance(res, ensure_type): + raise TypeError(f"{ensure_type} expected, found {repr(res)}") + + return res + + +class NodeIDPrivate(StrawberryPrivate): + ... + + +NodeID: TypeAlias = Annotated[_T, NodeIDPrivate()] + + +@interface(description="An object with a Globally Unique ID") +class Node: + """Node interface for GraphQL types. + + Subclasses must type the id field using `NodeID`. It will be private to the + schema because it will be converted to a global ID and exposed as `id: GlobalID!` + + The following methods can also be implemented: + resolve_id: + (Optional) Called to resolve the node's id. Can be overriden to + customize how the id is retrieved (e.g. in case you don't want + to define a `NodeID` field) + resolve_nodes: + Called to retrieve an iterable of node given their ids + resolve_node: + (Optional) Called to retrieve a node given its id. If not defined + the default implementation will call `.resolve_nodes` with that + single node id. + + Example: + >>> @strawberry.type + ... class Fruit(Node): + ... id: NodeID[int] + ... name: str + ... + ... @classmethod + ... def resolve_nodes(cls, *, info, node_ids, required=False): + ... # Return an iterable of fruits in here + ... ... + + """ + + _id_attr: ClassVar[str] + + def __init_subclass__(cls, **kwargs): + annotations = get_type_hints(cls, include_extras=True) + candidates = [ + attr + for attr, annotation in annotations.items() + if ( + get_origin(annotation) is Annotated + and any( + isinstance(argument, NodeIDPrivate) + for argument in get_args(annotation) + ) + ) + ] + + if len(candidates) == 0: + raise TypeError(f"No field annotated with `NodeID` found on {cls!r}") + if len(candidates) > 1: + raise TypeError( + f"More than one field annotated with `NodeID` found on {cls!r}" + ) + + cls._id_attr = candidates[0] + + @field(name="id", description="The Globally Unique ID of this object") + @classmethod + def _id(cls, root: Node, info: Info) -> ID: + # FIXME: We want to support both integration objects that doesn't define + # a resolve_id and also the ones that does override it. Is there a better + # way of handling this? + if isinstance(root, Node): + resolve_id = root.__class__.resolve_id + else: + parent_type = info._raw_info.parent_type + type_def = info.schema.get_type_by_name(parent_type.name) + assert isinstance(type_def, TypeDefinition) + resolve_id = type_def.origin.resolve_id + + node_id = resolve_id(root, info=info) + resolve_typename = ( + root.__class__.resolve_typename + if isinstance(root, Node) + else cls.resolve_typename + ) + type_name = resolve_typename(root, info) + assert type_name + + if inspect.isawaitable(node_id): + return cast( + ID, + resolve_awaitable( + node_id, + lambda resolved: GlobalID( + type_name=type_name, + node_id=str(resolved), + ).to_id(), + ), + ) + + # If node_id is not str, GlobalID will raise an error for us + return GlobalID(type_name=type_name, node_id=str(node_id)).to_id() + + @classmethod + def resolve_id( + cls, + root: Self, + *, + info: Info, + ) -> AwaitableOrValue[str]: + """Resolve the node id. + + By default this will return `getattr(root, )`, where + is the field typed with `NodeID`. + + You can override this method to provide a custom implementation. + + Args: + info: + The strawberry execution info resolve the type name from + root: + The node to resolve + + Returns: + The resolved id (which is expected to be str) + + """ + return getattr(root, cls._id_attr) + + @classmethod + def resolve_typename(cls, root: Self, info: Info): + return info.path.typename + + @overload + @classmethod + def resolve_nodes( + cls, + *, + info: Info, + ) -> AwaitableOrValue[Iterable[Self]]: + ... + + @overload + @classmethod + def resolve_nodes( + cls, + *, + info: Info, + node_ids: Iterable[str], + required: Literal[True], + ) -> AwaitableOrValue[Iterable[Self]]: + ... + + @overload + @classmethod + def resolve_nodes( + cls, + *, + info: Info, + node_ids: Optional[Iterable[str]] = None, + required: Literal[False] = ..., + ) -> AwaitableOrValue[Iterable[Optional[Self]]]: + ... + + @overload + @classmethod + def resolve_nodes( + cls, + *, + info: Info, + node_ids: Optional[Iterable[str]] = None, + required: bool, + ) -> Union[ + AwaitableOrValue[Iterable[Self]], + AwaitableOrValue[Iterable[Optional[Self]]], + ]: + ... + + @classmethod + def resolve_nodes( + cls, + *, + info: Info, + node_ids: Optional[Iterable[str]] = None, + required: bool = False, + ): + """Resolve a list of nodes. + + This method *should* be defined by anyone implementing the `Node` interface. + + Args: + info: + The strawberry execution info resolve the type name from + node_ids: + Optional list of ids that, when provided, should be used to filter + the results to only contain the nodes of those ids. When empty, + all nodes of this type shall be returned. + required: + If `True`, all `node_ids` requested must exist. If they don't, + an error must be raised. If `False`, missing nodes should be + returned as `None`. It only makes sense when passing a list of + `node_ids`, otherwise it will should ignored. + + Returns: + An iterable of resolved nodes. + + """ + raise NotImplementedError + + @overload + @classmethod + def resolve_node( + cls, + node_id: str, + *, + info: Info, + required: Literal[True], + ) -> AwaitableOrValue[Self]: + ... + + @overload + @classmethod + def resolve_node( + cls, + node_id: str, + *, + info: Info, + required: Literal[False] = ..., + ) -> AwaitableOrValue[Optional[Self]]: + ... + + @overload + @classmethod + def resolve_node( + cls, + node_id: str, + *, + info: Info, + required: bool, + ) -> AwaitableOrValue[Optional[Self]]: + ... + + @classmethod + def resolve_node( + cls, + node_id: str, + *, + info: Info, + required: bool = False, + ) -> AwaitableOrValue[Optional[Self]]: + """Resolve a node given its id. + + This method is a convenience method that calls `resolve_nodes` for + a single node id. + + Args: + info: + The strawberry execution info resolve the type name from + node_id: + The id of the node to be retrieved + required: + if the node is required or not to exist. If not, then None + should be returned if it doesn't exist. Otherwise an exception + should be raised. + + Returns: + The resolved node or None if it was not found + + """ + retval = cls.resolve_nodes(info=info, node_ids=[node_id], required=required) + + if inspect.isawaitable(retval): + return resolve_awaitable(retval, lambda resolved: next(iter(resolved))) + + return next(iter(cast(Iterable[Self], retval))) + + +@type(description="Information to aid in pagination.") +class PageInfo: + """Information to aid in pagination. + + Attributes: + has_next_page: + When paginating forwards, are there more items? + has_previous_page: + When paginating backwards, are there more items? + start_cursor: + When paginating backwards, the cursor to continue + end_cursor: + When paginating forwards, the cursor to continue + + """ + + has_next_page: bool = field( + description="When paginating forwards, are there more items?", + ) + has_previous_page: bool = field( + description="When paginating backwards, are there more items?", + ) + start_cursor: Optional[str] = field( + description="When paginating backwards, the cursor to continue.", + ) + end_cursor: Optional[str] = field( + description="When paginating forwards, the cursor to continue.", + ) + + +@type(description="An edge in a connection.") +class Edge(Generic[NodeType]): + """An edge in a connection. + + Attributes: + cursor: + A cursor for use in pagination + node: + The item at the end of the edge + + """ + + cursor: str = field( + description="A cursor for use in pagination", + ) + node: NodeType = field( + description="The item at the end of the edge", + ) + + @classmethod + def from_node(cls, node: NodeType, *, cursor: Any = None): + return cls(cursor=to_base64(PREFIX, cursor), node=node) + + +@type(description="A connection to a list of items.") +class Connection(Generic[NodeType]): + """A connection to a list of items. + + Attributes: + page_info: + Pagination data for this connection + edges: + Contains the nodes in this connection + + """ + + page_info: PageInfo = field( + description="Pagination data for this connection", + ) + edges: List[Edge[NodeType]] = field( + description="Contains the nodes in this connection", + ) + + @overload + @classmethod + def from_nodes( + cls, + nodes: Union[ + Iterator[NodeType], + AsyncIterator[NodeType], + Iterable[NodeType], + AsyncIterable[NodeType], + ], + *, + info: Info, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + **kwargs, + ) -> Self: + ... + + @overload + @classmethod + def from_nodes( + cls, + nodes: NodeIterableType[_T], + *, + info: Info, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + node_converter: Callable[[_T], NodeType], + **kwargs, + ) -> Self: + ... + + @classmethod + def from_nodes( + cls, + nodes: Union[NodeIterableType[_T], NodeIterableType[NodeType]], + *, + info: Info, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + node_converter: Optional[Callable[[_T], NodeType]] = None, + **kwargs, + ): + """Resolve a connection from the list of nodes. + + This uses the described Relay Pagination algorithm_ + + Args: + info: + The strawberry execution info resolve the type name from + nodes: + An iterable of nodes to transform to a connection + before: + Returns the items in the list that come before the specified cursor + after: + Returns the items in the list that come after the specified cursor + first: + Returns the first n items from the list + last: + Returns the items in the list that come after the specified cursor + + Returns: + The resolved `Connection` + + .. _Relay Pagination algorithm: + https://relay.dev/graphql/connections.htm#sec-Pagination-algorithm + + """ + max_results = info.schema.config.relay_max_results + start = 0 + end: Optional[int] = None + + if after: + after_type, after_parsed = from_base64(after) + assert after_type == PREFIX + start = int(after_parsed) + 1 + if before: + before_type, before_parsed = from_base64(before) + assert before_type == PREFIX + end = int(before_parsed) + + if isinstance(first, int): + if first < 0: + raise ValueError("Argument 'first' must be a non-negative integer.") + + if first > max_results: + raise ValueError( + f"Argument 'first' cannot be higher than {max_results}." + ) + + if end is not None: + start = max(0, end - 1) + + end = start + first + if isinstance(last, int): + if last < 0: + raise ValueError("Argument 'last' must be a non-negative integer.") + + if last > max_results: + raise ValueError( + f"Argument 'last' cannot be higher than {max_results}." + ) + + if end is not None: + start = max(start, end - last) + else: + end = sys.maxsize + + if end is None: + end = max_results + + expected = end - start if end != sys.maxsize else None + # Overfetch by 1 to check if we have a next result + overfetch = end + 1 if end != sys.maxsize else end + + type_def = cast(TypeDefinition, cls._type_definition) # type:ignore + field_def = type_def.get_field("edges") + assert field_def + + field = field_def.type + while isinstance(field, StrawberryContainer): + field = field.of_type + + edge_class = cast(Edge[NodeType], field) + + if isinstance(nodes, (AsyncIterator, AsyncIterable)) and in_async_context(): + + async def resolver(): + try: + iterator = cast( + Union[AsyncIterator[NodeType], AsyncIterable[NodeType]], + cast(Sequence, nodes)[start:overfetch], + ) + except TypeError: + # FIXME: Why mypy isn't narrowing this based on the if above? + assert isinstance(nodes, (AsyncIterator, AsyncIterable)) + iterator = aislice( + nodes, # type: ignore[arg-type] + start, + overfetch, + ) + + assert isinstance(iterator, (AsyncIterator, AsyncIterable)) + edges: List[Edge] = [ + edge_class.from_node( + ( + node_converter(cast(_T, v)) + if node_converter is not None + else cast(NodeType, v) # type: ignore[redundant-cast] + ), + cursor=start + i, + ) + async for i, v in aenumerate(iterator) + ] + + has_previous_page = start > 0 + if expected is not None and len(edges) == expected + 1: + # Remove the overfetched result + edges = edges[:-1] + has_next_page = True + elif end == sys.maxsize: + # Last was asked without any after/before + assert last is not None + original_len = len(edges) + edges = edges[-last:] + has_next_page = False + has_previous_page = len(edges) != original_len + else: + has_next_page = False + + return cls( + edges=edges, + page_info=PageInfo( + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + has_previous_page=has_previous_page, + has_next_page=has_next_page, + ), + ) + + return resolver() + + try: + iterator = cast( + Union[Iterator[NodeType], Iterable[NodeType]], + cast(Sequence, nodes)[start:overfetch], + ) + except TypeError: + assert isinstance(nodes, (Iterable, Iterator)) + iterator = itertools.islice( + nodes, # type: ignore[arg-type] + start, + overfetch, + ) + + edges = [ + edge_class.from_node( + ( + node_converter(cast(_T, v)) + if node_converter is not None + else cast(NodeType, v) # type: ignore[redundant-cast] + ), + cursor=start + i, + ) + for i, v in enumerate(iterator) + ] + + has_previous_page = start > 0 + if expected is not None and len(edges) == expected + 1: + # Remove the overfetched result + edges = edges[:-1] + has_next_page = True + elif end == sys.maxsize: + # Last was asked without any after/before + assert last is not None + original_len = len(edges) + edges = edges[-last:] + has_next_page = False + has_previous_page = len(edges) != original_len + else: + has_next_page = False + + return cls( + edges=edges, + page_info=PageInfo( + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + has_previous_page=has_previous_page, + has_next_page=has_next_page, + ), + ) diff --git a/strawberry/relay/utils.py b/strawberry/relay/utils.py new file mode 100644 index 0000000000..758c8c4035 --- /dev/null +++ b/strawberry/relay/utils.py @@ -0,0 +1,63 @@ +import base64 +from typing import Any, Tuple, Union +from typing_extensions import assert_never + +from strawberry.types.types import TypeDefinition + + +def from_base64(value: str) -> Tuple[str, str]: + """Parse the base64 encoded relay value. + + Args: + value: + The value to be parsed + + Returns: + A tuple of (TypeName, NodeID). + + Raises: + ValueError: + If the value is not in the expected format + + """ + try: + res = base64.b64decode(value.encode()).decode().split(":") + except Exception as e: + raise ValueError(str(e)) from e + + if len(res) != 2: + raise ValueError(f"{res} expected to contain only 2 items") + + return res[0], res[1] + + +def to_base64(type_: Union[str, type, TypeDefinition], node_id: Any) -> str: + """Encode the type name and node id to a base64 string. + + Args: + type_: + The GraphQL type, type definition or type name. + node_id: + The node id itself + + Returns: + A tuple of (TypeName, NodeID). + + Raises: + ValueError: + If the value is not a valid GraphQL type or name + + """ + try: + if isinstance(type_, str): + type_name = type_ + elif isinstance(type_, TypeDefinition): + type_name = type_.name + elif isinstance(type_, type): + type_name = type_._type_definition.name # type:ignore + else: # pragma: no cover + assert_never(type_) + except Exception as e: + raise ValueError(f"{type_} is not a valid GraphQL type or name") from e + + return base64.b64encode(f"{type_name}:{node_id}".encode()).decode() diff --git a/strawberry/schema/config.py b/strawberry/schema/config.py index c23b557e58..18afbb7940 100644 --- a/strawberry/schema/config.py +++ b/strawberry/schema/config.py @@ -11,6 +11,7 @@ class StrawberryConfig: auto_camel_case: InitVar[bool] = None # pyright: reportGeneralTypeIssues=false name_converter: NameConverter = field(default_factory=NameConverter) default_resolver: Callable[[Any, str], object] = getattr + relay_max_results: int = 100 def __post_init__( self, diff --git a/strawberry/types/type_resolver.py b/strawberry/types/type_resolver.py index 6cd13e4f51..6ec72a492d 100644 --- a/strawberry/types/type_resolver.py +++ b/strawberry/types/type_resolver.py @@ -2,7 +2,8 @@ import dataclasses import sys -from typing import TYPE_CHECKING, Dict, List, Type, TypeVar +from typing import Dict, List, Type, TypeVar +from typing_extensions import get_args, get_origin from strawberry.annotation import StrawberryAnnotation from strawberry.exceptions import ( @@ -10,12 +11,48 @@ FieldWithResolverAndDefaultValueError, PrivateStrawberryFieldError, ) +from strawberry.field import StrawberryField from strawberry.private import is_private from strawberry.unset import UNSET from strawberry.utils.inspect import get_specialized_type_var_map -if TYPE_CHECKING: - from strawberry.field import StrawberryField + +def _get_field_for_type(type_: Type) -> Type[StrawberryField]: + # Deferred import to avoid import cycles + from strawberry.relay import Connection, ConnectionField, Node, NodeField + + # Supoort for "foo: Node" + if isinstance(type_, type) and issubclass(type_, Node): + return NodeField + + # Support for "foo: SpecializedConnection" + if isinstance(type_, type) and issubclass(type_, Connection): + return ConnectionField + + type_origin = get_origin(type_) + + # Support for "foo: Connection[Foo]" + if isinstance(type_origin, type) and issubclass( + type_origin, + Connection, + ): + return ConnectionField + + type_args = get_args(type_) + + # Support for "foo: Optional[Node]" and "foo: List[Node]" + if any(isinstance(arg, type) and issubclass(arg, Node) for arg in type_args): + return NodeField + + # Support for "foo: List[Optional[Node]]" + if isinstance(type_origin, type) and issubclass(type_origin, List): + if any( + isinstance(arg, type) and issubclass(arg, Node) + for arg in get_args(type_args[0]) + ): + return NodeField + + return StrawberryField def _get_fields(cls: Type) -> List[StrawberryField]: @@ -51,9 +88,6 @@ class if one is not set by either using an explicit strawberry.field(name=...) o passing a named function (i.e. not an anonymous lambda) to strawberry.field (typically as a decorator). """ - # Deferred import to avoid import cycles - from strawberry.field import StrawberryField - fields: Dict[str, StrawberryField] = {} # before trying to find any fields, let's first add the fields defined in @@ -157,7 +191,8 @@ class if one is not set by either using an explicit strawberry.field(name=...) o ) # Create a StrawberryField, for fields of Types #1 and #2a - field = StrawberryField( + field_class = _get_field_for_type(field_type) + field = field_class( python_name=field.name, graphql_name=None, type_annotation=StrawberryAnnotation( diff --git a/strawberry/utils/aio.py b/strawberry/utils/aio.py new file mode 100644 index 0000000000..ce0bd91427 --- /dev/null +++ b/strawberry/utils/aio.py @@ -0,0 +1,65 @@ +import itertools +import sys +from typing import ( + Any, + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, + Union, +) + +itertools.islice +_T = TypeVar("_T") +_R = TypeVar("_R") + + +async def aenumerate( + iterable: Union[AsyncIterator[_T], AsyncIterable[_T]], +) -> AsyncIterator[Tuple[int, _T]]: + """Async version of enumerate.""" + i = 0 + async for element in iterable: + yield i, element + i += 1 + + +async def aislice( + aiterable: Union[AsyncIterator[_T], AsyncIterable[_T]], + start: Optional[int] = None, + stop: Optional[int] = None, + step: Optional[int] = None, +) -> AsyncIterator[_T]: + """Async version of itertools.islice.""" + # This is based on + it = iter(range(start or 0, stop or sys.maxsize, step or 1)) + try: + nexti = next(it) + except StopIteration: + return + + try: + async for i, element in aenumerate(aiterable): + if i == nexti: + yield element + nexti = next(it) + except StopIteration: + return + + +async def asyncgen_to_list(generator: AsyncGenerator[_T, Any]) -> List[_T]: + """Convert an async generator to a list.""" + return [element async for element in generator] + + +async def resolve_awaitable( + awaitable: Awaitable[_T], + callback: Callable[[_T], _R], +) -> _R: + """Resolves an awaitable object and calls a callback with the resolved value.""" + return callback(await awaitable) diff --git a/strawberry/utils/inspect.py b/strawberry/utils/inspect.py index 2a749c8669..b11d882c00 100644 --- a/strawberry/utils/inspect.py +++ b/strawberry/utils/inspect.py @@ -1,9 +1,21 @@ +import asyncio import inspect from functools import lru_cache from typing import Any, Callable, Dict, Optional, TypeVar, Union, overload from typing_extensions import Literal, get_args +def in_async_context() -> bool: + # Based on the way django checks if there's an event loop in the current thread + # https://github.com/django/django/blob/main/django/utils/asyncio.py + try: + asyncio.get_running_loop() + except RuntimeError: + return False + else: + return True + + @lru_cache(maxsize=250) def get_func_args(func: Callable[[Any], Any]): """Returns a list of arguments for the function""" diff --git a/strawberry/utils/typing.py b/strawberry/utils/typing.py index ac5f4fa0e1..fa3ab229e5 100644 --- a/strawberry/utils/typing.py +++ b/strawberry/utils/typing.py @@ -1,7 +1,7 @@ import sys -from collections.abc import AsyncGenerator from typing import ( # type: ignore Any, + AsyncGenerator, Callable, ClassVar, Generic, diff --git a/tests/pyright/test_relay.py b/tests/pyright/test_relay.py new file mode 100644 index 0000000000..d4d73ef91d --- /dev/null +++ b/tests/pyright/test_relay.py @@ -0,0 +1,261 @@ +from .utils import Result, requires_pyright, run_pyright, skip_on_windows + +pytestmark = [skip_on_windows, requires_pyright] + + +CODE = """ +from typing import ( + Any, + AsyncIterator, + AsyncGenerator, + AsyncIterable, + Generator, + Iterable, + Iterator, + List, + Optional, + Union, +) + +import strawberry +from strawberry.types import Info +from typing_extensions import Self + + +@strawberry.type +class Fruit(strawberry.relay.Node): + id: strawberry.relay.NodeID[int] + name: str + color: str + + +@strawberry.type +class FruitCustomPaginationConnection(strawberry.relay.Connection[Fruit]): + @strawberry.field + def something(self) -> str: + return "foobar" + + @classmethod + def from_nodes( + cls, + nodes: Union[ + Iterator[Fruit], + AsyncIterator[Fruit], + Iterable[Fruit], + AsyncIterable[Fruit], + ], + *, + info: Optional[Info[Any, Any]] = None, + total_count: Optional[int] = None, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + **kwargs: Any, + ) -> Self: + ... + + +class FruitAlike: + ... + + +def fruit_converter(fruit_alike: FruitAlike) -> Fruit: + ... + + +@strawberry.type +class Query: + node: strawberry.relay.Node + nodes: List[strawberry.relay.Node] + node_optional: Optional[strawberry.relay.Node] + nodes_optional: List[Optional[strawberry.relay.Node]] + fruits: strawberry.relay.Connection[Fruit] + fruits_conn: strawberry.relay.Connection[Fruit] = strawberry.relay.connection() + fruits_custom_pagination: FruitCustomPaginationConnection + + @strawberry.relay.connection + def fruits_custom_resolver( + self, + info: Info[Any, Any], + name_endswith: Optional[str] = None, + ) -> List[Fruit]: + ... + + @strawberry.relay.connection(node_converter=fruit_converter) + def fruits_custom_resolver_with_node_converter( + self, + info: Info[Any, Any], + name_endswith: Optional[str] = None, + ) -> List[FruitAlike]: + ... + + @strawberry.relay.connection + def fruits_custom_resolver_iterator( + self, + info: Info[Any, Any], + name_endswith: Optional[str] = None, + ) -> Iterator[Fruit]: + ... + + @strawberry.relay.connection + def fruits_custom_resolver_iterable( + self, + info: Info[Any, Any], + name_endswith: Optional[str] = None, + ) -> Iterable[Fruit]: + ... + + @strawberry.relay.connection + def fruits_custom_resolver_generator( + self, + info: Info[Any, Any], + name_endswith: Optional[str] = None, + ) -> Generator[Fruit, None, None]: + ... + + @strawberry.relay.connection + async def fruits_custom_resolver_async_iterator( + self, + info: Info[Any, Any], + name_endswith: Optional[str] = None, + ) -> AsyncIterator[Fruit]: + ... + + @strawberry.relay.connection + async def fruits_custom_resolver_async_iterable( + self, + info: Info[Any, Any], + name_endswith: Optional[str] = None, + ) -> AsyncIterable[Fruit]: + ... + + @strawberry.relay.connection + async def fruits_custom_resolver_async_generator( + self, + info: Info[Any, Any], + name_endswith: Optional[str] = None, + ) -> AsyncGenerator[Fruit, None]: + ... + +reveal_type(Query.node) +reveal_type(Query.nodes) +reveal_type(Query.node_optional) +reveal_type(Query.nodes_optional) +reveal_type(Query.fruits) +reveal_type(Query.fruits_conn) +reveal_type(Query.fruits_custom_pagination) +reveal_type(Query.fruits_custom_resolver) +reveal_type(Query.fruits_custom_resolver_with_node_converter) +reveal_type(Query.fruits_custom_resolver_iterator) +reveal_type(Query.fruits_custom_resolver_iterable) +reveal_type(Query.fruits_custom_resolver_generator) +reveal_type(Query.fruits_custom_resolver_async_iterator) +reveal_type(Query.fruits_custom_resolver_async_iterable) +reveal_type(Query.fruits_custom_resolver_async_generator) +""" + + +def test_pyright(): + results = run_pyright(CODE) + + assert results == [ + Result( + type="information", + message='Type of "Query.node" is "Node"', + line=136, + column=13, + ), + Result( + type="information", + message='Type of "Query.nodes" is "List[Node]"', + line=137, + column=13, + ), + Result( + type="information", + message='Type of "Query.node_optional" is "Node | None"', + line=138, + column=13, + ), + Result( + type="information", + message='Type of "Query.nodes_optional" is "List[Node | None]"', + line=139, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits" is "Connection[Fruit]"', + line=140, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_conn" is "Connection[Fruit]"', + line=141, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_custom_pagination" is ' + '"FruitCustomPaginationConnection"', + line=142, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_custom_resolver" is "ConnectionField"', + line=143, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_custom_resolver_with_node_converter" is ' + '"Any"', + line=144, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_custom_resolver_iterator" is ' + '"ConnectionField"', + line=145, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_custom_resolver_iterable" is ' + '"ConnectionField"', + line=146, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_custom_resolver_generator" is ' + '"ConnectionField"', + line=147, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_custom_resolver_async_iterator" is ' + '"ConnectionField"', + line=148, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_custom_resolver_async_iterable" is ' + '"ConnectionField"', + line=149, + column=13, + ), + Result( + type="information", + message='Type of "Query.fruits_custom_resolver_async_generator" is ' + '"ConnectionField"', + line=150, + column=13, + ), + ] diff --git a/tests/relay/__init__.py b/tests/relay/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/relay/schema.gql b/tests/relay/schema.gql new file mode 100644 index 0000000000..9e554e3a8e --- /dev/null +++ b/tests/relay/schema.gql @@ -0,0 +1,263 @@ +type Fruit implements Node { + """The Globally Unique ID of this object""" + id: ID! + name: String! + color: String! +} + +type FruitAsync implements Node { + """The Globally Unique ID of this object""" + id: ID! + name: String! + color: String! +} + +"""A connection to a list of items.""" +type FruitAsyncConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + + """Contains the nodes in this connection""" + edges: [FruitAsyncEdge!]! +} + +"""An edge in a connection.""" +type FruitAsyncEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: FruitAsync! +} + +"""A connection to a list of items.""" +type FruitConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + + """Contains the nodes in this connection""" + edges: [FruitEdge!]! +} + +"""An edge in a connection.""" +type FruitEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Fruit! +} + +type FruitFruitCustomPaginationConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + + """Contains the nodes in this connection""" + edges: [FruitEdge!]! + something: String! +} + +"""An object with a Globally Unique ID""" +interface Node { + """The Globally Unique ID of this object""" + id: ID! +} + +"""Information to aid in pagination.""" +type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String +} + +type Query { + node( + """The ID of the object.""" + id: ID! + ): Node! + nodes( + """The IDs of the objects.""" + ids: [ID!]! + ): [Node!]! + nodeOptional( + """The ID of the object.""" + id: ID! + ): Node + nodesOptional( + """The IDs of the objects.""" + ids: [ID!]! + ): [Node]! + fruits( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + ): FruitConnection! + fruitsAsync( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + ): FruitAsyncConnection! + fruitsCustomPagination( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + ): FruitFruitCustomPaginationConnection! + fruitsCustomResolver( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + nameEndswith: String = null + ): FruitConnection! + fruitsCustomResolverWithNodeConverter( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + nameEndswith: String = null + ): FruitConnection! + fruitsCustomResolverWithNodeConverterForwardRef( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + nameEndswith: String = null + ): FruitConnection! + fruitsCustomResolverIterator( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + nameEndswith: String = null + ): FruitConnection! + fruitsCustomResolverIterable( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + nameEndswith: String = null + ): FruitConnection! + fruitsCustomResolverGenerator( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + nameEndswith: String = null + ): FruitConnection! + fruitsCustomResolverAsyncIterable( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + nameEndswith: String = null + ): FruitConnection! + fruitsCustomResolverAsyncIterator( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + nameEndswith: String = null + ): FruitConnection! + fruitsCustomResolverAsyncGenerator( + """Returns the items in the list that come before the specified cursor.""" + before: String = null + + """Returns the items in the list that come after the specified cursor.""" + after: String = null + + """Returns the first n items from the list.""" + first: Int = null + + """Returns the items in the list that come after the specified cursor.""" + last: Int = null + nameEndswith: String = null + ): FruitConnection! +} diff --git a/tests/relay/schema.py b/tests/relay/schema.py new file mode 100644 index 0000000000..41fb47916c --- /dev/null +++ b/tests/relay/schema.py @@ -0,0 +1,262 @@ +from collections import namedtuple +from typing import ( + Any, + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Generator, + Iterable, + Iterator, + List, + Optional, +) + +import strawberry +from strawberry.relay.utils import to_base64 +from strawberry.types import Info + + +@strawberry.type +class Fruit(strawberry.relay.Node): + id: strawberry.relay.NodeID[int] + name: str + color: str + + @classmethod + def resolve_nodes( + cls, + *, + info: Info, + node_ids: Optional[Iterable[str]] = None, + required: bool = False, + ): + if node_ids is not None: + return [fruits[nid] if required else fruits.get(nid) for nid in node_ids] + + return fruits.values() + + +@strawberry.type +class FruitAsync(strawberry.relay.Node): + id: strawberry.relay.NodeID[int] + name: str + color: str + + @classmethod + async def resolve_nodes( + cls, + *, + info: Optional[Info] = None, + node_ids: Optional[Iterable[str]] = None, + required: bool = False, + ): + if node_ids is not None: + return [ + fruits_async[nid] if required else fruits_async.get(nid) + for nid in node_ids + ] + + return fruits_async.values() + + +@strawberry.type +class FruitCustomPaginationConnection(strawberry.relay.Connection[Fruit]): + @strawberry.field + def something(self) -> str: + return "foobar" + + @classmethod + def from_nodes( + cls, + nodes: Iterable[Fruit], + *, + info: Optional[Info] = None, + total_count: Optional[int] = None, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + **kwargs: Any, + ): + edges_mapping = { + to_base64("fruit_name", n.name): strawberry.relay.Edge( + node=n, + cursor=to_base64("fruit_name", n.name), + ) + for n in sorted(nodes, key=lambda f: f.name) + } + edges = list(edges_mapping.values()) + first_edge = edges[0] if edges else None + last_edge = edges[-1] if edges else None + + if after is not None: + after_edge_idx = edges.index(edges_mapping[after]) + edges = [e for e in edges if edges.index(e) > after_edge_idx] + + if before is not None: + before_edge_idx = edges.index(edges_mapping[before]) + edges = [e for e in edges if edges.index(e) < before_edge_idx] + + if first is not None: + edges = edges[:first] + + if last is not None: + edges = edges[-last:] + + return cls( + edges=edges, + page_info=strawberry.relay.PageInfo( + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + has_previous_page=first_edge is not None + and bool(edges) + and edges[0] != first_edge, + has_next_page=last_edge is not None + and bool(edges) + and edges[-1] != last_edge, + ), + ) + + +fruits = { + str(f.id): f + for f in [ + Fruit(id=1, name="Banana", color="yellow"), + Fruit(id=2, name="Apple", color="red"), + Fruit(id=3, name="Pineapple", color="yellow"), + Fruit(id=4, name="Grape", color="purple"), + Fruit(id=5, name="Orange", color="orange"), + ] +} +fruits_async = { + k: FruitAsync(id=v.id, name=v.name, color=v.color) for k, v in fruits.items() +} + + +FruitAlike = namedtuple("FruitAlike", ["id", "name", "color"]) + + +def fruit_converter(fruit_alike: FruitAlike) -> Fruit: + return Fruit( + id=fruit_alike.id, + name=fruit_alike.name, + color=fruit_alike.color, + ) + + +def fruit_converter_forward_ref(fruit_alike: FruitAlike) -> "Fruit": + return Fruit( + id=fruit_alike.id, + name=fruit_alike.name, + color=fruit_alike.color, + ) + + +@strawberry.type +class Query: + node: strawberry.relay.Node + nodes: List[strawberry.relay.Node] + node_optional: Optional[strawberry.relay.Node] + nodes_optional: List[Optional[strawberry.relay.Node]] + fruits: strawberry.relay.Connection[Fruit] + fruits_async: strawberry.relay.Connection[FruitAsync] + fruits_custom_pagination: FruitCustomPaginationConnection + + @strawberry.relay.connection + def fruits_custom_resolver( + self, + info: Info, + name_endswith: Optional[str] = None, + ) -> List[Fruit]: + return [ + f + for f in fruits.values() + if name_endswith is None or f.name.endswith(name_endswith) + ] + + @strawberry.relay.connection(node_converter=fruit_converter) + def fruits_custom_resolver_with_node_converter( + self, + info: Info, + name_endswith: Optional[str] = None, + ) -> List[FruitAlike]: + return [ + FruitAlike(f.id, f.name, f.color) + for f in fruits.values() + if name_endswith is None or f.name.endswith(name_endswith) + ] + + @strawberry.relay.connection(node_converter=fruit_converter_forward_ref) + def fruits_custom_resolver_with_node_converter_forward_ref( + self, + info: Info, + name_endswith: Optional[str] = None, + ) -> List[FruitAlike]: + return [ + FruitAlike(f.id, f.name, f.color) + for f in fruits.values() + if name_endswith is None or f.name.endswith(name_endswith) + ] + + @strawberry.relay.connection + def fruits_custom_resolver_iterator( + self, + info: Info, + name_endswith: Optional[str] = None, + ) -> Iterator[Fruit]: + for f in fruits.values(): + if name_endswith is None or f.name.endswith(name_endswith): + yield f + + @strawberry.relay.connection + def fruits_custom_resolver_iterable( + self, + info: Info, + name_endswith: Optional[str] = None, + ) -> Iterator[Fruit]: + for f in fruits.values(): + if name_endswith is None or f.name.endswith(name_endswith): + yield f + + @strawberry.relay.connection + def fruits_custom_resolver_generator( + self, + info: Info, + name_endswith: Optional[str] = None, + ) -> Generator[Fruit, None, None]: + for f in fruits.values(): + if name_endswith is None or f.name.endswith(name_endswith): + yield f + + @strawberry.relay.connection + async def fruits_custom_resolver_async_iterable( + self, + info: Info, + name_endswith: Optional[str] = None, + ) -> AsyncIterable[Fruit]: + for f in fruits.values(): + if name_endswith is None or f.name.endswith(name_endswith): + yield f + + @strawberry.relay.connection + async def fruits_custom_resolver_async_iterator( + self, + info: Info, + name_endswith: Optional[str] = None, + ) -> AsyncIterator[Fruit]: + for f in fruits.values(): + if name_endswith is None or f.name.endswith(name_endswith): + yield f + + @strawberry.relay.connection + async def fruits_custom_resolver_async_generator( + self, + info: Info, + name_endswith: Optional[str] = None, + ) -> AsyncGenerator[Fruit, None]: + for f in fruits.values(): + if name_endswith is None or f.name.endswith(name_endswith): + yield f + + +schema = strawberry.Schema(query=Query) diff --git a/tests/relay/test_exceptions.py b/tests/relay/test_exceptions.py new file mode 100644 index 0000000000..e0509298b8 --- /dev/null +++ b/tests/relay/test_exceptions.py @@ -0,0 +1,72 @@ +from typing import Iterable, Iterator, List + +import pytest + +import strawberry +from strawberry.exceptions.missing_return_annotation import MissingReturnAnnotationError +from strawberry.relay.exceptions import ( + RelayWrongAnnotationError, + RelayWrongNodeResolverAnnotationError, +) +from tests.relay.schema import Fruit + + +@strawberry.type +class NonNodeType: + foo: str + + +@pytest.mark.raises_strawberry_exception( + MissingReturnAnnotationError, + match=( + 'Return annotation missing for field "custom_resolver", ' + "did you forget to add it?" + ), +) +def test_raises_error_on_connection_missing_annotation(): + @strawberry.type + class Query: + @strawberry.relay.connection # type: ignore + def custom_resolver(self): + ... + + +@pytest.mark.raises_strawberry_exception( + RelayWrongAnnotationError, + match=( + 'Unable to determine the connection type of field "custom_resolver". ' + r"It should be annotated with a return value of `List\[\]`, " + r"`Iterable\[\]`, `Iterator\[\]`, " + r"`AsyncIterable\[\]` or `AsyncIterator\[\]`" + ), +) +@pytest.mark.parametrize( + "annotation", + [Fruit, List[int], List[object], Iterable[int], Iterator[int], List[NonNodeType]], +) +def test_raises_error_on_connection_with_wrong_annotation(annotation): + @strawberry.type + class Query: + @strawberry.relay.connection + def custom_resolver(self) -> annotation: + ... + + +@pytest.mark.raises_strawberry_exception( + RelayWrongNodeResolverAnnotationError, + match=( + 'Unable to determine the connection type of field "custom_resolver". ' + "The `node_resolver` function should be annotated with a return value " + "of ``" + ), +) +@pytest.mark.parametrize("annotation", [int, object, NonNodeType]) +def test_raises_error_on_connection_with_wrong_node_resolver_annotation(annotation): + def node_converter(n: Fruit) -> annotation: + ... + + @strawberry.type + class Query: + @strawberry.relay.connection(node_converter=node_converter) # type: ignore + def custom_resolver(self) -> List[Fruit]: + ... diff --git a/tests/relay/test_fields.py b/tests/relay/test_fields.py new file mode 100644 index 0000000000..f756e62361 --- /dev/null +++ b/tests/relay/test_fields.py @@ -0,0 +1,1182 @@ +import pytest + +import strawberry +from strawberry.relay.fields import ConnectionField, NodeField +from strawberry.relay.types import NodeType +from strawberry.relay.utils import to_base64 + +from .schema import schema + + +def test_type_uses_node_field(): + @strawberry.type + class Query: + node: strawberry.relay.Node + + node_field = Query._type_definition.get_field("node") # type: ignore + assert isinstance(node_field, NodeField) + + copied = node_field.copy_with({}) + assert copied.default_args == node_field.default_args + + +def test_type_uses_connection_field(): + @strawberry.type + class Fruit: + ... + + @strawberry.type + class Query: + connection: strawberry.relay.Connection + + connection_field = Query._type_definition.get_field("connection") # type: ignore + assert isinstance(connection_field, ConnectionField) + + copied = connection_field.copy_with({NodeType: Fruit}) + assert copied.default_args == connection_field.default_args + + +def test_query_node(): + result = schema.execute_sync( + """ + query TestQuery ($id: ID!) { + node (id: $id) { + ... on Node { + id + } + ... on Fruit { + name + color + } + } + } + """, + variable_values={ + "id": to_base64("Fruit", 2), + }, + ) + assert result.errors is None + assert result.data == { + "node": { + "id": to_base64("Fruit", 2), + "color": "red", + "name": "Apple", + }, + } + + +def test_query_node_optional(): + result = schema.execute_sync( + """ + query TestQuery ($id: ID!) { + nodeOptional (id: $id) { + ... on Node { + id + } + ... on Fruit { + name + color + } + } + } + """, + variable_values={ + "id": to_base64("Fruit", 999), + }, + ) + assert result.errors is None + assert result.data == {"nodeOptional": None} + + +async def test_query_node_async(): + result = await schema.execute( + """ + query TestQuery ($id: ID!) { + node (id: $id) { + ... on Node { + id + } + ... on Fruit { + name + color + } + } + } + """, + variable_values={ + "id": to_base64("Fruit", 2), + }, + ) + assert result.errors is None + assert result.data == { + "node": { + "id": to_base64("Fruit", 2), + "color": "red", + "name": "Apple", + }, + } + + +async def test_query_node_optional_async(): + result = await schema.execute( + """ + query TestQuery ($id: ID!) { + nodeOptional (id: $id) { + ... on Node { + id + } + ... on Fruit { + name + color + } + } + } + """, + variable_values={ + "id": to_base64("Fruit", 999), + }, + ) + assert result.errors is None + assert result.data == {"nodeOptional": None} + + +def test_query_nodes(): + result = schema.execute_sync( + """ + query TestQuery ($ids: [ID!]!) { + nodes (ids: $ids) { + ... on Node { + id + } + ... on Fruit { + name + color + } + } + } + """, + variable_values={ + "ids": [to_base64("Fruit", 2), to_base64("Fruit", 4)], + }, + ) + assert result.errors is None + assert result.data == { + "nodes": [ + { + "id": to_base64("Fruit", 2), + "name": "Apple", + "color": "red", + }, + { + "id": to_base64("Fruit", 4), + "name": "Grape", + "color": "purple", + }, + ], + } + + +def test_query_nodes_optional(): + result = schema.execute_sync( + """ + query TestQuery ($ids: [ID!]!) { + nodesOptional (ids: $ids) { + ... on Node { + id + } + ... on Fruit { + name + color + } + } + } + """, + variable_values={ + "ids": [ + to_base64("Fruit", 2), + to_base64("Fruit", 999), + to_base64("Fruit", 4), + ], + }, + ) + assert result.errors is None + assert result.data == { + "nodesOptional": [ + { + "id": to_base64("Fruit", 2), + "name": "Apple", + "color": "red", + }, + None, + { + "id": to_base64("Fruit", 4), + "name": "Grape", + "color": "purple", + }, + ], + } + + +async def test_query_nodes_async(): + result = await schema.execute( + """ + query TestQuery ($ids: [ID!]!) { + nodes (ids: $ids) { + ... on Node { + id + } + ... on Fruit { + name + color + } + ... on FruitAsync { + name + color + } + } + } + """, + variable_values={ + "ids": [ + to_base64("Fruit", 2), + to_base64("Fruit", 4), + to_base64("FruitAsync", 2), + ], + }, + ) + assert result.errors is None + assert result.data == { + "nodes": [ + { + "id": to_base64("Fruit", 2), + "name": "Apple", + "color": "red", + }, + { + "id": to_base64("Fruit", 4), + "name": "Grape", + "color": "purple", + }, + { + "id": to_base64("FruitAsync", 2), + "name": "Apple", + "color": "red", + }, + ], + } + + +async def test_query_nodes_optional_async(): + result = await schema.execute( + """ + query TestQuery ($ids: [ID!]!) { + nodesOptional (ids: $ids) { + ... on Node { + id + } + ... on Fruit { + name + color + } + ... on FruitAsync { + name + color + } + } + } + """, + variable_values={ + "ids": [ + to_base64("Fruit", 2), + to_base64("FruitAsync", 999), + to_base64("Fruit", 4), + to_base64("Fruit", 999), + to_base64("FruitAsync", 2), + ], + }, + ) + assert result.errors is None + assert result.data == { + "nodesOptional": [ + { + "id": to_base64("Fruit", 2), + "name": "Apple", + "color": "red", + }, + None, + { + "id": to_base64("Fruit", 4), + "name": "Grape", + "color": "purple", + }, + None, + { + "id": to_base64("FruitAsync", 2), + "name": "Apple", + "color": "red", + }, + ], + } + + +fruits_query = """ +query TestQuery ( + $first: Int = null + $last: Int = null + $before: String = null, + $after: String = null, +) {{ + {} ( + first: $first + last: $last + before: $before + after: $after + ) {{ + pageInfo {{ + hasNextPage + hasPreviousPage + startCursor + endCursor + }} + edges {{ + cursor + node {{ + id + name + color + }} + }} + }} +}} +""" + +attrs = [ + "fruits", + "fruitsCustomResolver", + "fruitsCustomResolverWithNodeConverter", + "fruitsCustomResolverWithNodeConverterForwardRef", + "fruitsCustomResolverIterator", + "fruitsCustomResolverIterable", + "fruitsCustomResolverGenerator", +] +async_attrs = [ + *attrs, + "fruitsCustomResolverAsyncIterator", + "fruitsCustomResolverAsyncIterable", + "fruitsCustomResolverAsyncGenerator", +] + + +@pytest.mark.parametrize("query_attr", attrs) +def test_query_connection(query_attr: str): + result = schema.execute_sync( + fruits_query.format(query_attr), + variable_values={}, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjA=", + "node": { + "id": to_base64("Fruit", 1), + "color": "yellow", + "name": "Banana", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjE=", + "node": { + "id": to_base64("Fruit", 2), + "color": "red", + "name": "Apple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjI=", + "node": { + "id": to_base64("Fruit", 3), + "color": "yellow", + "name": "Pineapple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjM=", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjQ=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": to_base64("arrayconnection", "0"), + "endCursor": to_base64("arrayconnection", "4"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", async_attrs) +async def test_query_connection_async(query_attr: str): + result = await schema.execute( + fruits_query.format(query_attr), + variable_values={}, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjA=", + "node": { + "id": to_base64("Fruit", 1), + "color": "yellow", + "name": "Banana", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjE=", + "node": { + "id": to_base64("Fruit", 2), + "color": "red", + "name": "Apple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjI=", + "node": { + "id": to_base64("Fruit", 3), + "color": "yellow", + "name": "Pineapple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjM=", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjQ=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": to_base64("arrayconnection", "0"), + "endCursor": to_base64("arrayconnection", "4"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", attrs) +def test_query_connection_filtering_first(query_attr: str): + result = schema.execute_sync( + fruits_query.format(query_attr), + variable_values={"first": 2}, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjA=", + "node": { + "id": to_base64("Fruit", 1), + "color": "yellow", + "name": "Banana", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjE=", + "node": { + "id": to_base64("Fruit", 2), + "color": "red", + "name": "Apple", + }, + }, + ], + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": False, + "startCursor": to_base64("arrayconnection", "0"), + "endCursor": to_base64("arrayconnection", "1"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", attrs) +def test_query_connection_filtering_first_with_after(query_attr: str): + result = schema.execute_sync( + fruits_query.format(query_attr), + variable_values={"first": 2, "after": to_base64("arrayconnection", "1")}, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjI=", + "node": { + "id": to_base64("Fruit", 3), + "color": "yellow", + "name": "Pineapple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjM=", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + ], + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": True, + "startCursor": to_base64("arrayconnection", "2"), + "endCursor": to_base64("arrayconnection", "3"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", attrs) +def test_query_connection_filtering_last(query_attr: str): + result = schema.execute_sync( + fruits_query.format(query_attr), + variable_values={"last": 2}, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjM=", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjQ=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": True, + "startCursor": to_base64("arrayconnection", "3"), + "endCursor": to_base64("arrayconnection", "4"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", attrs) +def test_query_connection_filtering_last_with_before(query_attr: str): + result = schema.execute_sync( + fruits_query.format(query_attr), + variable_values={"last": 2, "before": to_base64("arrayconnection", "4")}, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjI=", + "node": { + "id": to_base64("Fruit", 3), + "color": "yellow", + "name": "Pineapple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjM=", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + ], + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": True, + "startCursor": to_base64("arrayconnection", "2"), + "endCursor": to_base64("arrayconnection", "3"), + }, + } + } + + +fruits_custom_query = """ +query TestQuery ( + $first: Int = null + $last: Int = null + $before: String = null, + $after: String = null, +) { + fruitsCustomPagination ( + first: $first + last: $last + before: $before + after: $after + ) { + something + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + edges { + cursor + node { + id + name + color + } + } + } +} +""" + + +def test_query_custom_connection(): + result = schema.execute_sync( + fruits_custom_query, + variable_values={}, + ) + assert result.errors is None + assert result.data == { + "fruitsCustomPagination": { + "something": "foobar", + "edges": [ + { + "cursor": "ZnJ1aXRfbmFtZTpBcHBsZQ==", + "node": { + "id": to_base64("Fruit", 2), + "color": "red", + "name": "Apple", + }, + }, + { + "cursor": "ZnJ1aXRfbmFtZTpCYW5hbmE=", + "node": { + "id": to_base64("Fruit", 1), + "color": "yellow", + "name": "Banana", + }, + }, + { + "cursor": "ZnJ1aXRfbmFtZTpHcmFwZQ==", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + { + "cursor": "ZnJ1aXRfbmFtZTpPcmFuZ2U=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + { + "cursor": "ZnJ1aXRfbmFtZTpQaW5lYXBwbGU=", + "node": { + "id": to_base64("Fruit", 3), + "color": "yellow", + "name": "Pineapple", + }, + }, + ], + "pageInfo": { + "startCursor": to_base64("fruit_name", "Apple"), + "endCursor": to_base64("fruit_name", "Pineapple"), + "hasNextPage": False, + "hasPreviousPage": False, + }, + } + } + + +def test_query_custom_connection_filtering_first(): + result = schema.execute_sync( + fruits_custom_query, + variable_values={"first": 2}, + ) + assert result.errors is None + assert result.data == { + "fruitsCustomPagination": { + "something": "foobar", + "edges": [ + { + "cursor": "ZnJ1aXRfbmFtZTpBcHBsZQ==", + "node": { + "id": to_base64("Fruit", 2), + "color": "red", + "name": "Apple", + }, + }, + { + "cursor": "ZnJ1aXRfbmFtZTpCYW5hbmE=", + "node": { + "id": to_base64("Fruit", 1), + "color": "yellow", + "name": "Banana", + }, + }, + ], + "pageInfo": { + "startCursor": to_base64("fruit_name", "Apple"), + "endCursor": to_base64("fruit_name", "Banana"), + "hasNextPage": True, + "hasPreviousPage": False, + }, + } + } + + +def test_query_custom_connection_filtering_first_with_after(): + result = schema.execute_sync( + fruits_custom_query, + variable_values={"first": 2, "after": to_base64("fruit_name", "Banana")}, + ) + assert result.errors is None + assert result.data == { + "fruitsCustomPagination": { + "something": "foobar", + "edges": [ + { + "cursor": "ZnJ1aXRfbmFtZTpHcmFwZQ==", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + { + "cursor": "ZnJ1aXRfbmFtZTpPcmFuZ2U=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + ], + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": True, + "startCursor": to_base64("fruit_name", "Grape"), + "endCursor": to_base64("fruit_name", "Orange"), + }, + } + } + + +def test_query_custom_connection_filtering_last(): + result = schema.execute_sync( + fruits_custom_query, + variable_values={"last": 2}, + ) + assert result.errors is None + assert result.data == { + "fruitsCustomPagination": { + "something": "foobar", + "edges": [ + { + "cursor": "ZnJ1aXRfbmFtZTpPcmFuZ2U=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + { + "cursor": "ZnJ1aXRfbmFtZTpQaW5lYXBwbGU=", + "node": { + "id": to_base64("Fruit", 3), + "color": "yellow", + "name": "Pineapple", + }, + }, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": True, + "startCursor": to_base64("fruit_name", "Orange"), + "endCursor": to_base64("fruit_name", "Pineapple"), + }, + } + } + + +def test_query_custom_connection_filtering_first_with_before(): + result = schema.execute_sync( + fruits_custom_query, + variable_values={ + "last": 2, + "before": to_base64("fruit_name", "Pineapple"), + }, + ) + assert result.errors is None + assert result.data == { + "fruitsCustomPagination": { + "something": "foobar", + "edges": [ + { + "cursor": "ZnJ1aXRfbmFtZTpHcmFwZQ==", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + { + "cursor": "ZnJ1aXRfbmFtZTpPcmFuZ2U=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + ], + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": True, + "startCursor": to_base64("fruit_name", "Grape"), + "endCursor": to_base64("fruit_name", "Orange"), + }, + } + } + + +fruits_query_custom_resolver = """ +query TestQuery ( + $first: Int = null + $last: Int = null + $before: String = null, + $after: String = null, + $nameEndswith: String = null +) {{ + {} ( + first: $first + last: $last + before: $before + after: $after + nameEndswith: $nameEndswith + ) {{ + pageInfo {{ + hasNextPage + hasPreviousPage + startCursor + endCursor + }} + edges {{ + cursor + node {{ + id + name + color + }} + }} + }} +}} +""" + +custom_attrs = [ + "fruitsCustomResolver", + "fruitsCustomResolverWithNodeConverter", + "fruitsCustomResolverWithNodeConverterForwardRef", + "fruitsCustomResolverIterator", + "fruitsCustomResolverIterable", + "fruitsCustomResolverGenerator", +] +custom_async_attrs = [ + *attrs, + "fruitsCustomResolverAsyncIterator", + "fruitsCustomResolverAsyncIterable", + "fruitsCustomResolverAsyncGenerator", +] + + +@pytest.mark.parametrize("query_attr", custom_attrs) +def test_query_connection_custom_resolver(query_attr: str): + result = schema.execute_sync( + fruits_query_custom_resolver.format(query_attr), + variable_values={"nameEndswith": "e"}, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjA=", + "node": { + "id": to_base64("Fruit", 2), + "color": "red", + "name": "Apple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjE=", + "node": { + "id": to_base64("Fruit", 3), + "color": "yellow", + "name": "Pineapple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjI=", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjM=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": to_base64("arrayconnection", "0"), + "endCursor": to_base64("arrayconnection", "3"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", custom_attrs) +def test_query_connection_custom_resolver_filtering_first(query_attr: str): + result = schema.execute_sync( + fruits_query_custom_resolver.format(query_attr), + variable_values={"first": 2, "nameEndswith": "e"}, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjA=", + "node": { + "id": to_base64("Fruit", 2), + "color": "red", + "name": "Apple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjE=", + "node": { + "id": to_base64("Fruit", 3), + "color": "yellow", + "name": "Pineapple", + }, + }, + ], + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": False, + "startCursor": to_base64("arrayconnection", "0"), + "endCursor": to_base64("arrayconnection", "1"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", custom_attrs) +def test_query_connection_custom_resolver_filtering_first_with_after(query_attr: str): + result = schema.execute_sync( + fruits_query_custom_resolver.format(query_attr), + variable_values={ + "first": 2, + "after": to_base64("arrayconnection", "1"), + "nameEndswith": "e", + }, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjI=", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjM=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": True, + "startCursor": to_base64("arrayconnection", "2"), + "endCursor": to_base64("arrayconnection", "3"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", custom_attrs) +def test_query_connection_custom_resolver_filtering_last(query_attr: str): + result = schema.execute_sync( + fruits_query_custom_resolver.format(query_attr), + variable_values={"last": 2, "nameEndswith": "e"}, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjI=", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjM=", + "node": { + "id": to_base64("Fruit", 5), + "color": "orange", + "name": "Orange", + }, + }, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": True, + "startCursor": to_base64("arrayconnection", "2"), + "endCursor": to_base64("arrayconnection", "3"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", custom_attrs) +def test_query_connection_custom_resolver_filtering_last_with_before(query_attr: str): + result = schema.execute_sync( + fruits_query_custom_resolver.format(query_attr), + variable_values={ + "last": 2, + "before": to_base64("arrayconnection", "3"), + "nameEndswith": "e", + }, + ) + assert result.errors is None + assert result.data == { + query_attr: { + "edges": [ + { + "cursor": "YXJyYXljb25uZWN0aW9uOjE=", + "node": { + "id": to_base64("Fruit", 3), + "color": "yellow", + "name": "Pineapple", + }, + }, + { + "cursor": "YXJyYXljb25uZWN0aW9uOjI=", + "node": { + "id": to_base64("Fruit", 4), + "color": "purple", + "name": "Grape", + }, + }, + ], + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": True, + "startCursor": to_base64("arrayconnection", "1"), + "endCursor": to_base64("arrayconnection", "2"), + }, + } + } + + +@pytest.mark.parametrize("query_attr", custom_attrs) +def test_query_first_negative(query_attr: str): + result = schema.execute_sync( + fruits_query_custom_resolver.format(query_attr), + variable_values={"first": -1}, + ) + assert result.errors is not None + assert ( + result.errors[0].message == "Argument 'first' must be a non-negative integer." + ) + + +@pytest.mark.parametrize("query_attr", custom_attrs) +def test_query_first_higher_than_max_results(query_attr: str): + result = schema.execute_sync( + fruits_query_custom_resolver.format(query_attr), + variable_values={"first": 500}, + ) + assert result.errors is not None + assert result.errors[0].message == "Argument 'first' cannot be higher than 100." + + +@pytest.mark.parametrize("query_attr", custom_attrs) +def test_query_last_negative(query_attr: str): + result = schema.execute_sync( + fruits_query_custom_resolver.format(query_attr), + variable_values={"last": -1}, + ) + assert result.errors is not None + assert result.errors[0].message == "Argument 'last' must be a non-negative integer." + + +@pytest.mark.parametrize("query_attr", custom_attrs) +def test_query_last_higher_than_max_results(query_attr: str): + result = schema.execute_sync( + fruits_query_custom_resolver.format(query_attr), + variable_values={"last": 500}, + ) + assert result.errors is not None + assert result.errors[0].message == "Argument 'last' cannot be higher than 100." diff --git a/tests/relay/test_schema.py b/tests/relay/test_schema.py new file mode 100644 index 0000000000..8f937759b3 --- /dev/null +++ b/tests/relay/test_schema.py @@ -0,0 +1,16 @@ +import pathlib + +from .schema import schema + + +def test_schema(): + schema_output = str(schema).strip("\n").strip(" ") + output = pathlib.Path(__file__).parent / "schema.gql" + if not output.exists(): + with output.open("w") as f: + f.write(schema_output + "\n") + + with output.open() as f: + expected = f.read().strip("\n").strip(" ") + + assert schema_output == expected diff --git a/tests/relay/test_types.py b/tests/relay/test_types.py new file mode 100644 index 0000000000..f7a1407073 --- /dev/null +++ b/tests/relay/test_types.py @@ -0,0 +1,170 @@ +from typing import Any, Optional, Union, cast +from typing_extensions import assert_type + +import pytest + +import strawberry +from strawberry import relay +from strawberry.relay.types import GlobalIDValueError, Node +from strawberry.types.info import Info + +from .schema import Fruit, FruitAsync, schema + + +class FakeInfo: + schema = schema + + +# We only need that info contains the schema for the tests +fake_info = cast(Info, FakeInfo()) + + +@pytest.mark.parametrize("type_name", [None, 1, 1.1]) +def test_global_id_wrong_type_name(type_name: Any): + with pytest.raises(GlobalIDValueError) as exc_info: + strawberry.relay.GlobalID(type_name=type_name, node_id="foobar") + + +@pytest.mark.parametrize("node_id", [None, 1, 1.1]) +def test_global_id_wrong_type_node_id(node_id: Any): + with pytest.raises(GlobalIDValueError) as exc_info: + strawberry.relay.GlobalID(type_name="foobar", node_id=node_id) + + +def test_global_id_from_id(): + gid = strawberry.relay.GlobalID.from_id("Zm9vYmFyOjE=") + assert gid.type_name == "foobar" + assert gid.node_id == "1" + + +@pytest.mark.parametrize("value", ["foobar", ["Zm9vYmFy"], 123]) +def test_global_id_from_id_error(value: Any): + with pytest.raises(GlobalIDValueError) as exc_info: + strawberry.relay.GlobalID.from_id(value) + + +def test_global_id_resolve_type(): + gid = strawberry.relay.GlobalID(type_name="Fruit", node_id="1") + type_ = gid.resolve_type(fake_info) + assert type_ is Fruit + + +def test_global_id_resolve_node(): + gid = strawberry.relay.GlobalID(type_name="Fruit", node_id="1") + fruit = gid.resolve_node(fake_info) + assert isinstance(fruit, Fruit) + assert fruit.id == 1 + assert fruit.name == "Banana" + + +def test_global_id_resolve_node_non_existing(): + gid = strawberry.relay.GlobalID(type_name="Fruit", node_id="999") + fruit = gid.resolve_node(fake_info) + assert_type(fruit, Optional[Node]) + assert fruit is None + + +def test_global_id_resolve_node_non_existing_but_required(): + with pytest.raises(KeyError): + gid = strawberry.relay.GlobalID(type_name="Fruit", node_id="999") + fruit = gid.resolve_node(fake_info, required=True) + + +def test_global_id_resolve_node_ensure_type(): + gid = strawberry.relay.GlobalID(type_name="Fruit", node_id="1") + fruit = gid.resolve_node(fake_info, ensure_type=Fruit) + assert_type(fruit, Fruit) + assert isinstance(fruit, Fruit) + assert fruit.id == 1 + assert fruit.name == "Banana" + + +def test_global_id_resolve_node_ensure_type_with_union(): + class Foo: + ... + + gid = strawberry.relay.GlobalID(type_name="Fruit", node_id="1") + fruit = gid.resolve_node(fake_info, ensure_type=Union[Fruit, Foo]) + assert_type(fruit, Union[Fruit, Foo]) + assert isinstance(fruit, Fruit) + assert fruit.id == 1 + assert fruit.name == "Banana" + + +def test_global_id_resolve_node_ensure_type_wrong_type(): + class Foo: + ... + + gid = strawberry.relay.GlobalID(type_name="Fruit", node_id="1") + with pytest.raises(TypeError): + fruit = gid.resolve_node(fake_info, ensure_type=Foo) + + +async def test_global_id_aresolve_node(): + gid = strawberry.relay.GlobalID(type_name="FruitAsync", node_id="1") + fruit = await gid.aresolve_node(fake_info) + assert_type(fruit, Optional[Node]) + assert isinstance(fruit, FruitAsync) + assert fruit.id == 1 + assert fruit.name == "Banana" + + +async def test_global_id_aresolve_node_non_existing(): + gid = strawberry.relay.GlobalID(type_name="FruitAsync", node_id="999") + fruit = await gid.aresolve_node(fake_info) + assert_type(fruit, Optional[Node]) + assert fruit is None + + +async def test_global_id_aresolve_node_non_existing_but_required(): + with pytest.raises(KeyError): + gid = strawberry.relay.GlobalID(type_name="FruitAsync", node_id="999") + fruit = await gid.aresolve_node(fake_info, required=True) + + +async def test_global_id_aresolve_node_ensure_type(): + gid = strawberry.relay.GlobalID(type_name="FruitAsync", node_id="1") + fruit = await gid.aresolve_node(fake_info, ensure_type=FruitAsync) + assert_type(fruit, FruitAsync) + assert isinstance(fruit, FruitAsync) + assert fruit.id == 1 + assert fruit.name == "Banana" + + +async def test_global_id_aresolve_node_ensure_type_with_union(): + class Foo: + ... + + gid = strawberry.relay.GlobalID(type_name="FruitAsync", node_id="1") + fruit = await gid.aresolve_node(fake_info, ensure_type=Union[FruitAsync, Foo]) + assert_type(fruit, Union[FruitAsync, Foo]) + assert isinstance(fruit, FruitAsync) + assert fruit.id == 1 + assert fruit.name == "Banana" + + +async def test_global_id_aresolve_node_ensure_type_wrong_type(): + class Foo: + ... + + gid = strawberry.relay.GlobalID(type_name="FruitAsync", node_id="1") + with pytest.raises(TypeError): + fruit = await gid.aresolve_node(fake_info, ensure_type=Foo) + + +async def test_node_no_id(): + with pytest.raises(TypeError): + + @strawberry.type + class Foo(relay.Node): + foo: str + bar: str + + +async def test_node_more_than_one_id(): + with pytest.raises(TypeError): + + @strawberry.type + class Foo(relay.Node): + foo: relay.NodeID[str] + bar: relay.NodeID[str] diff --git a/tests/relay/test_utils.py b/tests/relay/test_utils.py new file mode 100644 index 0000000000..d6f7c471aa --- /dev/null +++ b/tests/relay/test_utils.py @@ -0,0 +1,55 @@ +from typing import Any + +import pytest + +from strawberry.relay.utils import from_base64, to_base64 + +from .schema import Fruit + + +def test_from_base64(): + type_name, node_id = from_base64("Zm9vYmFyOjE=") + assert type_name == "foobar" + assert node_id == "1" + + +@pytest.mark.parametrize("value", [None, 1, 1.1, "dsadfas"]) +def test_from_base64_non_base64(value: Any): + with pytest.raises(ValueError): + type_name, node_id = from_base64(value) + + +@pytest.mark.parametrize( + "value", + [ + "Zm9vYmFy", # "foobar" + "Zm9vYmFyOjE6Mg==", # "foobar:1:2" + ], +) +def test_from_base64_wrong_number_of_args(value: Any): + with pytest.raises(ValueError): + type_name, node_id = from_base64(value) + + +def test_to_base64(): + value = to_base64("foobar", "1") + assert value == "Zm9vYmFyOjE=" + + +def test_to_base64_with_type(): + value = to_base64(Fruit, "1") + assert value == "RnJ1aXQ6MQ==" + + +def test_to_base64_with_typedef(): + value = to_base64( + Fruit._type_definition, # type: ignore + "1", + ) + assert value == "RnJ1aXQ6MQ==" + + +@pytest.mark.parametrize("value", [None, 1, 1.1, object()]) +def test_to_base64_with_invalid_type(value: Any): + with pytest.raises(ValueError): + value = to_base64(value, "1") diff --git a/tests/test_aio.py b/tests/test_aio.py new file mode 100644 index 0000000000..9df1d79f1b --- /dev/null +++ b/tests/test_aio.py @@ -0,0 +1,63 @@ +from strawberry.utils.aio import ( + aenumerate, + aislice, + asyncgen_to_list, + resolve_awaitable, +) + + +async def test_aenumerate(): + async def gen(): + yield "a" + yield "b" + yield "c" + yield "d" + + res = [(i, v) async for i, v in aenumerate(gen())] + assert res == [(0, "a"), (1, "b"), (2, "c"), (3, "d")] + + +async def test_aslice(): + async def gen(): + yield "a" + yield "b" + raise AssertionError("should never be called") # pragma: no cover + yield "c" # pragma: no cover + + res = [] + async for v in aislice(gen(), 0, 2): + res.append(v) + + assert res == ["a", "b"] + + +async def test_aslice_with_step(): + async def gen(): + yield "a" + yield "b" + yield "c" + raise AssertionError("should never be called") # pragma: no cover + yield "d" # pragma: no cover + yield "e" # pragma: no cover + + res = [] + async for v in aislice(gen(), 0, 4, 2): + res.append(v) + + assert res == ["a", "c"] + + +async def test_asyncgen_to_list(): + async def gen(): + yield "a" + yield "b" + yield "c" + + assert await asyncgen_to_list(gen()) == ["a", "b", "c"] + + +async def test_resolve_awaitable(): + async def awaitable(): + return 1 + + assert await resolve_awaitable(awaitable(), lambda v: v + 1) == 2 diff --git a/tests/test_inspect.py b/tests/test_inspect.py new file mode 100644 index 0000000000..b2bfcec6a1 --- /dev/null +++ b/tests/test_inspect.py @@ -0,0 +1,16 @@ +from strawberry.utils.inspect import in_async_context + + +def test_in_async_context_sync(): + assert not in_async_context() + + +async def test_in_async_context_async(): + assert in_async_context() + + +async def test_in_async_context_async_with_inner_sync_function(): + def inner_sync_function(): + assert in_async_context() + + inner_sync_function()