diff --git a/.changeset/tiny-rivers-build.md b/.changeset/tiny-rivers-build.md new file mode 100644 index 000000000..07bb7fefe --- /dev/null +++ b/.changeset/tiny-rivers-build.md @@ -0,0 +1,8 @@ +--- +'@graphql-tools/federation': patch +--- + +Handle shared subscription root fields correctly + +In case of conflicting subscription root fields coming from different subgraphs or different entry points(multiple keys), +subscription was failing. diff --git a/packages/federation/src/supergraph.ts b/packages/federation/src/supergraph.ts index 421200d27..c393c59d1 100644 --- a/packages/federation/src/supergraph.ts +++ b/packages/federation/src/supergraph.ts @@ -41,6 +41,7 @@ import { EnumValueDefinitionNode, FieldDefinitionNode, FieldNode, + GraphQLFieldResolver, GraphQLInterfaceType, GraphQLOutputType, GraphQLSchema, @@ -1216,9 +1217,8 @@ export function getStitchingOptionsFromSupergraphSdl( } if (operationType) { const defaultMergedField = defaultMerger(candidates); - return { - ...defaultMergedField, - resolve(_root, _args, context, info) { + const mergedResolver: GraphQLFieldResolver<{}, {}> = + function mergedResolver(_root, _args, context, info) { const originalSelectionSet: SelectionSetNode = { kind: Kind.SELECTION_SET, selections: info.fieldNodes, @@ -1361,7 +1361,19 @@ export function getStitchingOptionsFromSupergraphSdl( return Promise.all(jobs).then((results) => mergeResults(results)); } return mergeResults(jobs); - }, + }; + if (operationType === 'subscription') { + return { + ...defaultMergedField, + subscribe: mergedResolver, + resolve: function identityFn(payload) { + return payload; + }, + }; + } + return { + ...defaultMergedField, + resolve: mergedResolver, }; } const filteredCandidates = candidates.filter((candidate) => { diff --git a/packages/federation/tests/getStitchedSchemaFromLocalSchemas.ts b/packages/federation/tests/getStitchedSchemaFromLocalSchemas.ts index b29895b31..63c8cb295 100644 --- a/packages/federation/tests/getStitchedSchemaFromLocalSchemas.ts +++ b/packages/federation/tests/getStitchedSchemaFromLocalSchemas.ts @@ -2,9 +2,11 @@ import { createDefaultExecutor } from '@graphql-tools/delegate'; import { ExecutionRequest, ExecutionResult, + getDocumentNodeFromSchema, mapMaybePromise, } from '@graphql-tools/utils'; import { composeLocalSchemasWithApollo } from '@internal/testing'; +import { composeServices } from '@theguild/federation-composition'; import { GraphQLSchema } from 'graphql'; import { kebabCase } from 'lodash'; import { getStitchedSchemaFromSupergraphSdl } from '../src/supergraph'; @@ -14,21 +16,49 @@ export interface LocalSchemaItem { schema: GraphQLSchema; } -export async function getStitchedSchemaFromLocalSchemas( - localSchemas: Record, +export async function getStitchedSchemaFromLocalSchemas({ + localSchemas, + onSubgraphExecute, + composeWith = 'apollo', + ignoreRules, +}: { + localSchemas: Record; onSubgraphExecute?: ( subgraph: string, executionRequest: ExecutionRequest, result: ExecutionResult | AsyncIterable, - ) => void, -): Promise { - const supergraphSdl = await composeLocalSchemasWithApollo( - Object.entries(localSchemas).map(([name, schema]) => ({ - name, - schema, - url: `http://localhost/${name}`, - })), - ); + ) => void; + composeWith?: 'apollo' | 'guild'; + ignoreRules?: string[]; +}): Promise { + let supergraphSdl: string; + if (composeWith === 'apollo') { + supergraphSdl = await composeLocalSchemasWithApollo( + Object.entries(localSchemas).map(([name, schema]) => ({ + name, + schema, + url: `http://localhost/${name}`, + })), + ); + } else if (composeWith === 'guild') { + const result = composeServices( + Object.entries(localSchemas).map(([name, schema]) => ({ + name, + typeDefs: getDocumentNodeFromSchema(schema), + url: `http://localhost/${name}`, + })), + { disableValidationRules: ignoreRules }, + ); + result.errors?.forEach((error) => { + console.error(error); + }); + if (!result.supergraphSdl) { + throw new Error('Failed to compose services'); + } + supergraphSdl = result.supergraphSdl; + } else { + throw new Error(`Unknown composeWith ${composeWith}`); + } function createTracedExecutor(name: string, schema: GraphQLSchema) { const executor = createDefaultExecutor(schema); return function tracedExecutor(request: ExecutionRequest) { diff --git a/packages/federation/tests/optimizations.test.ts b/packages/federation/tests/optimizations.test.ts index e6adde137..dabe8a281 100644 --- a/packages/federation/tests/optimizations.test.ts +++ b/packages/federation/tests/optimizations.test.ts @@ -630,15 +630,15 @@ it('nested recursive requirements', async () => { const subgraphCalls: Record = {}; - const schema = await getStitchedSchemaFromLocalSchemas( - { + const schema = await getStitchedSchemaFromLocalSchemas({ + localSchemas: { inventory, spatial, }, - (subgraph) => { + onSubgraphExecute: (subgraph) => { subgraphCalls[subgraph] = (subgraphCalls[subgraph] || 0) + 1; }, - ); + }); expect( await normalizedExecutor({ diff --git a/packages/federation/tests/shared-root.test.ts b/packages/federation/tests/shared-root.test.ts index bb81cb9c9..2bb0162d9 100644 --- a/packages/federation/tests/shared-root.test.ts +++ b/packages/federation/tests/shared-root.test.ts @@ -1,6 +1,7 @@ import { buildSubgraphSchema } from '@apollo/subgraph'; import { normalizedExecutor } from '@graphql-tools/executor'; import { ExecutionRequest } from '@graphql-tools/utils'; +import { assertAsyncIterable } from '@internal/testing'; import { ExecutionResult, parse } from 'graphql'; import { describe, expect, it, vi } from 'vitest'; import { getStitchedSchemaFromLocalSchemas } from './getStitchedSchemaFromLocalSchemas'; @@ -99,8 +100,10 @@ describe('Shared Root Fields', () => { }); const gatewaySchema = await getStitchedSchemaFromLocalSchemas({ - subgraph1, - subgraph2, + localSchemas: { + subgraph1, + subgraph2, + }, }); const result = await normalizedExecutor({ @@ -157,13 +160,13 @@ describe('Shared Root Fields', () => { result: ExecutionResult | AsyncIterable, ) => void >(); - const gatewaySchema = await getStitchedSchemaFromLocalSchemas( - { + const gatewaySchema = await getStitchedSchemaFromLocalSchemas({ + localSchemas: { SUBGRAPHA, SUBGRAPHB, }, - onSubgraphExecuteFn, - ); + onSubgraphExecute: onSubgraphExecuteFn, + }); const result = await normalizedExecutor({ schema: gatewaySchema, @@ -245,13 +248,13 @@ describe('Shared Root Fields', () => { result: ExecutionResult | AsyncIterable, ) => void >(); - const gatewaySchema = await getStitchedSchemaFromLocalSchemas( - { + const gatewaySchema = await getStitchedSchemaFromLocalSchemas({ + localSchemas: { SUBGRAPHA, SUBGRAPHB, }, - onSubgraphExecuteFn, - ); + onSubgraphExecute: onSubgraphExecuteFn, + }); const resultA = await normalizedExecutor({ schema: gatewaySchema, @@ -296,4 +299,239 @@ describe('Shared Root Fields', () => { expect(onSubgraphExecuteFn).toHaveBeenCalledTimes(2); expect(onSubgraphExecuteFn.mock.calls[1]?.[0]).toBe('SUBGRAPHB'); }); + it('should choose the best subscription root field in case of multiple entry points(keys)', async () => { + interface Review { + id: string; + url: string; + comment: string; + } + const reviews: Review[] = [ + { + id: 'r1', + url: 'http://r1', + comment: 'Tractor 👍', + }, + { + id: 'r2', + url: 'http://r2', + comment: 'Washing machine 👎', + }, + ]; + const REVIEWS = buildSubgraphSchema({ + typeDefs: parse(/* GraphQL */ ` + type Query { + allReviews: [Review!]! + } + type Subscription { + newReview: Review! + } + type Review @key(fields: "id") @key(fields: "url") { + id: ID! + url: String! + comment: String! + } + `), + resolvers: { + Query: { + allReviews: () => reviews, + }, + Subscription: { + newReview: { + async *subscribe() { + yield { newReview: reviews[reviews.length - 1] }; + }, + }, + }, + }, + }); + + const gatewaySchema = await getStitchedSchemaFromLocalSchemas({ + localSchemas: { + REVIEWS, + }, + }); + + const newReviewSub = await normalizedExecutor({ + schema: gatewaySchema, + document: parse(/* GraphQL */ ` + subscription { + newReview { + id + } + } + `), + }); + assertAsyncIterable(newReviewSub); + const iter = newReviewSub[Symbol.asyncIterator](); + + await expect(iter.next()).resolves.toMatchInlineSnapshot(` + { + "done": false, + "value": { + "data": { + "newReview": { + "id": "r2", + }, + }, + }, + } + `); + + await expect(iter.next()).resolves.toMatchInlineSnapshot(` + { + "done": true, + "value": undefined, + } + `); + }); + it('should choose the best subscription root field in case of conflicting fields', async () => { + interface Event { + id: string; + message: string; + time: number; + } + const allEvents: Event[] = [ + { + id: 'e1', + message: 'Event 1', + time: 1, + }, + { + id: 'e2', + message: 'Event 2', + time: 2, + }, + ]; + const EVENTSWITHMESSAGES = buildSubgraphSchema({ + typeDefs: parse(/* GraphQL */ ` + schema + @link( + url: "https://specs.apollo.dev/federation/v2.5" + import: ["@key", "@shareable"] + ) { + query: Query + subscription: Subscription + } + + type Query { + allEventsWithMessage: [Event!]! + } + type Subscription { + newEvent: Event! @shareable + } + type Event @key(fields: "id") { + id: ID! + message: String! + } + `), + resolvers: { + Query: { + allEventsWithMessage: () => allEvents, + }, + Subscription: { + newEvent: { + async *subscribe() { + for (const event of allEvents) { + yield { newEvent: event }; + } + }, + }, + }, + }, + }); + const EVENTSWITHTIME = buildSubgraphSchema({ + typeDefs: parse(/* GraphQL */ ` + schema + @link( + url: "https://specs.apollo.dev/federation/v2.5" + import: ["@key", "@shareable"] + ) { + query: Query + subscription: Subscription + } + + type Query { + allEventsWithTime: [Event!]! + } + type Subscription { + newEvent: Event! @shareable + } + type Event @key(fields: "id") { + id: ID! + time: Int! + } + `), + resolvers: { + Query: { + allEventsWithTime: () => allEvents, + }, + Subscription: { + newEvent: { + async *subscribe() { + for (const event of allEvents) { + yield { newEvent: event }; + } + }, + }, + }, + }, + }); + + let subgraphCalls: Record = {}; + const gatewaySchema = await getStitchedSchemaFromLocalSchemas({ + localSchemas: { + EVENTSWITHMESSAGES, + EVENTSWITHTIME, + }, + composeWith: 'guild', + ignoreRules: ['InvalidFieldSharingRule'], + onSubgraphExecute(subgraph) { + subgraphCalls[subgraph] = (subgraphCalls[subgraph] || 0) + 1; + }, + }); + + const eventsWithMessageSub = await normalizedExecutor({ + schema: gatewaySchema, + document: parse(/* GraphQL */ ` + subscription { + newEvent { + message + } + } + `), + }); + assertAsyncIterable(eventsWithMessageSub); + const collectedEventsWithMessage: ExecutionResult[] = []; + for await (const result of eventsWithMessageSub) { + collectedEventsWithMessage.push(result); + } + expect(collectedEventsWithMessage).toEqual( + allEvents.map(({ message }) => ({ data: { newEvent: { message } } })), + ); + expect(subgraphCalls).toEqual({ + EVENTSWITHMESSAGES: 1, + }); + subgraphCalls = {}; + const eventsWithTimeSub = await normalizedExecutor({ + schema: gatewaySchema, + document: parse(/* GraphQL */ ` + subscription { + newEvent { + time + } + } + `), + }); + assertAsyncIterable(eventsWithTimeSub); + const collectedEventsWithTime: ExecutionResult[] = []; + for await (const result of eventsWithTimeSub) { + collectedEventsWithTime.push(result); + } + expect(collectedEventsWithTime).toEqual( + allEvents.map(({ time }) => ({ data: { newEvent: { time } } })), + ); + expect(subgraphCalls).toEqual({ + EVENTSWITHTIME: 1, + }); + }); });