diff --git a/packages/rpc-subscriptions-spec/README.md b/packages/rpc-subscriptions-spec/README.md index 2250bbd82d7..2a65e25156b 100644 --- a/packages/rpc-subscriptions-spec/README.md +++ b/packages/rpc-subscriptions-spec/README.md @@ -55,3 +55,14 @@ Subscription channels publish events on two channel names: - `error`: Fires when the channel closes unexpectedly - `message`: Fires on every message received from the remote end + +## Functions + +### `executeRpcPubSubSubscriptionPlan({ channel, responseTransformer, signal, subscribeMethodName, subscribeParams, unsubscribeMethodName })` + +Given a channel, this function executes the particular subscription plan required by the Solana JSON RPC Subscriptions API. + +1. Calls the `subscribeMethodName` on the remote RPC +2. Waits for a response containing the subscription id +3. Returns a `DataPublisher` that publishes notifications related to that subscriptions id, filtering out all others +4. Calls the `unsubscribeMethodName` on the remote RPC when the abort signal is fired. diff --git a/packages/rpc-subscriptions-spec/package.json b/packages/rpc-subscriptions-spec/package.json index 98eac9cf322..47e3da4e899 100644 --- a/packages/rpc-subscriptions-spec/package.json +++ b/packages/rpc-subscriptions-spec/package.json @@ -72,6 +72,7 @@ ], "dependencies": { "@solana/errors": "workspace:*", + "@solana/promises": "workspace:*", "@solana/rpc-spec-types": "workspace:*", "@solana/subscribable": "workspace:*" }, diff --git a/packages/rpc-subscriptions-spec/src/__tests__/rpc-subscriptions-pubsub-plan-test.ts b/packages/rpc-subscriptions-spec/src/__tests__/rpc-subscriptions-pubsub-plan-test.ts new file mode 100644 index 00000000000..995a7c34c0a --- /dev/null +++ b/packages/rpc-subscriptions-spec/src/__tests__/rpc-subscriptions-pubsub-plan-test.ts @@ -0,0 +1,316 @@ +import { + SOLANA_ERROR__INVARIANT_VIOLATION__DATA_PUBLISHER_CHANNEL_UNIMPLEMENTED, + SOLANA_ERROR__RPC_SUBSCRIPTIONS__EXPECTED_SERVER_SUBSCRIPTION_ID, + SolanaError, +} from '@solana/errors'; +import { DataPublisher } from '@solana/subscribable'; + +import { RpcSubscriptionChannelEvents, RpcSubscriptionsChannel } from '../rpc-subscriptions-channel'; +import { executeRpcPubSubSubscriptionPlan } from '../rpc-subscriptions-pubsub-plan'; + +let mockId = 0; +let lastMessageId: number; +jest.mock('@solana/rpc-spec-types', () => ({ + ...jest.requireActual('@solana/rpc-spec-types'), + createRpcMessage(...args: never[]) { + lastMessageId = mockId++; + return { + ...jest.requireActual('@solana/rpc-spec-types').createRpcMessage(...args), + id: lastMessageId, + }; + }, +})); + +describe('executeRpcPubSubSubscriptionPlan', () => { + let abortController: AbortController; + let mockChannel: { on: jest.Mock; send: unknown }; + let mockSend: jest.Mock; + function receiveError(err?: unknown) { + mockChannel.on.mock.calls.filter(([type]) => type === 'error').forEach(([_, listener]) => listener(err)); + } + function receiveMessage(message: unknown) { + mockChannel.on.mock.calls.filter(([type]) => type === 'message').forEach(([_, listener]) => listener(message)); + } + beforeEach(() => { + abortController = new AbortController(); + mockSend = jest.fn().mockResolvedValue(void 0); + mockChannel = { + on: jest.fn().mockReturnValue(() => {}), + send: mockSend, + }; + }); + it('rejects when already aborted', async () => { + expect.assertions(1); + const abortController = new AbortController(); + abortController.abort(); + const publisherPromise = executeRpcPubSubSubscriptionPlan({ + channel: mockChannel as RpcSubscriptionsChannel, + signal: abortController.signal, + subscribeMethodName: 'thingSubscribe', + subscribeParams: [], + unsubscribeMethodName: 'thingUnsubscribe', + }); + await expect(publisherPromise).rejects.toThrow(); + }); + it('subscribes to the channel for errors', () => { + executeRpcPubSubSubscriptionPlan({ + channel: mockChannel as RpcSubscriptionsChannel, + signal: abortController.signal, + subscribeMethodName: 'thingSubscribe', + subscribeParams: [], + unsubscribeMethodName: 'thingUnsubscribe', + }); + expect(mockChannel.on).toHaveBeenCalledWith('error', expect.any(Function), { + signal: abortController.signal, + }); + }); + it('sends the expected subscribe message', () => { + const expectedParams = [1, 2, 3]; + executeRpcPubSubSubscriptionPlan({ + channel: mockChannel as RpcSubscriptionsChannel, + signal: abortController.signal, + subscribeMethodName: 'thingSubscribe', + subscribeParams: expectedParams, + unsubscribeMethodName: 'thingUnsubscribe', + }); + expect(mockSend).toHaveBeenCalledWith( + expect.objectContaining({ + id: expect.any(Number), + jsonrpc: '2.0', + method: 'thingSubscribe', + params: expectedParams, + }), + ); + }); + describe('given that the subscribe message fails to send', () => { + beforeEach(() => { + mockSend.mockRejectedValue('o no'); + }); + it("rejects with the send method's rejection", async () => { + expect.assertions(1); + const publisherPromise = executeRpcPubSubSubscriptionPlan({ + channel: mockChannel as RpcSubscriptionsChannel, + signal: abortController.signal, + subscribeMethodName: 'thingSubscribe', + subscribeParams: [], + unsubscribeMethodName: 'thingUnsubscribe', + }); + await expect(publisherPromise).rejects.toBe('o no'); + }); + it('does not send an unsubscribe message when aborted', () => { + expect.assertions(1); + mockSend.mockClear(); + abortController.abort(); + expect(mockSend).not.toHaveBeenCalled(); + }); + }); + describe('given that the server has not yet acknowledged the subscription', () => { + let publisherPromise: ReturnType; + beforeEach(() => { + publisherPromise = executeRpcPubSubSubscriptionPlan({ + channel: mockChannel as RpcSubscriptionsChannel, + signal: abortController.signal, + subscribeMethodName: 'thingSubscribe', + subscribeParams: [], + unsubscribeMethodName: 'thingUnsubscribe', + }); + }); + afterEach(() => { + publisherPromise.catch(() => {}); + }); + it('rejects when aborted', async () => { + expect.assertions(1); + abortController.abort(); + await expect(publisherPromise).rejects.toThrow(); + }); + it('does not send an unsubscribe message when aborted', () => { + expect.assertions(1); + mockSend.mockClear(); + abortController.abort(); + expect(mockSend).not.toHaveBeenCalled(); + }); + }); + it("throws when the server's subscription acknowledgement does not contain a subscription id number", async () => { + expect.assertions(1); + const publisherPromise = executeRpcPubSubSubscriptionPlan({ + channel: mockChannel as RpcSubscriptionsChannel, + signal: abortController.signal, + subscribeMethodName: 'thingSubscribe', + subscribeParams: [], + unsubscribeMethodName: 'thingUnsubscribe', + }); + await Promise.resolve(); + receiveMessage({ id: lastMessageId, jsonrpc: '2.0', result: undefined }); + await expect(publisherPromise).rejects.toThrow( + new SolanaError(SOLANA_ERROR__RPC_SUBSCRIPTIONS__EXPECTED_SERVER_SUBSCRIPTION_ID), + ); + }); + describe('given that the server has already acknowledged the subscription', () => { + let expectedSubscriptionId: number; + let publisherPromise: Promise< + DataPublisher, 'message'> & { notification: unknown }> + >; + let mockResponseTransformer: jest.Mock; + beforeEach(async () => { + jest.useFakeTimers(); + mockResponseTransformer = jest.fn().mockImplementation(result => result); + publisherPromise = executeRpcPubSubSubscriptionPlan({ + channel: mockChannel as RpcSubscriptionsChannel, + responseTransformer: mockResponseTransformer, + signal: abortController.signal, + subscribeMethodName: 'thingSubscribe', + subscribeParams: [], + unsubscribeMethodName: 'thingUnsubscribe', + }); + await jest.runAllTimersAsync(); + receiveMessage({ id: lastMessageId, jsonrpc: '2.0', result: (expectedSubscriptionId = 123) }); + }); + it('publishes errors', async () => { + expect.assertions(1); + const publisher = await publisherPromise; + const errorListener = jest.fn(); + publisher.on('error', errorListener); + receiveError('o no'); + expect(errorListener).toHaveBeenCalledWith('o no'); + }); + it('publishes notifications that match this subscription id', async () => { + expect.assertions(1); + const publisher = await publisherPromise; + const notificationListener = jest.fn(); + publisher.on('notification', notificationListener); + receiveMessage({ + jsonrpc: '2.0', + method: 'thingNotification', + params: { + result: 'hi', + subscription: expectedSubscriptionId, + }, + }); + expect(notificationListener).toHaveBeenCalledWith('hi'); + }); + it('throws when a caller tries to listen to an unsupported channel', async () => { + expect.assertions(1); + const publisher = await publisherPromise; + const badListener = jest.fn(); + expect(() => { + publisher.on( + // @ts-expect-error This test supplies a bad event name on purpose. + 'bad', + badListener, + ); + }).toThrow( + new SolanaError(SOLANA_ERROR__INVARIANT_VIOLATION__DATA_PUBLISHER_CHANNEL_UNIMPLEMENTED, { + channelName: 'bad', + supportedChannelNames: ['notification', 'error'], + }), + ); + }); + it('publishes notifications transformed by the response transformer that match this subscription id', async () => { + expect.assertions(1); + mockResponseTransformer.mockImplementation(result => `now hear this: ${result}`); + const publisher = await publisherPromise; + const notificationListener = jest.fn(); + publisher.on('notification', notificationListener); + receiveMessage({ + jsonrpc: '2.0', + method: 'thingNotification', + params: { + result: 'hi', + subscription: expectedSubscriptionId, + }, + }); + expect(notificationListener).toHaveBeenCalledWith('now hear this: hi'); + }); + it('calls the response transformer only once per notification, even when there are multiple subscribers', async () => { + expect.assertions(1); + const publisher = await publisherPromise; + const notificationListenerA = jest.fn(); + const notificationListenerB = jest.fn(); + publisher.on('notification', notificationListenerA); + publisher.on('notification', notificationListenerB); + receiveMessage({ + jsonrpc: '2.0', + method: 'thingNotification', + params: { + result: 'hi', + subscription: expectedSubscriptionId, + }, + }); + expect(mockResponseTransformer).toHaveBeenCalledTimes(1); + }); + it("does not publish notifications that don't match this subscription id", async () => { + expect.assertions(1); + const publisher = await publisherPromise; + const notificationListener = jest.fn(); + publisher.on('notification', notificationListener); + receiveMessage({ + jsonrpc: '2.0', + method: 'thingNotification', + params: { + result: 'hi', + subscription: expectedSubscriptionId + 1, + }, + }); + expect(notificationListener).not.toHaveBeenCalled(); + }); + it('sends an unsubscribe message when aborted', () => { + expect.assertions(1); + mockSend.mockClear(); + abortController.abort(); + expect(mockSend).toHaveBeenCalledWith( + expect.objectContaining({ + id: expect.any(Number), + jsonrpc: '2.0', + method: 'thingUnsubscribe', + params: [expectedSubscriptionId], + }), + ); + }); + describe('but then later errors', () => { + beforeEach(() => { + receiveError('o no'); + }); + it('does not send an unsubscribe message when aborted', () => { + expect.assertions(1); + mockSend.mockClear(); + abortController.abort(); + expect(mockSend).not.toHaveBeenCalled(); + }); + }); + describe('and then acknowledges a subsequent subscription with the same subscription id', () => { + let secondAbortController: AbortController; + beforeEach(async () => { + jest.useFakeTimers(); + secondAbortController = new AbortController(); + executeRpcPubSubSubscriptionPlan({ + channel: mockChannel as RpcSubscriptionsChannel, + signal: secondAbortController.signal, + subscribeMethodName: 'thingSubscribe', + subscribeParams: [], + unsubscribeMethodName: 'thingUnsubscribe', + }); + await jest.runAllTimersAsync(); + receiveMessage({ id: lastMessageId, jsonrpc: '2.0', result: (expectedSubscriptionId = 123) }); + }); + /** + * Because of https://github.com/solana-labs/solana/pull/18943, two subscriptions for + * materially the same notification will be coalesced on the server. This means they + * will be assigned the same subscription id, and will occupy one subscription slot. We + * must be careful not to send the unsubscribe message until the last subscriber aborts. + */ + it('does not send the unsubscribe message when fewer than all of the subscriptions are aborted', () => { + mockSend.mockClear(); + abortController.abort(); + expect(mockSend).not.toHaveBeenCalled(); + }); + it('sends the unsubscribe message once all of the subscriptions abort', () => { + mockSend.mockClear(); + abortController.abort(); + secondAbortController.abort(); + expect(mockSend).toHaveBeenCalledWith( + expect.objectContaining({ method: 'thingUnsubscribe', params: [expectedSubscriptionId] }), + ); + }); + }); + }); +}); diff --git a/packages/rpc-subscriptions-spec/src/index.ts b/packages/rpc-subscriptions-spec/src/index.ts index 33520b89898..ca4cce0ffaf 100644 --- a/packages/rpc-subscriptions-spec/src/index.ts +++ b/packages/rpc-subscriptions-spec/src/index.ts @@ -2,4 +2,5 @@ export * from './rpc-subscriptions-request'; export * from './rpc-subscriptions'; export * from './rpc-subscriptions-api'; export * from './rpc-subscriptions-channel'; +export * from './rpc-subscriptions-pubsub-plan'; export * from './rpc-subscriptions-transport'; diff --git a/packages/rpc-subscriptions-spec/src/rpc-subscriptions-pubsub-plan.ts b/packages/rpc-subscriptions-spec/src/rpc-subscriptions-pubsub-plan.ts new file mode 100644 index 00000000000..10dbe3b1c01 --- /dev/null +++ b/packages/rpc-subscriptions-spec/src/rpc-subscriptions-pubsub-plan.ts @@ -0,0 +1,217 @@ +import { + getSolanaErrorFromJsonRpcError, + SOLANA_ERROR__INVARIANT_VIOLATION__DATA_PUBLISHER_CHANNEL_UNIMPLEMENTED, + SOLANA_ERROR__RPC_SUBSCRIPTIONS__EXPECTED_SERVER_SUBSCRIPTION_ID, + SolanaError, +} from '@solana/errors'; +import { safeRace } from '@solana/promises'; +import { createRpcMessage, RpcResponseData } from '@solana/rpc-spec-types'; +import { DataPublisher } from '@solana/subscribable'; + +import { RpcSubscriptionChannelEvents } from './rpc-subscriptions-channel'; +import { RpcSubscriptionsChannel } from './rpc-subscriptions-channel'; +import { demultiplexDataPublisher } from './rpc-subscriptions-pubsub-demultiplex'; + +type Config = Readonly<{ + channel: RpcSubscriptionsChannel | RpcResponseData>; + responseTransformer?: (response: unknown, notificationName: string) => T; + signal: AbortSignal; + subscribeMethodName: string; + subscribeParams?: unknown[]; + unsubscribeMethodName: string; +}>; + +type RpcNotification = Readonly<{ + method: string; + params: Readonly<{ + result: TNotification; + subscription: number; + }>; +}>; + +type RpcSubscriptionId = number; + +type RpcSubscriptionNotificationEvents = Omit, 'message'> & { + notification: TNotification; +}; + +const subscriberCountBySubscriptionIdByChannel = new WeakMap(); +function decrementSubscriberCountAndReturnNewCount(channel: WeakKey, subscriptionId?: number): number | undefined { + return augmentSubscriberCountAndReturnNewCount(-1, channel, subscriptionId); +} +function incrementSubscriberCount(channel: WeakKey, subscriptionId?: number): void { + augmentSubscriberCountAndReturnNewCount(1, channel, subscriptionId); +} +function augmentSubscriberCountAndReturnNewCount( + amount: -1 | 1, + channel: WeakKey, + subscriptionId?: number, +): number | undefined { + if (subscriptionId === undefined) { + return; + } + let subscriberCountBySubscriptionId = subscriberCountBySubscriptionIdByChannel.get(channel); + if (!subscriberCountBySubscriptionId && amount > 0) { + subscriberCountBySubscriptionIdByChannel.set( + channel, + (subscriberCountBySubscriptionId = { [subscriptionId]: 0 }), + ); + } + if (subscriberCountBySubscriptionId[subscriptionId] !== undefined) { + return (subscriberCountBySubscriptionId[subscriptionId] = + amount + subscriberCountBySubscriptionId[subscriptionId]); + } +} + +const cache = new WeakMap(); +function getMemoizedDemultiplexedNotificationPublisherFromChannelAndResponseTransformer( + channel: RpcSubscriptionsChannel>, + responseTransformer?: (response: unknown, notificationName: string) => T, +): DataPublisher<{ + [channelName: `notification:${number}`]: TNotification; +}> { + let publisherByResponseTransformer = cache.get(channel); + if (!publisherByResponseTransformer) { + cache.set(channel, (publisherByResponseTransformer = new WeakMap())); + } + const responseTransformerKey = responseTransformer ?? channel; + let publisher = publisherByResponseTransformer.get(responseTransformerKey); + if (!publisher) { + publisherByResponseTransformer.set( + responseTransformerKey, + (publisher = demultiplexDataPublisher(channel, 'message', rawMessage => { + const message = rawMessage as RpcNotification | RpcResponseData; + if (!('method' in message)) { + return; + } + const transformedNotification = responseTransformer + ? responseTransformer(message.params.result, message.method) + : message.params.result; + return [`notification:${message.params.subscription}`, transformedNotification]; + })), + ); + } + return publisher; +} + +export async function executeRpcPubSubSubscriptionPlan({ + channel, + responseTransformer, + signal, + subscribeMethodName, + subscribeParams, + unsubscribeMethodName, +}: Config): Promise>> { + let subscriptionId: number | undefined; + channel.on( + 'error', + () => { + // An error on the channel indicates that the subscriptions are dead. + // There is no longer any sense hanging on to subscription ids. + // Erasing it here will prevent the unsubscribe code from running. + subscriptionId = undefined; + subscriberCountBySubscriptionIdByChannel.delete(channel); + }, + { signal }, + ); + /** + * STEP 1 + * Create a promise that rejects if this subscription is aborted and sends + * the unsubscribe message if the subscription is active at that time. + */ + const abortPromise = new Promise((_, reject) => { + function handleAbort(this: AbortSignal) { + /** + * Because of https://github.com/solana-labs/solana/pull/18943, two subscriptions for + * materially the same notification will be coalesced on the server. This means they + * will be assigned the same subscription id, and will occupy one subscription slot. We + * must be careful not to send the unsubscribe message until the last subscriber aborts. + */ + if (decrementSubscriberCountAndReturnNewCount(channel, subscriptionId) === 0) { + const unsubscribePayload = createRpcMessage(unsubscribeMethodName, [subscriptionId]); + subscriptionId = undefined; + channel.send(unsubscribePayload).catch(() => {}); + } + reject(this.reason); + } + if (signal.aborted) { + handleAbort.call(signal); + } else { + signal.addEventListener('abort', handleAbort); + } + }); + /** + * STEP 2 + * Send the subscription request. + */ + const subscribePayload = createRpcMessage(subscribeMethodName, subscribeParams); + await channel.send(subscribePayload); + /** + * STEP 3 + * Wait for the acknowledgement from the server with the subscription id. + */ + const subscriptionIdPromise = new Promise((resolve, reject) => { + const abortController = new AbortController(); + signal.addEventListener('abort', abortController.abort.bind(abortController)); + const options = { signal: abortController.signal } as const; + channel.on( + 'error', + err => { + abortController.abort(); + reject(err); + }, + options, + ); + channel.on( + 'message', + message => { + if (message && typeof message === 'object' && 'id' in message && message.id === subscribePayload.id) { + abortController.abort(); + if ('error' in message) { + reject(getSolanaErrorFromJsonRpcError(message.error)); + } else { + resolve(message.result); + } + } + }, + options, + ); + }); + subscriptionId = await safeRace([abortPromise, subscriptionIdPromise]); + if (subscriptionId == null) { + throw new SolanaError(SOLANA_ERROR__RPC_SUBSCRIPTIONS__EXPECTED_SERVER_SUBSCRIPTION_ID); + } + incrementSubscriberCount(channel, subscriptionId); + /** + * STEP 4 + * Filter out notifications unrelated to this subscription. + */ + const notificationPublisher = getMemoizedDemultiplexedNotificationPublisherFromChannelAndResponseTransformer( + channel, + responseTransformer, + ); + const notificationKey = `notification:${subscriptionId}` as const; + return { + on(type, listener, options) { + switch (type) { + case 'notification': + return notificationPublisher.on( + notificationKey, + listener as (data: RpcSubscriptionNotificationEvents['notification']) => void, + options, + ); + case 'error': + return channel.on( + 'error', + listener as (data: RpcSubscriptionNotificationEvents['error']) => void, + options, + ); + default: + throw new SolanaError(SOLANA_ERROR__INVARIANT_VIOLATION__DATA_PUBLISHER_CHANNEL_UNIMPLEMENTED, { + channelName: type, + supportedChannelNames: ['notification', 'error'], + }); + } + }, + }; +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 2e70bd50654..5c4251034b9 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -915,6 +915,9 @@ importers: '@solana/errors': specifier: workspace:* version: link:../errors + '@solana/promises': + specifier: workspace:* + version: link:../promises '@solana/rpc-spec-types': specifier: workspace:* version: link:../rpc-spec-types