diff --git a/__tests__/rpcProvider.test.ts b/__tests__/rpcProvider.test.ts index 6346fc61c..1ac69d391 100644 --- a/__tests__/rpcProvider.test.ts +++ b/__tests__/rpcProvider.test.ts @@ -339,4 +339,25 @@ describeIfRpc('RPCProvider', () => { expect(syncingStats).toMatchSchemaRef('GetSyncingStatsResponse'); }); }); + + describeIfRpc('Fallback node', () => { + beforeAll(() => {}); + test('Ensure fallback node is used when base node fails', async () => { + const provider: RpcProvider = new RpcProvider({ + nodeUrl: 'http://[1080:0:0:0:8:800:200C:417A]', + fallbackNodeUrls: [process.env.TEST_RPC_URL!], + }); + const blockNumber = await provider.getBlockNumber(); + expect(typeof blockNumber).toBe('number'); + }); + }); + + test('Ensure fallback nodes are run until any of them succeeds', async () => { + const provider: RpcProvider = new RpcProvider({ + nodeUrl: 'Incorrect URL', + fallbackNodeUrls: ['Another incorrect URL', process.env.TEST_RPC_URL!], + }); + const blockNumber = await provider.getBlockNumber(); + expect(typeof blockNumber).toBe('number'); + }); }); diff --git a/src/provider/rpc.ts b/src/provider/rpc.ts index c59b1cc58..526101a4c 100644 --- a/src/provider/rpc.ts +++ b/src/provider/rpc.ts @@ -27,6 +27,7 @@ import { getSimulateTransactionOptions, waitForTransactionOptions, } from '../types'; +import assert from '../utils/assert'; import { CallData } from '../utils/calldata'; import { getAbiContractVersion } from '../utils/calldata/cairo'; import { isSierra } from '../utils/contract'; @@ -68,8 +69,6 @@ const defaultOptions = { }; export class RpcProvider implements ProviderInterface { - public nodeUrl: string; - public headers: object; private responseParser = new RPCResponseParser(); @@ -80,37 +79,49 @@ export class RpcProvider implements ProviderInterface { private chainId?: StarknetChainId; + public nodeUrls: string[]; + constructor(optionsOrProvider?: RpcProviderOptions) { - const { nodeUrl, retries, headers, blockIdentifier, chainId, rpcVersion } = + const { nodeUrl, retries, headers, blockIdentifier, chainId, rpcVersion, fallbackNodeUrls } = optionsOrProvider || {}; + let primaryNode; if (Object.values(NetworkName).includes(nodeUrl as NetworkName)) { // Network name provided for nodeUrl - this.nodeUrl = getDefaultNodeUrl( + primaryNode = getDefaultNodeUrl( nodeUrl as NetworkName, optionsOrProvider?.default, rpcVersion ); } else if (nodeUrl) { // NodeUrl provided - this.nodeUrl = nodeUrl; + primaryNode = nodeUrl; } else { // none provided fallback to default testnet - this.nodeUrl = getDefaultNodeUrl(undefined, optionsOrProvider?.default, rpcVersion); + primaryNode = getDefaultNodeUrl(undefined, optionsOrProvider?.default, rpcVersion); } this.retries = retries || defaultOptions.retries; this.headers = { ...defaultOptions.headers, ...headers }; this.blockIdentifier = blockIdentifier || defaultOptions.blockIdentifier; this.chainId = chainId; // setting to a non-null value skips making a request in getChainId() + this.nodeUrls = [primaryNode, ...(fallbackNodeUrls || [])]; + } + + get nodeUrl() { + return this.nodeUrls[0]; + } + + set nodeUrl(url) { + this.nodeUrls[0] = url; } - public fetch(method: string, params?: object, id: string | number = 0) { + public fetch(url: string, method: string, params?: object, id: string | number = 0) { const rpcRequestBody: RPC.JRPC.RequestBody = { id, jsonrpc: '2.0', method, ...(params && { params }), }; - return fetch(this.nodeUrl, { + return fetch(url, { method: 'POST', body: stringify(rpcRequestBody), headers: this.headers as Record, @@ -132,13 +143,52 @@ export class RpcProvider implements ProviderInterface { } } + protected async setPrimaryNode(node: string, index: number) { + // eslint-disable-next-line prefer-destructuring + this.nodeUrls[index] = this.nodeUrls[0]; + this.nodeUrls[0] = node; + } + + protected async fetchResponse(method: string, params?: object) { + const nodes = [...this.nodeUrls]; + const lastNode = nodes.pop(); + assert(lastNode !== undefined); + let response; + for (let i = 0; i < nodes.length - 1; i += 1) { + try { + // eslint-disable-next-line no-await-in-loop + response = await this.fetch(nodes[i], method, params); + + if (response.ok) { + this.setPrimaryNode(nodes[i], i); + return response; + } + } catch (error: any) { + /* empty */ + } + } + + // If all nodes fail return anything the last one returned + try { + response = await this.fetch(lastNode, method, params); + if (response.ok) { + this.setPrimaryNode(lastNode, this.nodeUrls.length - 1); + } + return response; + } catch (error: any) { + this.errorHandler(method, params, error?.response?.data, error); + throw error; + } + } + protected async fetchEndpoint( method: T, params?: RPC.Methods[T]['params'] ): Promise { + const response = await this.fetchResponse(method, params); + try { - const rawResult = await this.fetch(method, params); - const { error, result } = await rawResult.json(); + const { error, result } = await response.json(); this.errorHandler(method, params, error); return result as RPC.Methods[T]['result']; } catch (error: any) { diff --git a/src/types/provider/configuration.ts b/src/types/provider/configuration.ts index e7640aa26..06e03cd93 100644 --- a/src/types/provider/configuration.ts +++ b/src/types/provider/configuration.ts @@ -13,6 +13,7 @@ export type RpcProviderOptions = { blockIdentifier?: BlockIdentifier; chainId?: StarknetChainId; default?: boolean; + fallbackNodeUrls?: string[]; rpcVersion?: 'v0_5' | 'v0_6'; };