Skip to content

Commit

Permalink
Merge pull request #2855 from murgatroid99/grpc-js_credentials_secure…
Browse files Browse the repository at this point in the history
…_connector

grpc-js: Add security connector, rework connection establishment
  • Loading branch information
murgatroid99 authored Nov 23, 2024
2 parents f5133e4 + 1657324 commit 8f08bbe
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 345 deletions.
221 changes: 154 additions & 67 deletions packages/grpc-js/src/channel-credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@ import {
createSecureContext,
PeerCertificate,
SecureContext,
checkServerIdentity,
connect as tlsConnect
} from 'tls';

import { CallCredentials } from './call-credentials';
import { CIPHER_SUITES, getDefaultRootsData } from './tls-helpers';
import { CaCertificateUpdate, CaCertificateUpdateListener, CertificateProvider, IdentityCertificateUpdate, IdentityCertificateUpdateListener } from './certificate-provider';
import { Socket } from 'net';
import { ChannelOptions } from './channel-options';
import { GrpcUri, parseUri, splitHostPort } from './uri-parser';
import { getDefaultAuthority } from './resolver';

// eslint-disable-next-line @typescript-eslint/no-explicit-any
function verifyIsBufferOrNull(obj: any, friendlyName: string): void {
Expand Down Expand Up @@ -57,6 +63,11 @@ export interface VerifyOptions {
rejectUnauthorized?: boolean;
}

export interface SecureConnector {
connect(socket: Socket): Promise<Socket>;
destroy(): void;
}

/**
* A class that contains credentials for communicating over a channel, as well
* as a set of per-call credentials, which are applied to every method call made
Expand All @@ -83,13 +94,6 @@ export abstract class ChannelCredentials {
return this.callCredentials;
}

/**
* Gets a SecureContext object generated from input parameters if this
* instance was created with createSsl, or null if this instance was created
* with createInsecure.
*/
abstract _getConnectionOptions(): ConnectionOptions | null;

/**
* Indicates whether this credentials object creates a secure channel.
*/
Expand All @@ -102,13 +106,7 @@ export abstract class ChannelCredentials {
*/
abstract _equals(other: ChannelCredentials): boolean;

_ref(): void {
// Do nothing by default
}

_unref(): void {
// Do nothing by default
}
abstract _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector;

/**
* Return a new ChannelCredentials instance with a given set of credentials.
Expand Down Expand Up @@ -180,51 +178,111 @@ class InsecureChannelCredentialsImpl extends ChannelCredentials {
compose(callCredentials: CallCredentials): never {
throw new Error('Cannot compose insecure credentials');
}

_getConnectionOptions(): ConnectionOptions | null {
return {};
}
_isSecure(): boolean {
return false;
}
_equals(other: ChannelCredentials): boolean {
return other instanceof InsecureChannelCredentialsImpl;
}
_createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector {
return {
connect(socket) {
return Promise.resolve(socket);
},
destroy() {}
}
}
}

class SecureChannelCredentialsImpl extends ChannelCredentials {
connectionOptions: ConnectionOptions;
function getConnectionOptions(secureContext: SecureContext, verifyOptions: VerifyOptions, channelTarget: GrpcUri, options: ChannelOptions): ConnectionOptions {
const connectionOptions: ConnectionOptions = {
secureContext: secureContext
};
if (verifyOptions.checkServerIdentity) {
connectionOptions.checkServerIdentity = verifyOptions.checkServerIdentity;
}
if (verifyOptions.rejectUnauthorized !== undefined) {
connectionOptions.rejectUnauthorized = verifyOptions.rejectUnauthorized;
}
connectionOptions.ALPNProtocols = ['h2'];
if (options['grpc.ssl_target_name_override']) {
const sslTargetNameOverride = options['grpc.ssl_target_name_override']!;
const originalCheckServerIdentity =
connectionOptions.checkServerIdentity ?? checkServerIdentity;
connectionOptions.checkServerIdentity = (
host: string,
cert: PeerCertificate
): Error | undefined => {
return originalCheckServerIdentity(sslTargetNameOverride, cert);
};
connectionOptions.servername = sslTargetNameOverride;
} else {
if ('grpc.http_connect_target' in options) {
/* This is more or less how servername will be set in createSession
* if a connection is successfully established through the proxy.
* If the proxy is not used, these connectionOptions are discarded
* anyway */
const targetPath = getDefaultAuthority(
parseUri(options['grpc.http_connect_target'] as string) ?? {
path: 'localhost',
}
);
const hostPort = splitHostPort(targetPath);
connectionOptions.servername = hostPort?.host ?? targetPath;
}
}
if (options['grpc-node.tls_enable_trace']) {
connectionOptions.enableTrace = true;
}

let realTarget: GrpcUri = channelTarget;
if ('grpc.http_connect_target' in options) {
const parsedTarget = parseUri(options['grpc.http_connect_target']!);
if (parsedTarget) {
realTarget = parsedTarget;
}
}
const targetPath = getDefaultAuthority(realTarget);
const hostPort = splitHostPort(targetPath);
const remoteHost = hostPort?.host ?? targetPath;
connectionOptions.host = remoteHost;
connectionOptions.servername = remoteHost;
return connectionOptions;
}

class SecureConnectorImpl implements SecureConnector {
constructor(private connectionOptions: ConnectionOptions) {
}
connect(socket: Socket): Promise<Socket> {
const tlsConnectOptions: ConnectionOptions = {
socket: socket,
...this.connectionOptions
};
return new Promise<Socket>((resolve, reject) => {
const tlsSocket = tlsConnect(tlsConnectOptions, () => {
resolve(tlsSocket)
});
tlsSocket.on('error', (error: Error) => {
reject(error);
});
});
}
destroy() {}
}

class SecureChannelCredentialsImpl extends ChannelCredentials {
constructor(
private secureContext: SecureContext,
private verifyOptions: VerifyOptions
) {
super();
this.connectionOptions = {
secureContext,
};
// Node asserts that this option is a function, so we cannot pass undefined
if (verifyOptions?.checkServerIdentity) {
this.connectionOptions.checkServerIdentity =
verifyOptions.checkServerIdentity;
}

if (verifyOptions?.rejectUnauthorized !== undefined) {
this.connectionOptions.rejectUnauthorized =
verifyOptions.rejectUnauthorized;
}
}

compose(callCredentials: CallCredentials): ChannelCredentials {
const combinedCallCredentials =
this.callCredentials.compose(callCredentials);
return new ComposedChannelCredentialsImpl(this, combinedCallCredentials);
}

_getConnectionOptions(): ConnectionOptions | null {
// Copy to prevent callers from mutating this.connectionOptions
return { ...this.connectionOptions };
}
_isSecure(): boolean {
return true;
}
Expand All @@ -242,6 +300,10 @@ class SecureChannelCredentialsImpl extends ChannelCredentials {
return false;
}
}
_createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector {
const connectionOptions = getConnectionOptions(this.secureContext, this.verifyOptions, channelTarget, options);
return new SecureConnectorImpl(connectionOptions);
}
}

class CertificateProviderChannelCredentialsImpl extends ChannelCredentials {
Expand All @@ -250,10 +312,38 @@ class CertificateProviderChannelCredentialsImpl extends ChannelCredentials {
private latestIdentityUpdate: IdentityCertificateUpdate | null = null;
private caCertificateUpdateListener: CaCertificateUpdateListener = this.handleCaCertificateUpdate.bind(this);
private identityCertificateUpdateListener: IdentityCertificateUpdateListener = this.handleIdentityCertitificateUpdate.bind(this);
private static SecureConnectorImpl = class implements SecureConnector {
constructor(private parent: CertificateProviderChannelCredentialsImpl, private channelTarget: GrpcUri, private options: ChannelOptions) {}

connect(socket: Socket): Promise<Socket> {
return new Promise((resolve, reject) => {
const secureContext = this.parent.getLatestSecureContext();
if (!secureContext) {
reject(new Error('Credentials not loaded'));
return;
}
const connnectionOptions = getConnectionOptions(secureContext, this.parent.verifyOptions, this.channelTarget, this.options);
const tlsConnectOptions: ConnectionOptions = {
socket: socket,
...connnectionOptions
}
const tlsSocket = tlsConnect(tlsConnectOptions, () => {
resolve(tlsSocket)
});
tlsSocket.on('error', (error: Error) => {
reject(error);
});
});
}

destroy() {
this.parent.unref();
}
}
constructor(
private caCertificateProvider: CertificateProvider,
private identityCertificateProvider: CertificateProvider | null,
private verifyOptions: VerifyOptions | null
private verifyOptions: VerifyOptions
) {
super();
}
Expand All @@ -265,27 +355,6 @@ class CertificateProviderChannelCredentialsImpl extends ChannelCredentials {
combinedCallCredentials
);
}
_getConnectionOptions(): ConnectionOptions | null {
if (this.latestCaUpdate === null) {
return null;
}
if (this.identityCertificateProvider !== null && this.latestIdentityUpdate === null) {
return null;
}
const secureContext: SecureContext = createSecureContext({
ca: this.latestCaUpdate.caCertificate,
key: this.latestIdentityUpdate?.privateKey,
cert: this.latestIdentityUpdate?.certificate,
ciphers: CIPHER_SUITES
});
const options: ConnectionOptions = {
secureContext: secureContext
};
if (this.verifyOptions?.checkServerIdentity) {
options.checkServerIdentity = this.verifyOptions.checkServerIdentity;
}
return options;
}
_isSecure(): boolean {
return true;
}
Expand All @@ -301,20 +370,24 @@ class CertificateProviderChannelCredentialsImpl extends ChannelCredentials {
return false;
}
}
_ref(): void {
private ref(): void {
if (this.refcount === 0) {
this.caCertificateProvider.addCaCertificateListener(this.caCertificateUpdateListener);
this.identityCertificateProvider?.addIdentityCertificateListener(this.identityCertificateUpdateListener);
}
this.refcount += 1;
}
_unref(): void {
private unref(): void {
this.refcount -= 1;
if (this.refcount === 0) {
this.caCertificateProvider.removeCaCertificateListener(this.caCertificateUpdateListener);
this.identityCertificateProvider?.removeIdentityCertificateListener(this.identityCertificateUpdateListener);
}
}
_createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector {
this.ref();
return new CertificateProviderChannelCredentialsImpl.SecureConnectorImpl(this, channelTarget, options);
}

private handleCaCertificateUpdate(update: CaCertificateUpdate | null) {
this.latestCaUpdate = update;
Expand All @@ -323,10 +396,25 @@ class CertificateProviderChannelCredentialsImpl extends ChannelCredentials {
private handleIdentityCertitificateUpdate(update: IdentityCertificateUpdate | null) {
this.latestIdentityUpdate = update;
}

private getLatestSecureContext(): SecureContext | null {
if (this.latestCaUpdate === null) {
return null;
}
if (this.identityCertificateProvider !== null && this.latestIdentityUpdate === null) {
return null;
}
return createSecureContext({
ca: this.latestCaUpdate.caCertificate,
key: this.latestIdentityUpdate?.privateKey,
cert: this.latestIdentityUpdate?.certificate,
ciphers: CIPHER_SUITES
});
}
}

export function createCertificateProviderChannelCredentials(caCertificateProvider: CertificateProvider, identityCertificateProvider: CertificateProvider | null, verifyOptions?: VerifyOptions) {
return new CertificateProviderChannelCredentialsImpl(caCertificateProvider, identityCertificateProvider, verifyOptions ?? null);
return new CertificateProviderChannelCredentialsImpl(caCertificateProvider, identityCertificateProvider, verifyOptions ?? {});
}

class ComposedChannelCredentialsImpl extends ChannelCredentials {
Expand All @@ -347,10 +435,6 @@ class ComposedChannelCredentialsImpl extends ChannelCredentials {
combinedCallCredentials
);
}

_getConnectionOptions(): ConnectionOptions | null {
return this.channelCredentials._getConnectionOptions();
}
_isSecure(): boolean {
return true;
}
Expand All @@ -367,4 +451,7 @@ class ComposedChannelCredentialsImpl extends ChannelCredentials {
return false;
}
}
_createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector {
return this.channelCredentials._createSecureConnector(channelTarget, options);
}
}
Loading

0 comments on commit 8f08bbe

Please sign in to comment.