From 2064876a7d7fe6c33f29a50ff772d0a889da3576 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Wed, 11 Sep 2024 20:05:06 -0700 Subject: [PATCH] Fix crash in socket.upgradeTLS (#13884) --- packages/bun-usockets/src/crypto/openssl.c | 6 + packages/bun-usockets/src/libusockets.h | 2 +- src/boringssl.zig | 17 ++ src/bun.js/api/BunObject.zig | 14 +- src/bun.js/api/bun/socket.zig | 126 +++++++++++--- src/bun.js/bindings/ErrorCode.ts | 1 + src/deps/uws.zig | 6 +- test/js/bun/net/socket.test.ts | 193 ++++++++++++++++----- 8 files changed, 284 insertions(+), 81 deletions(-) diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index 06f656cea08fa..232d5f8ff9c16 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -1986,6 +1986,12 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls( struct us_socket_context_t *context = us_create_bun_socket_context( 1, old_context->loop, sizeof(struct us_wrapped_socket_context_t), options); + + // Handle SSL context creation failure + if (UNLIKELY(!context)) { + return NULL; + } + struct us_internal_ssl_socket_context_t *tls_context = (struct us_internal_ssl_socket_context_t *)context; diff --git a/packages/bun-usockets/src/libusockets.h b/packages/bun-usockets/src/libusockets.h index b6d9733f3e07d..ae6d05d3c22b8 100644 --- a/packages/bun-usockets/src/libusockets.h +++ b/packages/bun-usockets/src/libusockets.h @@ -247,7 +247,7 @@ void *us_socket_context_get_native_handle(int ssl, us_socket_context_r context); struct us_socket_context_t *us_create_socket_context(int ssl, us_loop_r loop, int ext_size, struct us_socket_context_options_t options) nonnull_fn_decl; struct us_socket_context_t *us_create_bun_socket_context(int ssl, struct us_loop_t *loop, - int ext_size, struct us_bun_socket_context_options_t options) nonnull_fn_decl; + int ext_size, struct us_bun_socket_context_options_t options); /* Delete resources allocated at creation time (will call unref now and only free when ref count == 0). */ void us_socket_context_free(int ssl, us_socket_context_r context) nonnull_fn_decl; diff --git a/src/boringssl.zig b/src/boringssl.zig index 3f3877ef36348..3f8892a1aadc0 100644 --- a/src/boringssl.zig +++ b/src/boringssl.zig @@ -206,3 +206,20 @@ pub fn checkServerIdentity( } return false; } + +const JSC = bun.JSC; +pub fn ERR_toJS(globalThis: *JSC.JSGlobalObject, err_code: u32) JSC.JSValue { + var outbuf: [128 + 1 + "BoringSSL ".len]u8 = undefined; + @memset(&outbuf, 0); + outbuf[0.."BoringSSL ".len].* = "BoringSSL ".*; + const message_buf = outbuf["BoringSSL ".len..]; + + _ = boring.ERR_error_string_n(err_code, message_buf, message_buf.len); + + const error_message: []const u8 = bun.sliceTo(outbuf[0..], 0); + if (error_message.len == "BoringSSL ".len) { + return globalThis.ERR_BORINGSSL("An unknown BoringSSL error occurred: {d}", .{err_code}).toJS(); + } + + return globalThis.ERR_BORINGSSL("{s}", .{error_message}).toJS(); +} diff --git a/src/bun.js/api/BunObject.zig b/src/bun.js/api/BunObject.zig index 0e8b93aeead0e..41b8affc1d822 100644 --- a/src/bun.js/api/BunObject.zig +++ b/src/bun.js/api/BunObject.zig @@ -1800,19 +1800,7 @@ pub const Crypto = struct { }; pub fn createCryptoError(globalThis: *JSC.JSGlobalObject, err_code: u32) JSValue { - var outbuf: [128 + 1 + "BoringSSL error: ".len]u8 = undefined; - @memset(&outbuf, 0); - outbuf[0.."BoringSSL error: ".len].* = "BoringSSL error: ".*; - const message_buf = outbuf["BoringSSL error: ".len..]; - - _ = BoringSSL.ERR_error_string_n(err_code, message_buf, message_buf.len); - - const error_message: []const u8 = bun.sliceTo(outbuf[0..], 0); - if (error_message.len == "BoringSSL error: ".len) { - return ZigString.static("Unknown BoringSSL error").toErrorInstance(globalThis); - } - - return ZigString.fromUTF8(error_message).toErrorInstance(globalThis); + return BoringSSL.ERR_toJS(globalThis, err_code); } const unknown_password_algorithm_message = "unknown algorithm, expected one of: \"bcrypt\", \"argon2id\", \"argon2d\", \"argon2i\" (default is \"argon2id\")"; diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index e1486a9335d3b..c613b70b7696c 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -2888,6 +2888,8 @@ fn NewSocket(comptime ssl: bool) type { callframe: *JSC.CallFrame, ) JSValue { JSC.markBinding(@src()); + const this_js = callframe.this(); + if (comptime ssl) { return JSValue.jsUndefined(); } @@ -2904,6 +2906,7 @@ fn NewSocket(comptime ssl: bool) type { } var exception: JSC.C.JSValueRef = null; + var success = false; const opts = args.ptr[0]; if (opts.isEmptyOrUndefinedOrNull() or opts.isBoolean() or !opts.isObject()) { @@ -2915,13 +2918,34 @@ fn NewSocket(comptime ssl: bool) type { globalObject.throw("Expected \"socket\" option", .{}); return .zero; }; + if (globalObject.hasException()) { + return .zero; + } + + const handlers = Handlers.fromJS(globalObject, socket_obj, &exception) orelse { + if (!globalObject.hasException() and exception != null) { + globalObject.throwValue(exception.?.value()); + } - var handlers = Handlers.fromJS(globalObject, socket_obj, &exception) orelse { - globalObject.throwValue(exception.?.value()); return .zero; }; + if (!globalObject.hasException() and exception != null) { + globalObject.throwValue(exception.?.value()); + } + + if (globalObject.hasException()) { + return .zero; + } + var ssl_opts: ?JSC.API.ServerConfig.SSLConfig = null; + defer { + if (!success) { + if (ssl_opts) |*ssl_config| { + ssl_config.deinit(); + } + } + } if (opts.getTruthy(globalObject, "tls")) |tls| { if (tls.isBoolean()) { @@ -2931,12 +2955,20 @@ fn NewSocket(comptime ssl: bool) type { } else { if (JSC.API.ServerConfig.SSLConfig.inJS(JSC.VirtualMachine.get(), globalObject, tls, &exception)) |ssl_config| { ssl_opts = ssl_config; - } else if (exception != null) { - return .zero; } } } + if (exception != null) { + if (!globalObject.hasException()) { + globalObject.throwValue(exception.?.value()); + } + } + + if (globalObject.hasException()) { + return .zero; + } + if (ssl_opts == null) { globalObject.throw("Expected \"tls\" option", .{}); return .zero; @@ -2947,8 +2979,12 @@ fn NewSocket(comptime ssl: bool) type { default_data = default_data_value; default_data.ensureStillAlive(); } + if (globalObject.hasException()) { + return .zero; + } var socket_config = ssl_opts.?; + ssl_opts = null; defer socket_config.deinit(); const options = socket_config.asUSockets(); @@ -2959,7 +2995,7 @@ fn NewSocket(comptime ssl: bool) type { const is_server = this.handlers.is_server; - var handlers_ptr = handlers.vm.allocator.create(Handlers) catch bun.outOfMemory(); + var handlers_ptr = bun.default_allocator.create(Handlers) catch bun.outOfMemory(); handlers_ptr.* = handlers; handlers_ptr.is_server = is_server; handlers_ptr.protect(); @@ -2972,11 +3008,11 @@ fn NewSocket(comptime ssl: bool) type { .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch bun.outOfMemory()) else null, .server_name = if (socket_config.server_name) |server_name| (bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch bun.outOfMemory()) else null, .socket_context = null, // only set after the wrapTLS + .flags = .{ + .is_active = false, + }, }); - const tls_js_value = tls.getThisValue(globalObject); - TLSSocket.dataSetCached(tls_js_value, globalObject, default_data); - const TCPHandler = NewWrappedHandler(false); // reconfigure context to use the new wrapper handlers @@ -2989,19 +3025,62 @@ fn NewSocket(comptime ssl: bool) type { WrappedSocket, TLSHandler, ) orelse { + const err = BoringSSL.ERR_get_error(); + defer { + if (err != 0) { + BoringSSL.ERR_clear_error(); + } + } + tls.wrapped = .none; + + // Reset config to TCP + uws.NewSocketHandler(false).configure( + this.socket.context().?, + true, + *TCPSocket, + struct { + pub const onOpen = NewSocket(false).onOpen; + pub const onClose = NewSocket(false).onClose; + pub const onData = NewSocket(false).onData; + pub const onWritable = NewSocket(false).onWritable; + pub const onTimeout = NewSocket(false).onTimeout; + pub const onConnectError = NewSocket(false).onConnectError; + pub const onEnd = NewSocket(false).onEnd; + pub const onHandshake = NewSocket(false).onHandshake; + }, + ); + + tls.deref(); + handlers_ptr.unprotect(); - handlers.vm.allocator.destroy(handlers_ptr); - bun.default_allocator.destroy(tls); + bun.default_allocator.destroy(handlers_ptr); + + // If BoringSSL gave us an error code, let's use it. + if (err != 0 and !globalObject.hasException()) { + globalObject.throwValue(BoringSSL.ERR_toJS(globalObject, err)); + } + + // If BoringSSL did not give us an error code, let's throw a generic error. + if (!globalObject.hasException()) { + globalObject.throw("Failed to upgrade socket from TCP -> TLS. Is the TLS config correct?", .{}); + } + return JSValue.jsUndefined(); }; + // Do not create the JS Wrapper object until _after_ we've validated the TLS config. + // Otherwise, JSC will GC it and the lifetime gets very complicated. + const tls_js_value = tls.getThisValue(globalObject); + TLSSocket.dataSetCached(tls_js_value, globalObject, default_data); + tls.socket = new_socket; tls.socket_context = new_socket.context(); // owns the new tls context that have a ref from the old one tls.ref(); + const vm = handlers.vm; - var raw_handlers_ptr = handlers.vm.allocator.create(Handlers) catch bun.outOfMemory(); + var raw_handlers_ptr = bun.default_allocator.create(Handlers) catch bun.outOfMemory(); raw_handlers_ptr.* = .{ - .vm = globalObject.bunVM(), + .vm = vm, .globalObject = globalObject, .onOpen = this.handlers.onOpen, .onClose = this.handlers.onClose, @@ -3030,10 +3109,11 @@ fn NewSocket(comptime ssl: bool) type { raw.ref(); const raw_js_value = raw.getThisValue(globalObject); - if (JSSocketType(ssl).dataGetCached(this.getThisValue(globalObject))) |raw_default_data| { + if (JSSocketType(ssl).dataGetCached(this_js)) |raw_default_data| { raw_default_data.ensureStillAlive(); TLSSocket.dataSetCached(raw_js_value, globalObject, raw_default_data); } + // marks both as active raw.markActive(); // this will keep tls alive until socket.open() is called to start TLS certificate and the handshake process @@ -3048,25 +3128,29 @@ fn NewSocket(comptime ssl: bool) type { ctx.* = .{ .tcp = raw, .tls = tls }; } - // start TLS handshake after we set ext - new_socket.startTLS(!this.handlers.is_server); - - //detach and invalidate the old instance - this.socket.detach(); - this.deref(); if (this.flags.is_active) { - const vm = this.handlers.vm; + this.poll_ref.disable(); this.flags.is_active = false; // will free handlers when hits 0 active connections // the connection can be upgraded inside a handler call so we need to garantee that it will be still alive this.handlers.markInactive(); - this.poll_ref.unref(vm); + this.has_pending_activity.store(false, .release); } const array = JSC.JSValue.createEmptyArray(globalObject, 2); array.putIndex(globalObject, 0, raw_js_value); array.putIndex(globalObject, 1, tls_js_value); + + defer this.deref(); + + // detach and invalidate the old instance + this.socket.detach(); + + // start TLS handshake after we set extension on the socket + new_socket.startTLS(!is_server); + + success = true; return array; } }; diff --git a/src/bun.js/bindings/ErrorCode.ts b/src/bun.js/bindings/ErrorCode.ts index 3a42a53c0e7c0..c8519af883efb 100644 --- a/src/bun.js/bindings/ErrorCode.ts +++ b/src/bun.js/bindings/ErrorCode.ts @@ -47,4 +47,5 @@ export default [ ["ERR_FORMDATA_PARSE_ERROR", TypeError, "TypeError"], ["ERR_BODY_ALREADY_USED", Error, "Error"], ["ERR_STREAM_WRAP", Error, "Error"], + ["ERR_BORINGSSL", Error, "Error"], ] as ErrorCodeMapping; diff --git a/src/deps/uws.zig b/src/deps/uws.zig index c661645e8d761..00cd6924e2757 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -1938,11 +1938,11 @@ pub const us_bun_socket_context_options_t = extern struct { ca_file_name: [*c]const u8 = null, ssl_ciphers: [*c]const u8 = null, ssl_prefer_low_memory_usage: i32 = 0, - key: [*c][*c]const u8 = null, + key: ?[*]?[*:0]const u8 = null, key_count: u32 = 0, - cert: [*c][*c]const u8 = null, + cert: ?[*]?[*:0]const u8 = null, cert_count: u32 = 0, - ca: [*c][*c]const u8 = null, + ca: ?[*]?[*:0]const u8 = null, ca_count: u32 = 0, secure_options: u32 = 0, reject_unauthorized: i32 = 0, diff --git a/test/js/bun/net/socket.test.ts b/test/js/bun/net/socket.test.ts index abe30419a13e1..f3d114667d7cf 100644 --- a/test/js/bun/net/socket.test.ts +++ b/test/js/bun/net/socket.test.ts @@ -1,6 +1,7 @@ import type { Socket } from "bun"; import { connect, fileURLToPath, SocketHandler, spawn } from "bun"; -import { expect, it } from "bun:test"; +import { heapStats } from "bun:jsc"; +import { expect, it, jest } from "bun:test"; import { bunEnv, bunExe, expectMaxObjectTypeCount, isWindows, tls } from "harness"; it("should coerce '0' to 0", async () => { const listener = Bun.listen({ @@ -495,18 +496,17 @@ it("should not call drain before handshake", async () => { await promise; expect(socket.authorized).toBe(true); }); -it("should be able to upgrade to TLS", async () => { +it("upgradeTLS handles errors", async () => { using server = Bun.serve({ tls, async fetch(req) { return new Response("Hello World"); }, }); - const { promise: tlsSocketPromise, resolve, reject } = Promise.withResolvers(); - const { promise: rawSocketPromise, resolve: rawSocketResolve, reject: rawSocketReject } = Promise.withResolvers(); - { - let body = ""; - let rawBody = Buffer.alloc(0); + let body = ""; + let rawBody = Buffer.alloc(0); + + for (let i = 0; i < 100; i++) { const socket = await Bun.connect({ hostname: "localhost", port: server.port, @@ -514,52 +514,159 @@ it("should be able to upgrade to TLS", async () => { data(socket, data) { rawBody = Buffer.concat([rawBody, data]); }, - close() { - rawSocketResolve(rawBody); - }, - error(err) { - rawSocketReject(err); - }, + close() {}, + error(err) {}, }, }); - const result = socket.upgradeTLS({ + + const handlers = { data: Buffer.from("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 0\r\n\r\n"), - tls, socket: { - data(socket, data) { - body += data.toString("utf8"); - if (body.includes("\r\n\r\n")) { - socket.end(); - } + data: jest.fn(), + close: jest.fn(), + drain: jest.fn(), + error: jest.fn(), + open: jest.fn(), + }, + }; + expect(() => + socket.upgradeTLS({ + ...handlers, + tls: { + ca: "invalid certificate!", }, - close() { - resolve(body); + }), + ).toThrow( + expect.objectContaining({ + code: "ERR_BORINGSSL", + }), + ); + + expect(() => + socket.upgradeTLS({ + ...handlers, + tls: { + cert: "invalid certificate!", }, - drain(socket) { - while (socket.data.byteLength > 0) { - const written = socket.write(socket.data); - if (written === 0) { - break; - } - socket.data = socket.data.slice(written); - } - socket.flush(); + }), + ).toThrow( + expect.objectContaining({ + code: "ERR_BORINGSSL", + }), + ); + + expect(() => + socket.upgradeTLS({ + ...handlers, + tls: { + ...tls, + key: "invalid key!", }, - error(err) { - reject(err); + }), + ).toThrow( + expect.objectContaining({ + code: "ERR_BORINGSSL", + }), + ); + + expect(() => + socket.upgradeTLS({ + ...handlers, + tls: { + ...tls, + key: "invalid key!", + cert: "invalid cert!", }, - }, - }); + }), + ).toThrow( + expect.objectContaining({ + code: "ERR_BORINGSSL", + }), + ); + + expect(() => + socket.upgradeTLS({ + ...handlers, + tls: {}, + }), + ).toThrow(); + + expect(handlers.socket.close).not.toHaveBeenCalled(); + expect(handlers.socket.error).not.toHaveBeenCalled(); + expect(handlers.socket.data).not.toHaveBeenCalled(); + expect(handlers.socket.drain).not.toHaveBeenCalled(); + expect(handlers.socket.open).not.toHaveBeenCalled(); + socket.end(); + } + Bun.gc(true); +}); +it("should be able to upgrade to TLS", async () => { + using server = Bun.serve({ + tls, + async fetch(req) { + return new Response("Hello World"); + }, + }); + for (let i = 0; i < 50; i++) { + const { promise: tlsSocketPromise, resolve, reject } = Promise.withResolvers(); + const { promise: rawSocketPromise, resolve: rawSocketResolve, reject: rawSocketReject } = Promise.withResolvers(); + { + let body = ""; + let rawBody = Buffer.alloc(0); + const socket = await Bun.connect({ + hostname: "localhost", + port: server.port, + socket: { + data(socket, data) { + rawBody = Buffer.concat([rawBody, data]); + }, + close() { + rawSocketResolve(rawBody); + }, + error(err) { + rawSocketReject(err); + }, + }, + }); + const result = socket.upgradeTLS({ + data: Buffer.from("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 0\r\n\r\n"), + tls, + socket: { + data(socket, data) { + body += data.toString("utf8"); + if (body.includes("\r\n\r\n")) { + socket.end(); + } + }, + close() { + resolve(body); + }, + drain(socket) { + while (socket.data.byteLength > 0) { + const written = socket.write(socket.data); + if (written === 0) { + break; + } + socket.data = socket.data.slice(written); + } + socket.flush(); + }, + error(err) { + reject(err); + }, + }, + }); - const [raw, tls_socket] = result; - expect(raw).toBeDefined(); - expect(tls_socket).toBeDefined(); + const [raw, tls_socket] = result; + expect(raw).toBeDefined(); + expect(tls_socket).toBeDefined(); + } + const [tlsData, rawData] = await Promise.all([tlsSocketPromise, rawSocketPromise]); + expect(tlsData).toContain("HTTP/1.1 200 OK"); + expect(tlsData).toContain("Content-Length: 11"); + expect(tlsData).toContain("\r\nHello World"); + expect(rawData.byteLength).toBeGreaterThanOrEqual(1980); } - const [tlsData, rawData] = await Promise.all([tlsSocketPromise, rawSocketPromise]); - expect(tlsData).toContain("HTTP/1.1 200 OK"); - expect(tlsData).toContain("Content-Length: 11"); - expect(tlsData).toContain("\r\nHello World"); - expect(rawData.byteLength).toBeGreaterThanOrEqual(1980); }); it("should not leak memory", async () => {