Skip to content

Commit

Permalink
Merge pull request from GHSA-7v5v-9h63-cj86
Browse files Browse the repository at this point in the history
grpc-js: Avoid buffering significantly more than max_receive_message_size per received message (1.9.x)
  • Loading branch information
murgatroid99 authored Jun 10, 2024
2 parents 02d0344 + c75e048 commit 08b0422
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 146 deletions.
2 changes: 1 addition & 1 deletion packages/grpc-js/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@grpc/grpc-js",
"version": "1.9.14",
"version": "1.9.15",
"description": "gRPC Library for Node - pure JS implementation",
"homepage": "https://grpc.io/",
"repository": "https://github.com/grpc/grpc-node/tree/master/packages/grpc-js",
Expand Down
67 changes: 51 additions & 16 deletions packages/grpc-js/src/compression-filter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { WriteObject, WriteFlags } from './call-interface';
import { Channel } from './channel';
import { ChannelOptions } from './channel-options';
import { CompressionAlgorithms } from './compression-algorithms';
import { LogVerbosity } from './constants';
import { DEFAULT_MAX_RECEIVE_MESSAGE_LENGTH, LogVerbosity, Status } from './constants';
import { BaseFilter, Filter, FilterFactory } from './filter';
import * as logging from './logging';
import { Metadata, MetadataValue } from './metadata';
Expand Down Expand Up @@ -98,6 +98,10 @@ class IdentityHandler extends CompressionHandler {
}

class DeflateHandler extends CompressionHandler {
constructor(private maxRecvMessageLength: number) {
super();
}

compressMessage(message: Buffer) {
return new Promise<Buffer>((resolve, reject) => {
zlib.deflate(message, (err, output) => {
Expand All @@ -112,18 +116,34 @@ class DeflateHandler extends CompressionHandler {

decompressMessage(message: Buffer) {
return new Promise<Buffer>((resolve, reject) => {
zlib.inflate(message, (err, output) => {
if (err) {
reject(err);
} else {
resolve(output);
let totalLength = 0;
const messageParts: Buffer[] = [];
const decompresser = zlib.createInflate();
decompresser.on('data', (chunk: Buffer) => {
messageParts.push(chunk);
totalLength += chunk.byteLength;
if (this.maxRecvMessageLength !== -1 && totalLength > this.maxRecvMessageLength) {
decompresser.destroy();
reject({
code: Status.RESOURCE_EXHAUSTED,
details: `Received message that decompresses to a size larger than ${this.maxRecvMessageLength}`
});
}
});
decompresser.on('end', () => {
resolve(Buffer.concat(messageParts));
});
decompresser.write(message);
decompresser.end();
});
}
}

class GzipHandler extends CompressionHandler {
constructor(private maxRecvMessageLength: number) {
super();
}

compressMessage(message: Buffer) {
return new Promise<Buffer>((resolve, reject) => {
zlib.gzip(message, (err, output) => {
Expand All @@ -138,13 +158,25 @@ class GzipHandler extends CompressionHandler {

decompressMessage(message: Buffer) {
return new Promise<Buffer>((resolve, reject) => {
zlib.unzip(message, (err, output) => {
if (err) {
reject(err);
} else {
resolve(output);
let totalLength = 0;
const messageParts: Buffer[] = [];
const decompresser = zlib.createGunzip();
decompresser.on('data', (chunk: Buffer) => {
messageParts.push(chunk);
totalLength += chunk.byteLength;
if (this.maxRecvMessageLength !== -1 && totalLength > this.maxRecvMessageLength) {
decompresser.destroy();
reject({
code: Status.RESOURCE_EXHAUSTED,
details: `Received message that decompresses to a size larger than ${this.maxRecvMessageLength}`
});
}
});
decompresser.on('end', () => {
resolve(Buffer.concat(messageParts));
});
decompresser.write(message);
decompresser.end();
});
}
}
Expand All @@ -169,14 +201,14 @@ class UnknownHandler extends CompressionHandler {
}
}

function getCompressionHandler(compressionName: string): CompressionHandler {
function getCompressionHandler(compressionName: string, maxReceiveMessageSize: number): CompressionHandler {
switch (compressionName) {
case 'identity':
return new IdentityHandler();
case 'deflate':
return new DeflateHandler();
return new DeflateHandler(maxReceiveMessageSize);
case 'gzip':
return new GzipHandler();
return new GzipHandler(maxReceiveMessageSize);
default:
return new UnknownHandler(compressionName);
}
Expand All @@ -186,6 +218,7 @@ export class CompressionFilter extends BaseFilter implements Filter {
private sendCompression: CompressionHandler = new IdentityHandler();
private receiveCompression: CompressionHandler = new IdentityHandler();
private currentCompressionAlgorithm: CompressionAlgorithm = 'identity';
private maxReceiveMessageLength: number;

constructor(
channelOptions: ChannelOptions,
Expand All @@ -195,6 +228,7 @@ export class CompressionFilter extends BaseFilter implements Filter {

const compressionAlgorithmKey =
channelOptions['grpc.default_compression_algorithm'];
this.maxReceiveMessageLength = channelOptions['grpc.max_receive_message_length'] ?? DEFAULT_MAX_RECEIVE_MESSAGE_LENGTH
if (compressionAlgorithmKey !== undefined) {
if (isCompressionAlgorithmKey(compressionAlgorithmKey)) {
const clientSelectedEncoding = CompressionAlgorithms[
Expand All @@ -215,7 +249,8 @@ export class CompressionFilter extends BaseFilter implements Filter {
) {
this.currentCompressionAlgorithm = clientSelectedEncoding;
this.sendCompression = getCompressionHandler(
this.currentCompressionAlgorithm
this.currentCompressionAlgorithm,
-1
);
}
} else {
Expand Down Expand Up @@ -247,7 +282,7 @@ export class CompressionFilter extends BaseFilter implements Filter {
if (receiveEncoding.length > 0) {
const encoding: MetadataValue = receiveEncoding[0];
if (typeof encoding === 'string') {
this.receiveCompression = getCompressionHandler(encoding);
this.receiveCompression = getCompressionHandler(encoding, this.maxReceiveMessageLength);
}
}
metadata.remove('grpc-encoding');
Expand Down
2 changes: 0 additions & 2 deletions packages/grpc-js/src/internal-channel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import {
} from './resolver';
import { trace } from './logging';
import { SubchannelAddress } from './subchannel-address';
import { MaxMessageSizeFilterFactory } from './max-message-size-filter';
import { mapProxyName } from './http_proxy';
import { GrpcUri, parseUri, uriToString } from './uri-parser';
import { ServerSurfaceCall } from './server-call';
Expand Down Expand Up @@ -393,7 +392,6 @@ export class InternalChannel {
}
);
this.filterStackFactory = new FilterStackFactory([
new MaxMessageSizeFilterFactory(this.options),
new CompressionFilterFactory(this, this.options),
]);
this.trace(
Expand Down
88 changes: 0 additions & 88 deletions packages/grpc-js/src/max-message-size-filter.ts

This file was deleted.

87 changes: 56 additions & 31 deletions packages/grpc-js/src/server-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import { EventEmitter } from 'events';
import * as http2 from 'http2';
import { Duplex, Readable, Writable } from 'stream';
import * as zlib from 'zlib';
import { promisify } from 'util';

import {
Status,
Expand All @@ -38,8 +37,6 @@ import { Deadline } from './deadline';
import { getErrorCode, getErrorMessage } from './error';

const TRACER_NAME = 'server_call';
const unzip = promisify(zlib.unzip);
const inflate = promisify(zlib.inflate);

function trace(text: string): void {
logging.trace(LogVerbosity.DEBUG, TRACER_NAME, text);
Expand Down Expand Up @@ -480,19 +477,42 @@ export class Http2ServerCallStream<
private getDecompressedMessage(
message: Buffer,
encoding: string
): Buffer | Promise<Buffer> {
if (encoding === 'deflate') {
return inflate(message.subarray(5));
} else if (encoding === 'gzip') {
return unzip(message.subarray(5));
} else if (encoding === 'identity') {
return message.subarray(5);
): Buffer | Promise<Buffer> { const messageContents = message.subarray(5);
if (encoding === 'identity') {
return messageContents;
} else if (encoding === 'deflate' || encoding === 'gzip') {
let decompresser: zlib.Gunzip | zlib.Deflate;
if (encoding === 'deflate') {
decompresser = zlib.createInflate();
} else {
decompresser = zlib.createGunzip();
}
return new Promise((resolve, reject) => {
let totalLength = 0
const messageParts: Buffer[] = [];
decompresser.on('data', (chunk: Buffer) => {
messageParts.push(chunk);
totalLength += chunk.byteLength;
if (this.maxReceiveMessageSize !== -1 && totalLength > this.maxReceiveMessageSize) {
decompresser.destroy();
reject({
code: Status.RESOURCE_EXHAUSTED,
details: `Received message that decompresses to a size larger than ${this.maxReceiveMessageSize}`
});
}
});
decompresser.on('end', () => {
resolve(Buffer.concat(messageParts));
});
decompresser.write(messageContents);
decompresser.end();
});
} else {
return Promise.reject({
code: Status.UNIMPLEMENTED,
details: `Received message compressed with unsupported encoding "${encoding}"`,
});
}

return Promise.reject({
code: Status.UNIMPLEMENTED,
details: `Received message compressed with unsupported encoding "${encoding}"`,
});
}

sendMetadata(customMetadata?: Metadata) {
Expand Down Expand Up @@ -816,7 +836,7 @@ export class Http2ServerCallStream<
| ServerDuplexStream<RequestType, ResponseType>,
encoding: string
) {
const decoder = new StreamDecoder();
const decoder = new StreamDecoder(this.maxReceiveMessageSize);

let readsDone = false;

Expand All @@ -832,29 +852,34 @@ export class Http2ServerCallStream<
};

this.stream.on('data', async (data: Buffer) => {
const messages = decoder.write(data);
let messages: Buffer[];
try {
messages = decoder.write(data);
} catch (e) {
this.sendError({
code: Status.RESOURCE_EXHAUSTED,
details: (e as Error).message
});
return;
}

pendingMessageProcessing = true;
this.stream.pause();
for (const message of messages) {
if (
this.maxReceiveMessageSize !== -1 &&
message.length > this.maxReceiveMessageSize
) {
this.sendError({
code: Status.RESOURCE_EXHAUSTED,
details: `Received message larger than max (${message.length} vs. ${this.maxReceiveMessageSize})`,
});
return;
}
this.emit('receiveMessage');

const compressed = message.readUInt8(0) === 1;
const compressedMessageEncoding = compressed ? encoding : 'identity';
const decompressedMessage = await this.getDecompressedMessage(
message,
compressedMessageEncoding
);
let decompressedMessage: Buffer;
try {
decompressedMessage = await this.getDecompressedMessage(
message,
compressedMessageEncoding
);
} catch (e) {
this.sendError(e as Partial<StatusObject>);
return;
}

// Encountered an error with decompression; it'll already have been propogated back
// Just return early
Expand Down
Loading

0 comments on commit 08b0422

Please sign in to comment.