Skip to content

Commit

Permalink
Implement query complexity estimation (#5880)
Browse files Browse the repository at this point in the history
* add complexity library and expose method to get estimators

* initialreview changes

* export default complexity estimators instead of getter

* Apply suggestions from code review

Co-authored-by: MacondoExpress <[email protected]>

---------

Co-authored-by: MacondoExpress <[email protected]>
  • Loading branch information
a-alle and MacondoExpress authored Dec 19, 2024
1 parent 07ea342 commit 3d93a18
Show file tree
Hide file tree
Showing 15 changed files with 634 additions and 61 deletions.
1 change: 1 addition & 0 deletions packages/graphql/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"dot-prop": "^6.0.1",
"graphql-compose": "^9.0.8",
"graphql-parse-resolve-info": "^4.12.3",
"graphql-query-complexity": "^1.0.0",
"graphql-relay": "^0.10.0",
"jose": "^5.0.0",
"pluralize": "^8.0.0",
Expand Down
106 changes: 106 additions & 0 deletions packages/graphql/src/classes/ComplexityEstimatorHelper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { type DefinitionNode,type GraphQLFieldExtensions, type GraphQLSchema, GraphQLInterfaceType, GraphQLObjectType, Kind } from "graphql";
import type { ComplexityEstimator} from "graphql-query-complexity";
import { fieldExtensionsEstimator, simpleEstimator } from "graphql-query-complexity";

export class ComplexityEstimatorHelper {
private objectTypeNameToFieldNamesMapForComplexityExtensions: Map<string, string[]>;
private useComplexityEstimators: boolean;

constructor(useComplexityEstimators: boolean) {
this.useComplexityEstimators = useComplexityEstimators
this.objectTypeNameToFieldNamesMapForComplexityExtensions = new Map<string, string[]>();
}

public registerField(parentObjectTypeNameOrInterfaceTypeName: string, fieldName: string): void {
if(this.useComplexityEstimators) {
const existingFieldsForTypeName = this.objectTypeNameToFieldNamesMapForComplexityExtensions.get(parentObjectTypeNameOrInterfaceTypeName) ?? []
this.objectTypeNameToFieldNamesMapForComplexityExtensions.set(parentObjectTypeNameOrInterfaceTypeName, existingFieldsForTypeName.concat(fieldName))
}
}

public hydrateDefinitionNodeWithComplexityExtensions(definition: DefinitionNode): DefinitionNode {
if(definition.kind !== Kind.OBJECT_TYPE_DEFINITION && definition.kind !== Kind.INTERFACE_TYPE_DEFINITION) {
return definition;
}
if(!this.objectTypeNameToFieldNamesMapForComplexityExtensions.has(definition.name.value)) {
return definition
}

const fieldsWithComplexity = definition.fields?.map(f => {
const hasFieldComplexityEstimator = this.getFieldsForParentTypeName(definition.name.value).find(fieldName => fieldName === f.name.value)
if (!hasFieldComplexityEstimator) {
return f
}
return {
...f,
extensions: {
// COMPLEXITY FORMULA
// c = c_child + lvl_limit * c_field, where
// c_field = 1
// lvl_limit defaults to 1
// c_child comes from simpleEstimator
complexity: ({childComplexity, args}) => {
const fieldDefaultComplexity = 1
const defaultLimitIfNotProvided = 1
if(args.limit ?? args.first) {
return childComplexity + (args.limit ?? args.first) * fieldDefaultComplexity
}
return childComplexity + defaultLimitIfNotProvided

},
},
}
})
return {
...definition,
fields: fieldsWithComplexity,
}
}


public hydrateSchemaFromSDLWithASTNodeExtensions(schema: GraphQLSchema): void {
const types = schema.getTypeMap();
Object.values(types).forEach((type) => {
if (type instanceof GraphQLObjectType || type instanceof GraphQLInterfaceType) {
const fields = type.getFields();
Object.values(fields).forEach((field) => {
if (field.astNode && 'extensions' in field.astNode) {
field.extensions = field.astNode.extensions as GraphQLFieldExtensions<any, any, any>;
}
});
}
});
}

public getComplexityEstimators(): ComplexityEstimator[] {
if (!this.useComplexityEstimators) {
return []
}
return [
fieldExtensionsEstimator(),
simpleEstimator({ defaultComplexity: 1 }),
];
}

private getFieldsForParentTypeName(parentObjectTypeNameOrInterfaceTypeName: string): string[] {
return this.objectTypeNameToFieldNamesMapForComplexityExtensions.get(parentObjectTypeNameOrInterfaceTypeName) || []
}
}
7 changes: 7 additions & 0 deletions packages/graphql/src/classes/Neo4jGraphQL.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import { Neo4jGraphQLSubscriptionsCDCEngine } from "./subscription/Neo4jGraphQLS
import { assertIndexesAndConstraints } from "./utils/asserts-indexes-and-constraints";
import { generateResolverComposition } from "./utils/generate-resolvers-composition";
import checkNeo4jCompat from "./utils/verify-database";
import { ComplexityEstimatorHelper } from "./ComplexityEstimatorHelper";

type TypeDefinitions = string | DocumentNode | TypeDefinitions[] | (() => TypeDefinitions);

Expand All @@ -75,6 +76,7 @@ class Neo4jGraphQL {
private jwtFieldsMap?: Map<string, string>;

private schemaModel?: Neo4jGraphQLSchemaModel;
private complexityEstimatorHelper: ComplexityEstimatorHelper;

private executableSchema?: Promise<GraphQLSchema>;
private subgraphSchema?: Promise<GraphQLSchema>;
Expand Down Expand Up @@ -108,6 +110,8 @@ class Neo4jGraphQL {

this.authorization = new Neo4jGraphQLAuthorization(authorizationSettings);
}

this.complexityEstimatorHelper = new ComplexityEstimatorHelper(!!this.features.complexityEstimators);
}

public async getSchema(): Promise<GraphQLSchema> {
Expand Down Expand Up @@ -393,6 +397,7 @@ class Neo4jGraphQL {
features: this.features,
userCustomResolvers: this.resolvers,
schemaModel: this.schemaModel,
complexityEstimatorHelper: this.complexityEstimatorHelper,
});

if (this.validate) {
Expand All @@ -406,6 +411,7 @@ class Neo4jGraphQL {
typeDefs,
resolvers,
});
this.complexityEstimatorHelper.hydrateSchemaFromSDLWithASTNodeExtensions(schema);

resolve(this.composeSchema(schema));
});
Expand Down Expand Up @@ -457,6 +463,7 @@ class Neo4jGraphQL {
userCustomResolvers: this.resolvers,
subgraph,
schemaModel: this.schemaModel,
complexityEstimatorHelper: this.complexityEstimatorHelper,
});

if (this.validate) {
Expand Down
3 changes: 3 additions & 0 deletions packages/graphql/src/classes/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
* limitations under the License.
*/

import { fieldExtensionsEstimator, simpleEstimator } from "graphql-query-complexity";

export * from "./Error";
export { GraphElement } from "./GraphElement";
export { Neo4jDatabaseInfo } from "./Neo4jDatabaseInfo";
export { default as Neo4jGraphQL, Neo4jGraphQLConstructor } from "./Neo4jGraphQL";
export { default as Node, NodeConstructor } from "./Node";
export { default as Relationship } from "./Relationship";
export const DefaultComplexityEstimators = [fieldExtensionsEstimator(), simpleEstimator({ defaultComplexity: 1 })];
4 changes: 4 additions & 0 deletions packages/graphql/src/schema/augment/fulltext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@ import {
withFulltextWhereInputType,
} from "../generation/fulltext-input";
import { fulltextResolver } from "../resolvers/query/fulltext";
import { type ComplexityEstimatorHelper } from "../../classes/ComplexityEstimatorHelper";

export function augmentFulltextSchema({
composer,
concreteEntityAdapter,
complexityEstimatorHelper,
features,
}: {
composer: SchemaComposer;
concreteEntityAdapter: ConcreteEntityAdapter;
complexityEstimatorHelper: ComplexityEstimatorHelper
features?: Neo4jFeaturesSettings;
}) {
if (!concreteEntityAdapter.annotations.fulltext) {
Expand All @@ -61,6 +64,7 @@ export function augmentFulltextSchema({
after: GraphQLString,
};

complexityEstimatorHelper.registerField("Query", index.queryName);
composer.Query.addFields({
[index.queryName]: {
type: withFulltextResultTypeConnection({ composer, concreteEntityAdapter }).NonNull,
Expand Down
4 changes: 4 additions & 0 deletions packages/graphql/src/schema/augment/vector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@ import {
withVectorWhereInputType,
} from "../generation/vector-input";
import { vectorResolver } from "../resolvers/query/vector";
import { type ComplexityEstimatorHelper } from "../../classes/ComplexityEstimatorHelper";

export function augmentVectorSchema({
composer,
concreteEntityAdapter,
complexityEstimatorHelper,
features,
}: {
composer: SchemaComposer;
concreteEntityAdapter: ConcreteEntityAdapter;
complexityEstimatorHelper: ComplexityEstimatorHelper
features?: Neo4jFeaturesSettings;
}) {
if (!concreteEntityAdapter.annotations.vector) {
Expand Down Expand Up @@ -67,6 +70,7 @@ export function augmentVectorSchema({
vectorArgs["vector"] = new GraphQLList(new GraphQLNonNull(GraphQLFloat));
}

complexityEstimatorHelper.registerField("Query", index.queryName);
composer.Query.addFields({
[index.queryName]: {
type: withVectorResultTypeConnection({ composer, concreteEntityAdapter }).NonNull,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import { withSortInputType } from "../generation/sort-and-options-input";
import { augmentUpdateInputTypeWithUpdateFieldInput, withUpdateInputType } from "../generation/update-input";
import { withSourceWhereInputType, withWhereInputType } from "../generation/where-input";
import { graphqlDirectivesToCompose } from "../to-compose";
import { type ComplexityEstimatorHelper } from "../../classes/ComplexityEstimatorHelper";

function doForRelationshipDeclaration({
relationshipDeclarationAdapter,
Expand Down Expand Up @@ -165,6 +166,7 @@ export function createRelationshipFields({
userDefinedDirectivesForNode,
userDefinedFieldDirectivesForNode,
features,
complexityEstimatorHelper,
}: {
entityAdapter: ConcreteEntityAdapter | InterfaceEntityAdapter;
schemaComposer: SchemaComposer;
Expand All @@ -175,6 +177,7 @@ export function createRelationshipFields({
userDefinedDirectivesForNode: Map<string, DirectiveNode[]>;
userDefinedFieldDirectivesForNode: Map<string, Map<string, DirectiveNode[]>>;
features?: Neo4jFeaturesSettings;
complexityEstimatorHelper: ComplexityEstimatorHelper;
}): void {
const relationships =
entityAdapter instanceof ConcreteEntityAdapter
Expand Down Expand Up @@ -243,6 +246,7 @@ export function createRelationshipFields({
userDefinedDirectivesOnTargetFields: Map<string, DirectiveNode[]> | undefined;
subgraph?: Subgraph;
features: Neo4jFeaturesSettings | undefined;
complexityEstimatorHelper: ComplexityEstimatorHelper;
} = {
relationshipAdapter,
composer: schemaComposer,
Expand All @@ -251,6 +255,7 @@ export function createRelationshipFields({
deprecatedDirectives,
userDefinedDirectivesOnTargetFields,
features,
complexityEstimatorHelper,
};

if (relationshipTarget instanceof UnionEntityAdapter) {
Expand Down Expand Up @@ -297,6 +302,7 @@ function createRelationshipFieldsForTarget({
userDefinedDirectivesOnTargetFields,
subgraph, // only for concrete targets
features,
complexityEstimatorHelper,
}: {
relationshipAdapter: RelationshipAdapter | RelationshipDeclarationAdapter;
composer: SchemaComposer;
Expand All @@ -306,6 +312,7 @@ function createRelationshipFieldsForTarget({
deprecatedDirectives: Directive[];
subgraph?: Subgraph;
features: Neo4jFeaturesSettings | undefined;
complexityEstimatorHelper: ComplexityEstimatorHelper;
}) {
withSourceWhereInputType({
relationshipAdapter,
Expand All @@ -318,6 +325,8 @@ function createRelationshipFieldsForTarget({
if (relationshipAdapter.target instanceof InterfaceEntityAdapter) {
withFieldInputType({ relationshipAdapter, composer, userDefinedFieldDirectives });
}

complexityEstimatorHelper.registerField(composeNode.getTypeName(), relationshipAdapter.name)
composeNode.addFields(
augmentObjectOrInterfaceTypeWithRelationshipField({
relationshipAdapter,
Expand All @@ -327,7 +336,8 @@ function createRelationshipFieldsForTarget({
features,
})
);


complexityEstimatorHelper.registerField(composeNode.getTypeName(), relationshipAdapter.operations.connectionFieldName)
composeNode.addFields(
augmentObjectOrInterfaceTypeWithConnectionField(relationshipAdapter, userDefinedFieldDirectives, composer, features)
);
Expand Down
13 changes: 9 additions & 4 deletions packages/graphql/src/schema/make-augmented-schema.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import { gql } from "graphql-tag";
import { Node } from "../classes";
import { generateModel } from "../schema-model/generate-model";
import makeAugmentedSchema from "./make-augmented-schema";
import { ComplexityEstimatorHelper } from "../classes/ComplexityEstimatorHelper";

describe("makeAugmentedSchema", () => {
test("should be a function", () => {
Expand All @@ -52,7 +53,7 @@ describe("makeAugmentedSchema", () => {
`;

const schemaModel = generateModel(mergeTypeDefs(typeDefs));
const neoSchema = makeAugmentedSchema({ document: typeDefs, schemaModel });
const neoSchema = makeAugmentedSchema({ document: typeDefs, schemaModel, complexityEstimatorHelper: new ComplexityEstimatorHelper(false) });
const document = neoSchema.typeDefs;
const queryObject = document.definitions.find(
(x) => x.kind === Kind.OBJECT_TYPE_DEFINITION && x.name.value === "Query"
Expand Down Expand Up @@ -96,7 +97,7 @@ describe("makeAugmentedSchema", () => {
`;

const schemaModel = generateModel(mergeTypeDefs(typeDefs));
const neoSchema = makeAugmentedSchema({ document: typeDefs, schemaModel });
const neoSchema = makeAugmentedSchema({ document: typeDefs, schemaModel, complexityEstimatorHelper: new ComplexityEstimatorHelper(false) });

const document = neoSchema.typeDefs;

Expand Down Expand Up @@ -127,6 +128,7 @@ describe("makeAugmentedSchema", () => {
},
},
schemaModel,
complexityEstimatorHelper: new ComplexityEstimatorHelper(false),
});

const document = neoSchema.typeDefs;
Expand Down Expand Up @@ -159,6 +161,7 @@ describe("makeAugmentedSchema", () => {
},
},
schemaModel,
complexityEstimatorHelper: new ComplexityEstimatorHelper(false),
});

const document = neoSchema.typeDefs;
Expand Down Expand Up @@ -194,6 +197,7 @@ describe("makeAugmentedSchema", () => {
},
},
schemaModel,
complexityEstimatorHelper: new ComplexityEstimatorHelper(false),
});

const document = neoSchema.typeDefs;
Expand Down Expand Up @@ -233,6 +237,7 @@ describe("makeAugmentedSchema", () => {
},
},
schemaModel,
complexityEstimatorHelper: new ComplexityEstimatorHelper(false),
});

const document = neoSchema.typeDefs;
Expand Down Expand Up @@ -266,7 +271,7 @@ describe("makeAugmentedSchema", () => {
`;

const schemaModel = generateModel(mergeTypeDefs(typeDefs));
const neoSchema = makeAugmentedSchema({ document: typeDefs, schemaModel });
const neoSchema = makeAugmentedSchema({ document: typeDefs, schemaModel, complexityEstimatorHelper: new ComplexityEstimatorHelper(false) });

const document = neoSchema.typeDefs;

Expand All @@ -286,7 +291,7 @@ describe("makeAugmentedSchema", () => {
`;

const schemaModel = generateModel(mergeTypeDefs(typeDefs));
expect(() => makeAugmentedSchema({ document: typeDefs, schemaModel })).not.toThrow(
expect(() => makeAugmentedSchema({ document: typeDefs, schemaModel, complexityEstimatorHelper: new ComplexityEstimatorHelper(false) })).not.toThrow(
'Error: Type with name "ActionMapping" does not exists'
);
});
Expand Down
Loading

0 comments on commit 3d93a18

Please sign in to comment.