From 5121c86b1a73563452ab9ec1082541a9066caf65 Mon Sep 17 00:00:00 2001 From: Joannis Orlandos Date: Thu, 6 Jun 2024 19:16:07 +0200 Subject: [PATCH] Fixes for Sendable and strict concurrency (#340) --- Package@swift-5.9.swift | 115 ++++++++++++++ Sources/Meow/Error.swift | 4 +- .../Authenticate+ChallengeResponse.swift | 6 +- Sources/MongoClient/Authenticate+SASL.swift | 2 +- Sources/MongoClient/Cluster.swift | 4 +- Sources/MongoClient/Connection.swift | 140 +++++++++++------- Sources/MongoClient/ConnectionPool.swift | 2 +- Sources/MongoClient/Cursor.swift | 43 ++++-- Sources/MongoClient/Error.swift | 4 +- .../InternalCommands/GetMore.swift | 4 +- Sources/MongoClient/SASL.swift | 4 +- Sources/MongoCore/Error.swift | 10 +- .../Primitives/ConnectionSettings.swift | 4 +- Sources/MongoKitten/Aggregate.swift | 2 +- Sources/MongoKitten/AggregateStage.swift | 2 +- Sources/MongoKitten/ChangeStream.swift | 6 +- .../CollectionHelpers/Collection+Find.swift | 11 +- Sources/MongoKitten/Cursor.swift | 12 +- .../MongoKitten/GridFS/GridFSFileWriter.swift | 2 +- Sources/MongoKitten/MongoCollection.swift | 18 ++- Sources/MongoKitten/MongoDatabase.swift | 97 +++++++++--- .../MongoTransactionDatabase.swift | 1 + .../Commands/CreateIndexes.swift | 4 +- .../MongoKittenCore/Commands/Distinct.swift | 4 +- Sources/MongoKittenCore/Commands/Find.swift | 2 +- Sources/MongoKittenCore/Commands/Insert.swift | 4 +- Sources/_MongoKittenCrypto/Hash.swift | 2 +- Tests/MongoKittenTests/CRUDTests.swift | 21 +++ 28 files changed, 392 insertions(+), 138 deletions(-) create mode 100644 Package@swift-5.9.swift diff --git a/Package@swift-5.9.swift b/Package@swift-5.9.swift new file mode 100644 index 00000000..eff7814b --- /dev/null +++ b/Package@swift-5.9.swift @@ -0,0 +1,115 @@ +// swift-tools-version:5.9 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "MongoKitten", + platforms: [ + .macOS(.v13), + .iOS(.v13) + ], + products: [ + // Products define the executables and libraries produced by a package, and make them visible to other packages. + .library( + name: "MongoKitten", + targets: ["MongoKitten"]), + .library( + name: "Meow", + targets: ["Meow"]), + .library( + name: "MongoClient", + targets: ["MongoClient"]), + .library( + name: "MongoCore", + targets: ["MongoCore"]), + ], + dependencies: [ + // ✏️ + .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"), + + // 📈 + .package(url: "https://github.com/apple/swift-metrics.git", "1.0.0" ..< "3.0.0"), + + // ✅ + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"), + + // 💾 + .package(url: "https://github.com/orlandos-nl/BSON.git", from: "8.0.9"), + + // 🚀 + .package(url: "https://github.com/apple/swift-nio.git", from: "2.43.0"), + + // 📚 + .package(url: "https://github.com/orlandos-nl/DNSClient.git", from: "2.2.1"), + + // 🔑 + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"), + + // 🔍 + .package(url: "https://github.com/apple/swift-distributed-tracing.git", from: "1.0.0"), + ], + targets: [ + .target( + name: "_MongoKittenCrypto", + dependencies: [], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency=complete"), + ]), + .target( + name: "MongoCore", + dependencies: [ + .product(name: "BSON", package: "BSON"), + .product(name: "NIO", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "NIOFoundationCompat", package: "swift-nio"), + .product(name: "Logging", package: "swift-log"), + .product(name: "Metrics", package: "swift-metrics"), + .product(name: "Atomics", package: "swift-atomics"), + ], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency=complete"), + ]), + .target( + name: "MongoKittenCore", + dependencies: ["MongoClient"], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency=complete"), + ]), + .target( + name: "MongoKitten", + dependencies: ["MongoClient", "MongoKittenCore"], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency=complete"), + ]), + .target( + name: "Meow", + dependencies: ["MongoKitten"], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency=complete"), + ]), + .target( + name: "MongoClient", + dependencies: [ + "MongoCore", + "_MongoKittenCrypto", + .product(name: "DNSClient", package: "DNSClient"), + .product(name: "Tracing", package: "swift-distributed-tracing") + ], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency=complete"), + ] + ), + .testTarget( + name: "MongoCoreTests", + dependencies: ["MongoCore"]), + .testTarget( + name: "MongoKittenTests", + dependencies: [ + "MongoKitten", + ]), + .testTarget( + name: "MeowTests", + dependencies: ["Meow"]), + ] +) diff --git a/Sources/Meow/Error.swift b/Sources/Meow/Error.swift index 1646e171..d9f174d4 100644 --- a/Sources/Meow/Error.swift +++ b/Sources/Meow/Error.swift @@ -1,7 +1,7 @@ import MongoKitten /// Generic errors thrown by the generator -public enum MeowModelError: Swift.Error { +public enum MeowModelError: Swift.Error, @unchecked Sendable { /// The value for the given key is missing, or invalid case missingOrInvalidValue(key: String, expected: Any.Type, got: Any?) @@ -29,7 +29,7 @@ public enum MeowModelError: Swift.Error { case brokenFileReference(ObjectId) } -enum MeowError: Swift.Error { +enum MeowError: Swift.Error, @unchecked Sendable { /// A reference to `type` with id `id` cannot be resolved case referenceError(id: Any, type: Any.Type) diff --git a/Sources/MongoClient/Authenticate+ChallengeResponse.swift b/Sources/MongoClient/Authenticate+ChallengeResponse.swift index a98dc6d0..2333e0a6 100644 --- a/Sources/MongoClient/Authenticate+ChallengeResponse.swift +++ b/Sources/MongoClient/Authenticate+ChallengeResponse.swift @@ -27,14 +27,14 @@ fileprivate struct AuthenticateCR: Encodable { extension MongoConnection { internal func authenticateCR(_ username: String, password: String, namespace: MongoNamespace) async throws { - try await InstrumentationSystem.tracer.withSpan("MongoKitten.AuthenticateCR", ofKind: .client) { span in + try await _withSpan("MongoKitten.AuthenticateCR", context: .current ?? .topLevel, ofKind: .client) { spanContext in let nonceReply = try await self.executeCodable( GetNonce(), decodeAs: GetNonceResult.self, namespace: namespace, sessionId: nil, traceLabel: "AuthenticateCR.Initiate", - serviceContext: span.context + serviceContext: spanContext ) let nonce = nonceReply.nonce @@ -53,7 +53,7 @@ extension MongoConnection { namespace: namespace, sessionId: nil, traceLabel: "AuthenticateCR.Finalize", - serviceContext: span.context + serviceContext: spanContext ) try authenticationReply.assertOK( diff --git a/Sources/MongoClient/Authenticate+SASL.swift b/Sources/MongoClient/Authenticate+SASL.swift index 09cb6295..e4c72fd7 100644 --- a/Sources/MongoClient/Authenticate+SASL.swift +++ b/Sources/MongoClient/Authenticate+SASL.swift @@ -136,7 +136,7 @@ extension MongoConnection { /// /// The Hasher `H` specifies the hashing algorithm used with SCRAM. func authenticateSASL(hasher: H, namespace: MongoNamespace, username: String, password: String) async throws { - try await InstrumentationSystem.tracer.withSpan("MongoKitten.AuthenticateSASL", ofKind: .client) { span in + try await _withSpan("MongoKitten.AuthenticateSASL", ofKind: .client) { @Sendable span in let context = SCRAM(hasher) let rawRequest = try context.authenticationString(forUser: username) diff --git a/Sources/MongoClient/Cluster.swift b/Sources/MongoClient/Cluster.swift index 289cc949..58c20d15 100644 --- a/Sources/MongoClient/Cluster.swift +++ b/Sources/MongoClient/Cluster.swift @@ -185,7 +185,7 @@ public final class MongoCluster: MongoConnectionPool, @unchecked Sendable { // Kick off the connection process try await resolveSettings() - await scheduleDiscovery() + scheduleDiscovery() self.completedInitialDiscovery = true } } @@ -617,7 +617,7 @@ public final class MongoCluster: MongoConnectionPool, @unchecked Sendable { _ = try await self.next(for: .writable) await rediscover() self.completedInitialDiscovery = true - await scheduleDiscovery() + scheduleDiscovery() } } diff --git a/Sources/MongoClient/Connection.swift b/Sources/MongoClient/Connection.swift index 1df4eb4c..f8ca8e4b 100644 --- a/Sources/MongoClient/Connection.swift +++ b/Sources/MongoClient/Connection.swift @@ -31,7 +31,7 @@ public struct MongoHandshakeResult { public var interval: Double { received.timeIntervalSince(sent) } - + init(sentAt sent: Date, handshake: ServerHandshake) { self.sent = sent self.received = Date() @@ -39,7 +39,7 @@ public struct MongoHandshakeResult { } } -/// A connection to a MongoDB server. +/// A connection to a MongoDB server. public final actor MongoConnection: Sendable { /// The NIO channel used for communication internal let channel: Channel @@ -54,7 +54,7 @@ public final actor MongoConnection: Sendable { public var queryTimeout: TimeAmount? = .seconds(30) internal var lastServerActivity: Date? - + /// Whether metrics are enabled. When enabled, metrics will be collected for queries using the `Metrics` library. public var isMetricsEnabled = false { didSet { @@ -65,7 +65,7 @@ public final actor MongoConnection: Sendable { } } } - + /// A LIFO (Last In, First Out) holder for sessions public let sessionManager: MongoSessionManager @@ -78,7 +78,7 @@ public final actor MongoConnection: Sendable { public nonisolated var implicitSessionId: SessionIdentifier { return implicitSession.sessionId } - + /// The current request ID, used to generate unique identifiers for MongoDB commands private var currentRequestId = ManagedAtomic(0) internal let context: MongoClientContext @@ -87,28 +87,28 @@ public final actor MongoConnection: Sendable { public var serverHandshake: ServerHandshake? { get async { await context.serverHandshake } } - + public nonisolated var closeFuture: EventLoopFuture { return channel.closeFuture } - + public nonisolated var eventLoop: EventLoop { return channel.eventLoop } public var allocator: ByteBufferAllocator { return channel.allocator } - + /// Whether this connection is a slaveOk connection, meaning it can read from secondaries public let slaveOk = ManagedAtomic(false) - + internal func nextRequestId() -> Int32 { return currentRequestId.loadThenWrappingIncrement(ordering: .relaxed) } - + /// Creates a connection that can communicate with MongoDB over a channel public init(channel: Channel, context: MongoClientContext, sessionManager: MongoSessionManager = .init()) { self.sessionManager = sessionManager self.channel = channel self.context = context } - + /// Registers MongoKitten's handlers on the channel public static func addHandlers(to channel: Channel, context: MongoClientContext) -> EventLoopFuture { let parser = ClientConnectionParser(context: context) @@ -123,11 +123,11 @@ public final actor MongoConnection: Sendable { sessionId: implicitSessionId ) } - + /// Connects to a MongoDB server using the given settings. - /// + /// /// let connection = try await MongoConnection.connect(to: ConnectionSettings("mongodb://localhost:27017")) - /// + /// /// - Parameters: /// - settings: The settings to use for connecting /// - logger: The logger to use for logging @@ -140,13 +140,13 @@ public final actor MongoConnection: Sendable { resolver: Resolver? = nil, clientDetails: MongoClientDetails? = nil ) async throws -> MongoConnection { - #if canImport(NIOTransportServices) && os(iOS) +#if canImport(NIOTransportServices) && os(iOS) return try await connect(settings: settings, logger: logger, onGroup: NIOTSEventLoopGroup(loopCount: 1), resolver: resolver, clientDetails: clientDetails) - #else +#else return try await connect(settings: settings, logger: logger, onGroup: MultiThreadedEventLoopGroup(numberOfThreads: 1), resolver: resolver, clientDetails: clientDetails) - #endif +#endif } - + internal static func connect( settings: ConnectionSettings, logger: Logger = Logger(label: "org.orlandos-nl.mongokitten.connection"), @@ -169,33 +169,33 @@ public final actor MongoConnection: Sendable { logger[metadataKey: "mongo-port"] = .string(String(host.port)) let context = MongoClientContext(logger: logger) - - #if canImport(NIOTransportServices) && os(iOS) + +#if canImport(NIOTransportServices) && os(iOS) var bootstrap = NIOTSConnectionBootstrap(group: group) - + if settings.useSSL { bootstrap = bootstrap.tlsOptions(NWProtocolTLS.Options()) } - #else +#else let bootstrap = ClientBootstrap(group: group) .resolver(resolver) - #endif - +#endif + let channel = try await bootstrap .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .channelInitializer { channel in - #if canImport(NIOTransportServices) && os(iOS) - #else +#if canImport(NIOTransportServices) && os(iOS) +#else if settings.useSSL { do { var configuration = TLSConfiguration.clientDefault - + if let caCert = settings.sslCaCertificate { configuration.trustRoots = NIOSSLTrustRoots.certificates([caCert]) } else if let caCertPath = settings.sslCaCertificatePath { configuration.trustRoots = NIOSSLTrustRoots.file(caCertPath) } - + let handler = try NIOSSLClientHandler(context: NIOSSLContext(configuration: configuration), serverHostname: host.hostname) return channel.pipeline.addHandler(handler).flatMap { return MongoConnection.addHandlers(to: channel, context: context) @@ -204,26 +204,26 @@ public final actor MongoConnection: Sendable { return channel.eventLoop.makeFailedFuture(error) } } - #endif - +#endif + return MongoConnection.addHandlers(to: channel, context: context) }.connect(host: host.hostname, port: host.port).get() - + let connection = MongoConnection( channel: channel, context: context, sessionManager: sessionManager ) - + try await connection.authenticate( clientDetails: clientDetails, using: settings.authentication, to: settings.authenticationSource ?? "admin" ) - + return connection } - + /// Executes a MongoDB `isMaster` /// /// - SeeAlso: https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst @@ -233,17 +233,17 @@ public final actor MongoConnection: Sendable { authenticationDatabase: String = "admin" ) async throws -> ServerHandshake { let userNamespace: String? - + if case .auto(let user, _) = credentials { userNamespace = "\(authenticationDatabase).\(user)" } else { userNamespace = nil } - + // NO session must be used here: https://github.com/mongodb/specifications/blob/master/source/sessions/driver-sessions.rst#when-opening-and-authenticating-a-connection // Forced on the current connection let sent = Date() - + let result = try await executeCodable( IsMaster( clientDetails: clientDetails, @@ -254,13 +254,13 @@ public final actor MongoConnection: Sendable { sessionId: nil, traceLabel: "Handshake" ) - + await context.setServerHandshake(to: result) self.lastHeartbeat = MongoHandshakeResult(sentAt: sent, handshake: result) return result } - - // `@inline(never)` needed due to the llvm coroutine splitting issue + + // `@inline(never)` needed due to the llvm coroutine splitting issue // `https://github.com/apple/swift/issues/60380`. @inline(never) public func authenticate( @@ -273,11 +273,32 @@ public final actor MongoConnection: Sendable { credentials: credentials, authenticationDatabase: authenticationDatabase ) - + await self.context.setServerHandshake(to: handshake) try await self.authenticate(to: authenticationDatabase, serverHandshake: handshake, with: credentials) } - + + @Sendable nonisolated package func _withSpan( + _ label: String, + context: ServiceContext? = nil, + ofKind kind: SpanKind, + perform: @Sendable (ServiceContext) async throws -> T + ) async throws -> T { + let context = context ?? .current ?? .topLevel + +#if swift(<5.10) + return try await withSpan( + label, + context: context, + ofKind: kind + ) { _ in + try await perform(context) + } +#else + return try await perform(context) +#endif + } + func executeMessage( _ message: Request, logMetadata: Logger.Metadata? = nil, @@ -293,33 +314,42 @@ public final actor MongoConnection: Sendable { let promise = self.eventLoop.makePromise(of: MongoServerReply.self) await self.context.setReplyCallback(forRequestId: message.header.requestId, completing: promise) - return try await InstrumentationSystem.tracer.withSpan( + return try await _withSpan( "MongoKitten.\(traceLabel)", - context: context ?? .current ?? .topLevel, + context: context, ofKind: .client - ) { span in + ) { [queryTimeout] _ in var buffer = self.channel.allocator.buffer(capacity: Int(message.header.messageLength)) message.write(to: &buffer) try await self.channel.writeAndFlush(buffer) - if let queryTimeout = queryTimeout { - Task { - try await Task.sleep(nanoseconds: UInt64(queryTimeout.nanoseconds)) - promise.fail(MongoError(.queryTimeout, reason: nil)) + return try await withThrowingTaskGroup(of: MongoServerReply.self) { taskGroup in + if let queryTimeout { + taskGroup.addTask { + try await Task.sleep(nanoseconds: UInt64(queryTimeout.nanoseconds)) + let error = MongoError(.queryTimeout, reason: nil) + promise.fail(error) + throw error + } } - } - let result = try await promise.futureResult.get() - lastServerActivity = Date() - return result + let result = try await promise.futureResult.get() + taskGroup.cancelAll() + await self.logActivity() + return result + } } } - + + private func logActivity() { + self.lastServerActivity = Date() + } + /// Close the connection to the MongoDB server public func close() async { _ = try? await self.channel.close() } - + deinit { channel.close(mode: .all, promise: nil) } diff --git a/Sources/MongoClient/ConnectionPool.swift b/Sources/MongoClient/ConnectionPool.swift index cc2b1cb7..935aacb3 100644 --- a/Sources/MongoClient/ConnectionPool.swift +++ b/Sources/MongoClient/ConnectionPool.swift @@ -35,7 +35,7 @@ public struct ConnectionPoolRequest: Sendable, ExpressibleByArrayLiteral { } /// A connection pool that can be used to get connections from -public protocol MongoConnectionPool { +public protocol MongoConnectionPool: Sendable { /// Gets a connection from the pool that matches the provided `request` func next(for request: ConnectionPoolRequest) async throws -> MongoConnection diff --git a/Sources/MongoClient/Cursor.swift b/Sources/MongoClient/Cursor.swift index f58e1826..80aacee6 100644 --- a/Sources/MongoClient/Cursor.swift +++ b/Sources/MongoClient/Cursor.swift @@ -1,6 +1,7 @@ import Tracing import BSON import NIO +import NIOConcurrencyHelpers import MongoCore /// A cursor returned from a query, used to iterate over the results. @@ -10,11 +11,24 @@ import MongoCore /// while let doc = try await cursor.next() { /// print(doc) /// } -public final class MongoCursor { +public final class MongoCursor: Sendable { + private let _id: NIOLockedValueBox + private let _initialBatch: NIOLockedValueBox<[Document]?> + private let _hoppedEventLoop: NIOLockedValueBox + private let _maxTimeMS: NIOLockedValueBox + private let _readConcern: NIOLockedValueBox + /// The id of the cursor, used for `getMore` requests - public private(set) var id: Int64 + public private(set) var id: Int64 { + get { _id.withLockedValue { $0 } } + set { _id.withLockedValue { $0 = newValue } } + } + + private var initialBatch: [Document]? { + get { _initialBatch.withLockedValue { $0 } } + set { _initialBatch.withLockedValue { $0 = newValue } } + } - private var initialBatch: [Document]? internal let closePromise: EventLoopPromise /// A future that will be completed when the cursor is closed @@ -29,7 +43,10 @@ public final class MongoCursor { public let namespace: MongoNamespace /// The event loop this cursor is bound to and will return results on - public var hoppedEventLoop: EventLoop? + public var hoppedEventLoop: EventLoop? { + get { _hoppedEventLoop.withLockedValue { $0 } } + set { _hoppedEventLoop.withLockedValue { $0 = newValue } } + } /// The transaction this cursor is associated with, if any public let transaction: MongoTransaction? @@ -38,10 +55,16 @@ public final class MongoCursor { public let session: MongoClientSession? /// The maximum amount of time to allow the server to spend on this cursor - public var maxTimeMS: Int32? + public var maxTimeMS: Int32? { + get { _maxTimeMS.withLockedValue { $0 } } + set { _maxTimeMS.withLockedValue { $0 = newValue } } + } /// The read concern to use for this cursor - public var readConcern: ReadConcern? + public var readConcern: ReadConcern? { + get { _readConcern.withLockedValue { $0 } } + set { _readConcern.withLockedValue { $0 = newValue } } + } /// The connection this cursor is using to communicate with the server public let connection: MongoConnection @@ -59,16 +82,18 @@ public final class MongoCursor { traceLabel: String? = nil, context: ServiceContext? = nil ) { - self.id = reply.id - self.initialBatch = reply.firstBatch + self._id = NIOLockedValueBox(reply.id) + self._initialBatch = NIOLockedValueBox(reply.firstBatch) self.namespace = namespace - self.hoppedEventLoop = hoppedEventLoop + self._hoppedEventLoop = NIOLockedValueBox(hoppedEventLoop) self.connection = connection self.session = session self.transaction = transaction self.closePromise = connection.eventLoop.makePromise() self.traceLabel = traceLabel self.context = context + self._maxTimeMS = NIOLockedValueBox(nil) + self._readConcern = NIOLockedValueBox(nil) } /// Performs a `GetMore` command on the database, requesting the next batch of items diff --git a/Sources/MongoClient/Error.swift b/Sources/MongoClient/Error.swift index 321b91cb..3fe18d04 100644 --- a/Sources/MongoClient/Error.swift +++ b/Sources/MongoClient/Error.swift @@ -1,6 +1,6 @@ /// An error that occurs when a connection fails during authentication -public struct MongoAuthenticationError: Error, CustomStringConvertible { - public enum Reason: String, Codable, CustomStringConvertible, Equatable { +public struct MongoAuthenticationError: Error, CustomStringConvertible, Sendable { + public enum Reason: String, Codable, CustomStringConvertible, Equatable, Sendable { case missingServerHandshake case anyAuthenticationFailure case unsupportedAuthenticationMechanism diff --git a/Sources/MongoClient/InternalCommands/GetMore.swift b/Sources/MongoClient/InternalCommands/GetMore.swift index 8b642502..1a96fd67 100644 --- a/Sources/MongoClient/InternalCommands/GetMore.swift +++ b/Sources/MongoClient/InternalCommands/GetMore.swift @@ -16,8 +16,8 @@ public struct GetMore: Codable, Sendable { } } -public struct GetMoreReply: Codable { - public struct CursorDetails: Codable { +public struct GetMoreReply: Codable, Sendable { + public struct CursorDetails: Codable, Sendable { public var id: Int64 public var ns: String public var nextBatch: [Document] diff --git a/Sources/MongoClient/SASL.swift b/Sources/MongoClient/SASL.swift index 8a7a5ca8..0fbcd766 100644 --- a/Sources/MongoClient/SASL.swift +++ b/Sources/MongoClient/SASL.swift @@ -12,7 +12,7 @@ fileprivate enum ProgressState { /// A thread-safe global cache that all MongoDB clients can use to reduce computational cost of authentication /// /// By caching the proof of being auhtenticated. -fileprivate final class MongoCredentialsCache { +fileprivate final class MongoCredentialsCache: @unchecked Sendable { fileprivate static let `default` = MongoCredentialsCache() private init() {} @@ -51,7 +51,7 @@ fileprivate final class MongoCredentialsCache { } /// This type contains all information needed to reduce the computational weight of authentication -struct Credentials: Codable { +struct Credentials: Codable, Sendable { let saltedPassword: Data let clientKey: Data let serverKey: Data diff --git a/Sources/MongoCore/Error.swift b/Sources/MongoCore/Error.swift index 08629e87..31708e37 100644 --- a/Sources/MongoCore/Error.swift +++ b/Sources/MongoCore/Error.swift @@ -1,6 +1,6 @@ /// An error that occurs when a connection fails due to a protocol error. public struct MongoProtocolParsingError: Error, Codable, CustomStringConvertible { - public enum Reason: String, Codable, CustomStringConvertible, Equatable { + public enum Reason: String, Codable, CustomStringConvertible, Equatable, Sendable { case unsupportedOpCode case unexpectedValue case missingDocumentBody @@ -31,7 +31,7 @@ public struct MongoProtocolParsingError: Error, Codable, CustomStringConvertible /// An error that occurs when a connection fails during protocol serialization public struct MongoProtocolSerializationError: Error, Codable, CustomStringConvertible { /// The reason why the error occurred - public enum Reason: String, Codable, CustomStringConvertible, Equatable { + public enum Reason: String, Codable, CustomStringConvertible, Equatable, Sendable { case commandSizeTooLarge case unexpectedOpCode case unsupportedOpCode @@ -62,7 +62,7 @@ public struct MongoProtocolSerializationError: Error, Codable, CustomStringConve } public struct MongoInvalidUriError: Error, Codable, CustomStringConvertible { - public enum Reason: String, Codable, CustomStringConvertible, Equatable { + public enum Reason: String, Codable, CustomStringConvertible, Equatable, Sendable { case srvCannotSpecifyPort case missingMongoDBScheme case uriIsMalformed @@ -101,7 +101,7 @@ internal struct MongoOptionalUnwrapFailure: Error, CustomStringConvertible { /// A regular error that occurs when a query fails. public struct MongoError: Error, CustomStringConvertible, CustomDebugStringConvertible { - public enum Kind: String, Codable, CustomStringConvertible, Equatable { + public enum Kind: String, Codable, CustomStringConvertible, Equatable, Sendable { case authenticationFailure case cannotGetMore case cannotConnect @@ -123,7 +123,7 @@ public struct MongoError: Error, CustomStringConvertible, CustomDebugStringConve } } - public enum Reason: String, Codable, CustomStringConvertible, Equatable { + public enum Reason: String, Codable, CustomStringConvertible, Equatable, Sendable { case internalError case unexpectedSASLPhase case scramFailure diff --git a/Sources/MongoCore/Primitives/ConnectionSettings.swift b/Sources/MongoCore/Primitives/ConnectionSettings.swift index e52bd0af..6392756a 100644 --- a/Sources/MongoCore/Primitives/ConnectionSettings.swift +++ b/Sources/MongoCore/Primitives/ConnectionSettings.swift @@ -15,7 +15,7 @@ fileprivate extension Bool { } /// Describes the settings for a MongoDB connection, most of which can be represented in a connection string -public struct ConnectionSettings: Equatable { +public struct ConnectionSettings: Equatable, Sendable { /// The authentication details to use with the database public enum Authentication: Equatable, Sendable { /// Unauthenticated @@ -35,7 +35,7 @@ public struct ConnectionSettings: Equatable { } /// Defines a MongoDB host - public struct Host: Hashable { + public struct Host: Hashable, Sendable { /// The hostname, like "localhost", "example.com" or "127.0.0.1" public var hostname: String diff --git a/Sources/MongoKitten/Aggregate.swift b/Sources/MongoKitten/Aggregate.swift index b2f936d5..29a51417 100644 --- a/Sources/MongoKitten/Aggregate.swift +++ b/Sources/MongoKitten/Aggregate.swift @@ -74,7 +74,7 @@ public struct AggregateBuilderPipeline: CountableCursor { } /// Executes the pipeline and returns a cursor - public func execute() async throws -> FinalizedCursor { + @Sendable public func execute() async throws -> FinalizedCursor { let command = makeCommand() let connection = try await getConnection() diff --git a/Sources/MongoKitten/AggregateStage.swift b/Sources/MongoKitten/AggregateStage.swift index 613a36d7..c72269ce 100644 --- a/Sources/MongoKitten/AggregateStage.swift +++ b/Sources/MongoKitten/AggregateStage.swift @@ -4,7 +4,7 @@ import MongoKittenCore import Foundation /// A stage in an aggregation pipeline -public protocol AggregateBuilderStage { +public protocol AggregateBuilderStage: Sendable { /// The stage as a document to be sent to the server var stage: Document { get } diff --git a/Sources/MongoKitten/ChangeStream.swift b/Sources/MongoKitten/ChangeStream.swift index e1ed3bb3..3fae8367 100644 --- a/Sources/MongoKitten/ChangeStream.swift +++ b/Sources/MongoKitten/ChangeStream.swift @@ -2,14 +2,14 @@ import MongoClient import NIO /// The options for a change stream -public struct ChangeStreamOptions: Encodable { +public struct ChangeStreamOptions: Encodable, Sendable { private enum CodingKeys: String, CodingKey { case batchSize case collation case fullDocument } - public enum FullDocument: String, Encodable { + public enum FullDocument: String, Encodable, Sendable { case `default`, updateLookup, whenAvailable, required } @@ -116,7 +116,7 @@ extension MongoCollection { } /// A change stream is a stream of change notifications for a collection or database -public struct ChangeStream: AsyncSequence { +public struct ChangeStream: AsyncSequence, Sendable { public typealias Notification = ChangeStreamNotification public typealias Element = Notification typealias InputCursor = FinalizedCursor> diff --git a/Sources/MongoKitten/CollectionHelpers/Collection+Find.swift b/Sources/MongoKitten/CollectionHelpers/Collection+Find.swift index 9bc9e07d..8dc3fd83 100644 --- a/Sources/MongoKitten/CollectionHelpers/Collection+Find.swift +++ b/Sources/MongoKitten/CollectionHelpers/Collection+Find.swift @@ -1,5 +1,6 @@ import Tracing import NIO +import NIOConcurrencyHelpers import MongoClient import MongoKittenCore @@ -110,12 +111,16 @@ public final class FindQueryBuilder: CountableCursor, PaginatableCursor { /// The collection this cursor applies to private let makeConnection: @Sendable () async throws -> MongoConnection - public var command: FindCommand + private let _command: NIOLockedValueBox + public var command: FindCommand { + get { _command.withLockedValue { $0} } + set { _command.withLockedValue { $0 = newValue } } + } private let collection: MongoCollection public var isDrained: Bool { false } init(command: FindCommand, collection: MongoCollection, makeConnection: @Sendable @escaping () async throws -> MongoConnection, transaction: MongoTransaction? = nil) { - self.command = command + self._command = NIOLockedValueBox(command) self.makeConnection = makeConnection self.collection = collection } @@ -124,7 +129,7 @@ public final class FindQueryBuilder: CountableCursor, PaginatableCursor { return try await makeConnection() } - public func execute() async throws -> FinalizedCursor { + @Sendable public func execute() async throws -> FinalizedCursor { let connection = try await getConnection() let findSpan: any Span if let context = collection.context { diff --git a/Sources/MongoKitten/Cursor.swift b/Sources/MongoKitten/Cursor.swift index 1cb059e9..07d0c8d9 100644 --- a/Sources/MongoKitten/Cursor.swift +++ b/Sources/MongoKitten/Cursor.swift @@ -78,7 +78,7 @@ fileprivate extension CursorBatch where Element == Document { } /// A cursor with results from a query. Implemented by `FindCursor` and `AggregateCursor`. -public protocol QueryCursor { +public protocol QueryCursor: Sendable { /// The Element type of the cursor associatedtype Element @@ -86,7 +86,7 @@ public protocol QueryCursor { func getConnection() async throws -> MongoConnection /// Executes the cursor, returning a `FinalizedCursor` after the operation has completed. - func execute() async throws -> FinalizedCursor + @Sendable func execute() async throws -> FinalizedCursor /// Transforms a given `Document` to the cursor `Element` type func transformElement(_ element: Document) async throws -> Element @@ -95,7 +95,7 @@ public protocol QueryCursor { /// A protocol for cursors that can quickly count their results, without iterating over them public protocol CountableCursor: QueryCursor { /// Counts the number of results in the cursor - func count() async throws -> Int + @Sendable func count() async throws -> Int } /// A protocol for cursors that can be paginated using `skip` and `limit` @@ -134,7 +134,7 @@ extension QueryCursor { /// Returns a new cursor with the results of mapping the given closure over the cursor's elements. This operation is lazy. /// /// - parameter transform: A mapping closure. `transform` accepts an element of this cursor as its parameter and returns a transformed value of the same or of a different type. - public func map(transform: @escaping (Element) async throws -> E) -> MappedCursor { + public func map(transform: @escaping @Sendable (Element) async throws -> E) -> MappedCursor { return MappedCursor(underlyingCursor: self, transform: transform, failable: false) } @@ -168,7 +168,7 @@ extension QueryCursor { /// /// This cursor is used to iterate over the results of a query, and is not obtained directly. /// Instead, you can execute a `find` or `aggregate` query to obtain this instance. -public final class FinalizedCursor { +public final class FinalizedCursor: Sendable { let base: Base /// The underlying server-side cursor @@ -232,7 +232,7 @@ extension QueryCursor where Element == Document { /// A cursor that is the result of mapping another cursor public struct MappedCursor: QueryCursor { - internal typealias Transform = (Base.Element) async throws -> E + internal typealias Transform = @Sendable (Base.Element) async throws -> E /// Gets the connection associated with this cursor public func getConnection() async throws -> MongoConnection { diff --git a/Sources/MongoKitten/GridFS/GridFSFileWriter.swift b/Sources/MongoKitten/GridFS/GridFSFileWriter.swift index f44e792a..decfa8d6 100644 --- a/Sources/MongoKitten/GridFS/GridFSFileWriter.swift +++ b/Sources/MongoKitten/GridFS/GridFSFileWriter.swift @@ -4,7 +4,7 @@ import NIO /// A GridFS file writer that can be used to upload a file to GridFS. This writer is not thread-safe. public final class GridFSFileWriter { static let allocator = ByteBufferAllocator() - static let encoder = BSONEncoder() + static var encoder: BSONEncoder { BSONEncoder() } let fs: GridFSBucket let fileId: Primitive diff --git a/Sources/MongoKitten/MongoCollection.swift b/Sources/MongoKitten/MongoCollection.swift index 04740891..4f7f72b2 100644 --- a/Sources/MongoKitten/MongoCollection.swift +++ b/Sources/MongoKitten/MongoCollection.swift @@ -6,13 +6,13 @@ import Foundation /// A reference to a collection in a `MongoDatabase`. /// /// MongoDB stores documents in collections. Collections are analogous to tables in relational databases. -public final class MongoCollection { +public final class MongoCollection: Sendable { // MARK: Properties - internal var context: ServiceContext? - internal var transaction: MongoTransaction? + internal let context: ServiceContext? + internal let transaction: MongoTransaction? /// The session this collection is bound to. This is used for creating database commands. - public internal(set) var session: MongoClientSession? + public let session: MongoClientSession? /// The `SessionIdentifier` of the session this collection is bound to. /// If `nil`, this collection is not bound to a session. @@ -43,10 +43,18 @@ public final class MongoCollection { } /// Initializes this collection with by the database it's in and the collection name - internal init(named name: String, in database: MongoDatabase, context: ServiceContext?) { + internal init( + named name: String, + in database: MongoDatabase, + context: ServiceContext?, + transaction: MongoTransaction?, + session: MongoClientSession? + ) { self.name = name self.database = database self.context = context + self.transaction = transaction + self.session = session } /// Drops this collection from the database it's in and removes all documents from it. diff --git a/Sources/MongoKitten/MongoDatabase.swift b/Sources/MongoKitten/MongoDatabase.swift index 683266b6..c9f80759 100644 --- a/Sources/MongoKitten/MongoDatabase.swift +++ b/Sources/MongoKitten/MongoDatabase.swift @@ -12,22 +12,25 @@ import NIOTransportServices /// A reference to a MongoDB database, over a `MongoConnectionPool`. /// /// Databases hold collections of documents. -public class MongoDatabase { - internal var transaction: MongoTransaction? +public class MongoDatabase: @unchecked Sendable { + internal let transaction: MongoTransaction? public var activeTransaction: MongoTransaction? { transaction } - public internal(set) var session: MongoClientSession? + public let session: MongoClientSession? - internal var span: (any Span)? + private let _span: NIOLockedValueBox<(any Span)?> + internal var span: (any Span)? { + get { _span.withLockedValue { $0 } } + set { _span.withLockedValue { $0 = newValue } } + } internal var context: ServiceContext? { span?.context } - private let lock = NIOLock() - private var _logMetadata: Logger.Metadata? + private let _logMetadata: NIOLockedValueBox public var logMetadata: Logger.Metadata? { - get { lock.withLock { _logMetadata } } - set { lock.withLock { _logMetadata = newValue } } + get { _logMetadata.withLockedValue { $0 } } + set { _logMetadata.withLockedValue { $0 = newValue } } } public var sessionId: SessionIdentifier? { @@ -53,15 +56,29 @@ public class MongoDatabase { } public func adoptingLogMetadata(_ metadata: Logger.Metadata) -> MongoDatabase { - let copy = MongoDatabase(named: name, pool: pool) + let copy = MongoDatabase( + named: name, + pool: pool, + transaction: transaction, + session: session + ) copy.logMetadata = metadata return copy } - internal init(named name: String, pool: MongoConnectionPool, span: (any Span)? = nil) { + internal init( + named name: String, + pool: MongoConnectionPool, + span: (any Span)? = nil, + transaction: MongoTransaction?, + session: MongoClientSession? + ) { self.name = name self.pool = pool - self.span = span + self._span = NIOLockedValueBox(span) + self._logMetadata = NIOLockedValueBox(nil) + self.transaction = transaction + self.session = session } deinit { span?.end() } @@ -115,7 +132,12 @@ public class MongoDatabase { } let cluster = try MongoCluster(lazyConnectingTo: settings, logger: logger) - return MongoDatabase(named: targetDatabase, pool: cluster) + return MongoDatabase( + named: targetDatabase, + pool: cluster, + transaction: nil, + session: nil + ) } /// Connect to the database with the given settings. You can also use `connect(_:on:)` to connect by using a connection string. @@ -133,7 +155,12 @@ public class MongoDatabase { } let cluster = try await MongoCluster(connectingTo: settings, logger: logger) - return MongoDatabase(named: targetDatabase, pool: cluster) + return MongoDatabase( + named: targetDatabase, + pool: cluster, + transaction: nil, + session: nil + ) } /// Creates a new tranasction provided the SessionOptions and optional TransactionOptions @@ -165,10 +192,13 @@ public class MongoDatabase { let transaction = newSession.startTransaction(autocommit: autoCommit) let span = InstrumentationSystem.tracer.startAnySpan("Transaction(\(transaction.number))") - let db = MongoTransactionDatabase(named: name, pool: connection, span: span) - db.transaction = transaction - db.session = newSession - return db + return MongoTransactionDatabase( + named: name, + pool: connection, + span: span, + transaction: transaction, + session: newSession + ) } /// Get a `MongoCollection` by providing a collection name as a `String` @@ -177,10 +207,13 @@ public class MongoDatabase { /// /// - returns: The requested collection in this database public subscript(collection: String) -> MongoCollection { - let collection = MongoCollection(named: collection, in: self, context: context) - collection.session = self.session - collection.transaction = self.transaction - return collection + MongoCollection( + named: collection, + in: self, + context: context, + transaction: self.transaction, + session: self.session + ) } /// Drops the current database, deleting the associated data files @@ -237,7 +270,13 @@ public class MongoDatabase { ).decode(CollectionDescription.self) return try await cursor.drain().map { description in - return MongoCollection(named: description.name, in: self, context: context) + return MongoCollection( + named: description.name, + in: self, + context: context, + transaction: self.transaction, + session: self.session + ) } } } @@ -267,7 +306,12 @@ extension EventLoopFuture where Value == Optional { extension MongoConnectionPool { public subscript(db: String) -> MongoDatabase { - return MongoDatabase(named: db, pool: self) + return MongoDatabase( + named: db, + pool: self, + transaction: nil, + session: nil + ) } /// Lists all databases your user has knowledge of in this cluster @@ -282,7 +326,12 @@ extension MongoConnectionPool { ) return response.databases.map { description in - return MongoDatabase(named: description.name, pool: self) + return MongoDatabase( + named: description.name, + pool: self, + transaction: nil, + session: nil + ) } } } diff --git a/Sources/MongoKitten/MongoTransactionDatabase.swift b/Sources/MongoKitten/MongoTransactionDatabase.swift index 41353ce3..578f0aaf 100644 --- a/Sources/MongoKitten/MongoTransactionDatabase.swift +++ b/Sources/MongoKitten/MongoTransactionDatabase.swift @@ -1,4 +1,5 @@ import Tracing +import MongoCore public final class MongoTransactionDatabase: MongoDatabase { /// Commits the transaction and ends the session. diff --git a/Sources/MongoKittenCore/Commands/CreateIndexes.swift b/Sources/MongoKittenCore/Commands/CreateIndexes.swift index aff2cc7b..d31e14a8 100644 --- a/Sources/MongoKittenCore/Commands/CreateIndexes.swift +++ b/Sources/MongoKittenCore/Commands/CreateIndexes.swift @@ -1,8 +1,8 @@ import BSON import MongoCore -public struct CreateIndexes: Encodable { - public struct Index: Encodable { +public struct CreateIndexes: Encodable, Sendable { + public struct Index: Encodable, Sendable { private enum CodingKeys: String, CodingKey { case name, key, unique, partialFilterExpression, sparse case expireAfterSeconds, storageEngine, weights diff --git a/Sources/MongoKittenCore/Commands/Distinct.swift b/Sources/MongoKittenCore/Commands/Distinct.swift index 94b357c0..cf0def1b 100644 --- a/Sources/MongoKittenCore/Commands/Distinct.swift +++ b/Sources/MongoKittenCore/Commands/Distinct.swift @@ -1,7 +1,7 @@ import BSON import MongoCore -public struct DistinctCommand: Codable { +public struct DistinctCommand: Codable, Sendable { private let distinct: String public var key: String public var query: Document? @@ -14,7 +14,7 @@ public struct DistinctCommand: Codable { } } -public struct DistinctReply: Decodable { +public struct DistinctReply: Decodable, Sendable { public let ok: Int private let values: Document public var distinctValues: [Primitive] { diff --git a/Sources/MongoKittenCore/Commands/Find.swift b/Sources/MongoKittenCore/Commands/Find.swift index 7ad67c59..c84545b5 100644 --- a/Sources/MongoKittenCore/Commands/Find.swift +++ b/Sources/MongoKittenCore/Commands/Find.swift @@ -1,7 +1,7 @@ import BSON import MongoCore -public struct FindCommand: Codable { +public struct FindCommand: Codable, Sendable { /// This variable _must_ be the first encoded value, so keep it above all others private let find: String public var collection: String { return find } diff --git a/Sources/MongoKittenCore/Commands/Insert.swift b/Sources/MongoKittenCore/Commands/Insert.swift index 22e52600..7253e889 100644 --- a/Sources/MongoKittenCore/Commands/Insert.swift +++ b/Sources/MongoKittenCore/Commands/Insert.swift @@ -1,7 +1,7 @@ import BSON import MongoCore -public struct InsertCommand: Codable { +public struct InsertCommand: Codable, Sendable { /// This variable _must_ be the first encoded value, so keep it above all others private let insert: String public var collection: String { return insert } @@ -17,7 +17,7 @@ public struct InsertCommand: Codable { } } -public struct InsertReply: Decodable, Error, CustomDebugStringConvertible { +public struct InsertReply: Decodable, Error, CustomDebugStringConvertible, Sendable { private enum CodingKeys: String, CodingKey { case ok, writeErrors, writeConcernError case insertCount = "n" diff --git a/Sources/_MongoKittenCrypto/Hash.swift b/Sources/_MongoKittenCrypto/Hash.swift index 3ebc73a1..012b02e6 100644 --- a/Sources/_MongoKittenCrypto/Hash.swift +++ b/Sources/_MongoKittenCrypto/Hash.swift @@ -1,4 +1,4 @@ -public protocol Hash { +public protocol Hash: Sendable { static var littleEndian: Bool { get } static var chunkSize: Int { get } static var digestSize: Int { get } diff --git a/Tests/MongoKittenTests/CRUDTests.swift b/Tests/MongoKittenTests/CRUDTests.swift index 32a49ff7..26dd8256 100644 --- a/Tests/MongoKittenTests/CRUDTests.swift +++ b/Tests/MongoKittenTests/CRUDTests.swift @@ -635,4 +635,25 @@ class CrudTests : XCTestCase { XCTAssertEqual(count, 1) } + + func testIteratorInTaskGroup() async throws { + try await testBulkCreateDummyAccounts() + let schema = mongo[DummyAccount.collectionName] + + let docs = try await withThrowingTaskGroup(of: Document.self) { taskGroup in + for try await document in schema.find() { + taskGroup.addTask { + document + } + } + + return try await taskGroup.reduce(into: [Document]()) { partialResult, doc in + partialResult.append(doc) + } + } + + let count = try await schema.count() + XCTAssertNotEqual(count, 0) + XCTAssertEqual(docs.count, count) + } }