Skip to content

Commit

Permalink
SHARD-1185: Rate limiting improvements (#98)
Browse files Browse the repository at this point in the history
* Add inactivity timeout and check interval for WebSocket connections

* Consolidate WebSocket inactivity checks into one interval and a map, enhancing performance and cleanup

* Refactor rate limit middleware and utility functions
  • Loading branch information
S0naliThakur authored Dec 30, 2024
1 parent c2420d4 commit 38d9c89
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 57 deletions.
4 changes: 4 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type Config = {
maxConnections: number // Maximum number of concurrent WebSocket connections
maxSubscriptionsPerSocket: number // Maximum number of subscriptions per socket
connectionTimeoutMs: number // Connection timeout in milliseconds (default 1 day)
inactivityTimeoutMs: number // 60 seconds inactivity timeout
inactivityCheckIntervalMs: number // Check every 10 seconds
}
trustProxy: boolean // Whether to trust the X-Forwarded-For header
log_server: {
Expand Down Expand Up @@ -122,6 +124,8 @@ export const CONFIG: Config = {
maxConnections: Number(process.env.WS_MAX_CONNECTIONS) || 1000,
maxSubscriptionsPerSocket: Number(process.env.WS_MAX_SUBSCRIPTIONS_PER_SOCKET) || 50,
connectionTimeoutMs: Number(process.env.WS_CONNECTION_TIMEOUT_MS) || 24 * 60 * 60 * 1000, // 1 day in ms
inactivityTimeoutMs: 60000, // 60 seconds inactivity timeout
inactivityCheckIntervalMs: 10000, // Check every 10 seconds
},
trustProxy: false,
log_server: {
Expand Down
8 changes: 3 additions & 5 deletions src/middlewares/rateLimit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ export async function rateLimitMiddleware(req: Request, res: Response, next: Nex
const requests: RpcRequest[] = Array.isArray(req.body) ? req.body : [req.body]

try {
const results = await Promise.all(
requests.map(request => checkRequest(ip, request))
)
const results = await Promise.all(requests.map((request) => checkRequest(ip, request)))

// If any request is not okay, reject the entire batch
if (results.some(result => !result)) {
if (results.some((result) => !result)) {
await handleRejection(res, config.rateLimitOption.softReject)
return
}
Expand All @@ -56,4 +54,4 @@ export async function rateLimitMiddleware(req: Request, res: Response, next: Nex
console.error('Rate limiting error:', error)
res.status(500).send('Internal server error')
}
}
}
26 changes: 13 additions & 13 deletions src/tests/rateLimit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { rateLimitMiddleware } from '../middlewares/rateLimit'

describe('Rate Limiting', () => {
let requestersList: RequestersList

Check warning on line 6 in src/tests/rateLimit.test.ts

View workflow job for this annotation

GitHub Actions / ci / QA merge checks

'requestersList' is assigned a value but never used

beforeEach(() => {
requestersList = new RequestersList([], [])
})
Expand All @@ -14,13 +14,13 @@ describe('Rate Limiting', () => {
socket: { remoteAddress: '127.0.0.1' },
body: {
method: 'eth_call',
params: []
}
params: [],
},
} as Request

const mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn()
send: jest.fn(),
} as unknown as Response

const mockNext = jest.fn()
Expand All @@ -35,18 +35,18 @@ describe('Rate Limiting', () => {
body: [
{
method: 'eth_call',
params: []
params: [],
},
{
method: 'eth_getBalance',
params: []
}
]
params: [],
},
],
} as Request

const mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn()
send: jest.fn(),
} as unknown as Response

const mockNext = jest.fn()
Expand All @@ -60,13 +60,13 @@ describe('Rate Limiting', () => {
socket: { remoteAddress: '127.0.0.1' },
body: {
method: 'eth_sendRawTransaction',
params: []
}
params: [],
},
} as Request

const mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn()
send: jest.fn(),
} as unknown as Response

const mockNext = jest.fn()
Expand All @@ -79,4 +79,4 @@ describe('Rate Limiting', () => {
expect(mockRes.status).toHaveBeenCalledWith(503)
expect(mockRes.send).toHaveBeenCalledWith('Rejected by rate-limiting')
})
})
})
12 changes: 6 additions & 6 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {
OriginalTxData,
AccountTypesData,
Account2,
InternalFilter
InternalFilter,
} from './types'
import Sntp from '@hapi/sntp'
import { randomBytes, createHash } from 'crypto'
Expand Down Expand Up @@ -1873,20 +1873,20 @@ export function sanitizeIpAndPort(ipPort: string): { isValid: boolean; error?: s
}

export function removeOldestFilter(filtersMap: Map<string, InternalFilter>): void {
let oldestKey: string | undefined;
let oldestTimestamp = Infinity;
let oldestKey: string | undefined
let oldestTimestamp = Infinity

// Iterate through the map to find the oldest entry
for (const [key, value] of filtersMap) {
if (value.filter.lastQueriedTimestamp < oldestTimestamp) {
oldestTimestamp = value.filter.lastQueriedTimestamp;
oldestKey = key;
oldestTimestamp = value.filter.lastQueriedTimestamp
oldestKey = key
}
}

// Remove the oldest entry
if (oldestKey !== undefined) {
filtersMap.delete(oldestKey);
filtersMap.delete(oldestKey)
}
}

Expand Down
26 changes: 23 additions & 3 deletions src/websocket/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import WebSocket from 'ws'
import EventEmitter from 'events'
import { methods, wrappedMethods } from '../api'
import { wrappedMethods } from '../api'
import { logSubscriptionList } from './clients'
import * as crypto from 'crypto'
import { CONFIG } from '../config'
Expand All @@ -24,6 +24,19 @@ interface Request {
// Add connection counter
let activeConnections = 0

const socketActivityMap = new Map<WebSocket.WebSocket, number>()

// Single interval for all connections
setInterval(() => {
const now = Date.now()
socketActivityMap.forEach((lastActivity, socket) => {
if (now - lastActivity > CONFIG.websocket.inactivityTimeoutMs) {
socket.close(1011, 'Connection inactive for too long')
socketActivityMap.delete(socket)
}
})
}, CONFIG.websocket.inactivityCheckIntervalMs)

export const onConnection = async (socket: WebSocket.WebSocket): Promise<void> => {
// Check max connections limit
if (activeConnections >= CONFIG.websocket.maxConnections) {
Expand All @@ -32,6 +45,9 @@ export const onConnection = async (socket: WebSocket.WebSocket): Promise<void> =
}
activeConnections++

// Track last activity time
socketActivityMap.set(socket, Date.now())

// Set connection timeout
const timeoutId = setTimeout(() => {
socket.close(1011, 'Connection timeout reached')
Expand All @@ -40,6 +56,9 @@ export const onConnection = async (socket: WebSocket.WebSocket): Promise<void> =
const eth_methods = Object.freeze(wrappedMethods)

socket.on('message', (message: string) => {
// Update last activity time on message received
socketActivityMap.set(socket, Date.now())

if (CONFIG.verbose) console.log(`Received message: ${message}`)
nestedCountersInstance.countEvent('websocket', 'message-received')
let request: Request = {
Expand Down Expand Up @@ -190,7 +209,8 @@ export const onConnection = async (socket: WebSocket.WebSocket): Promise<void> =
})

socket.on('close', (code, reason) => {
// Clear timeout on close
// Clean up
socketActivityMap.delete(socket)
clearTimeout(timeoutId)

// Decrement connection counter
Expand All @@ -205,7 +225,7 @@ export const onConnection = async (socket: WebSocket.WebSocket): Promise<void> =
logSubscriptionList.removeBySocket(socket)
socket.close(code, reason)
}
if(CONFIG.verbose) console.log(logSubscriptionList.getAll())
if (CONFIG.verbose) console.log(logSubscriptionList.getAll())
})
}

Expand Down
61 changes: 31 additions & 30 deletions test/unit/middlewares/rateLimit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ jest.mock('../../../src/external/Collector', () => ({
Collector: jest.fn().mockImplementation(() => ({
fetchAccount: jest.fn(),
getBlock: jest.fn(),
getTransactionByHash: jest.fn()
getTransactionByHash: jest.fn(),
})),
collectorAPI: {
fetchAccount: jest.fn(),
getBlock: jest.fn(),
getTransactionByHash: jest.fn()
}
getTransactionByHash: jest.fn(),
},
}))

// Mock the RequestersList instance
jest.mock('../../../src/utils', () => {
return {
RequestersList: jest.fn().mockImplementation(() => ({
isRequestOkay: jest.fn()
isRequestOkay: jest.fn(),
})),
sleep: jest.fn().mockImplementation(() => Promise.resolve())
sleep: jest.fn().mockImplementation(() => Promise.resolve()),
}
})

Expand All @@ -32,9 +32,9 @@ jest.mock('../../../src/config', () => ({
rateLimit: true,
rateLimitOption: {
softReject: false,
allowedTxCountInCheckInterval: 60
}
}
allowedTxCountInCheckInterval: 60,
},
},
}))

describe('Rate Limiting', () => {
Expand All @@ -53,13 +53,13 @@ describe('Rate Limiting', () => {
ip: '127.0.0.1',
body: {
method: 'eth_call',
params: []
}
params: [],
},
} as Request

const mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn()
send: jest.fn(),
} as unknown as Response

const mockNext = jest.fn()
Expand All @@ -77,13 +77,13 @@ describe('Rate Limiting', () => {
ip: '127.0.0.1',
body: {
method: 'eth_call',
params: []
}
params: [],
},
} as Request

const mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn()
send: jest.fn(),
} as unknown as Response

const mockNext = jest.fn()
Expand All @@ -94,7 +94,8 @@ describe('Rate Limiting', () => {
})

it('should allow valid batch requests', async () => {
const spy = jest.spyOn(requestersList, 'isRequestOkay')
const spy = jest
.spyOn(requestersList, 'isRequestOkay')
.mockResolvedValueOnce(true)
.mockResolvedValueOnce(true)

Expand All @@ -103,18 +104,18 @@ describe('Rate Limiting', () => {
body: [
{
method: 'eth_call',
params: []
params: [],
},
{
method: 'eth_getBalance',
params: []
}
]
params: [],
},
],
} as Request

const mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn()
send: jest.fn(),
} as unknown as Response

const mockNext = jest.fn()
Expand All @@ -131,13 +132,13 @@ describe('Rate Limiting', () => {
ip: '127.0.0.1',
body: {
method: 'eth_sendRawTransaction',
params: []
}
params: [],
},
} as Request

const mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn()
send: jest.fn(),
} as unknown as Response

const mockNext = jest.fn()
Expand All @@ -159,13 +160,13 @@ describe('Rate Limiting', () => {
ip: '127.0.0.1',
body: {
method: 'eth_sendRawTransaction',
params: []
}
params: [],
},
} as Request

const mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn()
send: jest.fn(),
} as unknown as Response

const mockNext = jest.fn()
Expand All @@ -186,13 +187,13 @@ describe('Rate Limiting', () => {
ip: '127.0.0.1',
body: {
method: 'eth_call',
params: []
}
params: [],
},
} as Request

const mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn()
send: jest.fn(),
} as unknown as Response

const mockNext = jest.fn()
Expand All @@ -202,4 +203,4 @@ describe('Rate Limiting', () => {
expect(mockRes.send).toHaveBeenCalledWith('Internal server error')
expect(mockNext).not.toHaveBeenCalled()
})
})
})

0 comments on commit 38d9c89

Please sign in to comment.