diff --git a/Package.swift b/Package.swift index 9203c1b9c..6177e3351 100644 --- a/Package.swift +++ b/Package.swift @@ -20,7 +20,8 @@ let package = Package( .package(name: "WebRTC", url: "https://github.com/webrtc-sdk/Specs.git", .exact("104.5112.09")), .package(name: "SwiftProtobuf", url: "https://github.com/apple/swift-protobuf.git", .upToNextMajor(from: "1.20.2")), .package(name: "Promises", url: "https://github.com/google/promises.git", .upToNextMajor(from: "2.0.0")), - .package(url: "https://github.com/apple/swift-log.git", .upToNextMajor(from: "1.4.4")) + .package(url: "https://github.com/apple/swift-log.git", .upToNextMajor(from: "1.4.4")), + .package(url: "https://github.com/apple/swift-async-algorithms.git", .upToNextMajor(from: "0.1.0")) ], targets: [ .systemLibrary(name: "CHeaders"), @@ -30,6 +31,7 @@ let package = Package( .target(name: "CHeaders"), "WebRTC", "SwiftProtobuf", "Promises", .product(name: "Logging", package: "swift-log"), + .product(name: "AsyncAlgorithms", package: "swift-async-algorithms") ], path: "Sources", swiftSettings: [ diff --git a/Sources/LiveKit/Core/DataChannelPair.swift b/Sources/LiveKit/Core/DataChannelPair.swift index 0d1ee7df6..23269d0d4 100644 --- a/Sources/LiveKit/Core/DataChannelPair.swift +++ b/Sources/LiveKit/Core/DataChannelPair.swift @@ -17,6 +17,127 @@ import Foundation import WebRTC import Promises +import AsyncAlgorithms + +class _DataChannelPair { + + public enum Errors: Error { + case decodeFailed + } + + let target: Livekit_SignalTarget + private let reliableChannel: RTCDataChannelObserver + private let lossyChannel: RTCDataChannelObserver + + var onDataPacket: OnDataPacket? + var dataPackets: AsyncStream { + AsyncStream { continuation in + //WIP: + } + } + + init(target: Livekit_SignalTarget, reliableChannel: RTCDataChannel? = nil, lossyChannel: RTCDataChannel? = nil) { + self.target = target + self.reliableChannel = RTCDataChannelObserver(channel: reliableChannel) + self.lossyChannel = RTCDataChannelObserver(channel: lossyChannel) + } + + public func set(reliable rChannel: RTCDataChannel?, lossy lChannel: RTCDataChannel?) { + self.reliableChannel.channel = rChannel + self.lossyChannel.channel = lChannel + } + + public func open() async throws { + await observeChannelStates(expectedReliableChannelState: .open, expectedLossyChannelState: .open) + } + + public func close() async { + await withTaskGroup(of: Void.self) { group in + [reliableChannel, lossyChannel].forEach { channel in + group.addTask { + channel.close() + } + } + + group.addTask { + await self.observeChannelStates(expectedReliableChannelState: .closed, expectedLossyChannelState: .closed) + } + } + } + + private func observeChannelStates(expectedReliableChannelState: RTCDataChannelState, expectedLossyChannelState: RTCDataChannelState) async { + for await (reliableState, lossyState) in combineLatest(reliableChannel.states, lossyChannel.states) { + print("got reliable state: \(reliableState), lossy state: \(lossyState)") + guard reliableState == expectedReliableChannelState, lossyState == expectedLossyChannelState else { continue } + break //terminate asyncsequence + } + } + + final class RTCDataChannelObserver: NSObject, RTCDataChannelDelegate { + + var channel: RTCDataChannel? { + willSet { + channel?.delegate = nil + } + + didSet { + channel?.delegate = self + } + } + + private var stateUpdate: ((RTCDataChannelState) -> Void)? + private var messageReceived: ((RTCDataBuffer) -> Void)? + + var states: AsyncStream { + AsyncStream { continuation in + stateUpdate = { state in + continuation.yield(state) + } + + continuation.onTermination = { @Sendable _ in + self.stateUpdate = nil + } + } + } + + var messages: AsyncStream { + AsyncStream { continuation in + messageReceived = { buffer in + continuation.yield(buffer) + } + + continuation.onTermination = { @Sendable _ in + self.messageReceived = nil + } + } + } + + init(channel: RTCDataChannel? = nil) { + super.init() + self.channel = channel + } + + deinit { + print("DEBUG: deinit \(self)") + } + + func close() { + channel?.close() + } + + func dataChannelDidChangeState(_ dataChannel: RTCDataChannel) { + assert(dataChannel == channel) + stateUpdate?(dataChannel.readyState) + } + + func dataChannel(_ dataChannel: RTCDataChannel, didReceiveMessageWith buffer: RTCDataBuffer) { + assert(dataChannel == channel) + messageReceived?(buffer) + } + } +} + +typealias OnDataPacket = (_ dataPacket: Livekit_DataPacket) -> Void internal class DataChannelPair: NSObject, Loggable { @@ -70,6 +191,18 @@ internal class DataChannelPair: NSObject, Loggable { openCompleter.fulfill(()) } } + + public func set(reliable rChannel: RTCDataChannel?, lossy lChannel: RTCDataChannel?) { + self._reliableChannel = rChannel + rChannel?.delegate = self + + self._lossyChannel = lChannel + lChannel?.delegate = self + + if isOpen { + openCompleter.fulfill(()) + } + } public func close() -> Promise { diff --git a/Sources/LiveKit/Core/Engine.swift b/Sources/LiveKit/Core/Engine.swift index d8435e72c..764e8713c 100644 --- a/Sources/LiveKit/Core/Engine.swift +++ b/Sources/LiveKit/Core/Engine.swift @@ -17,6 +17,7 @@ import Foundation import WebRTC import Promises +import Combine #if canImport(Network) import Network @@ -67,6 +68,7 @@ internal class Engine: MulticastDelegate { // MARK: - DataChannels + private var _subscriberDC = _DataChannelPair(target: .subscriber) private var subscriberDC = DataChannelPair(target: .subscriber) private var publisherDC = DataChannelPair(target: .publisher) @@ -129,12 +131,46 @@ internal class Engine: MulticastDelegate { default: return } } + + //WIP: + Task { + for await dataPacket in _subscriberDC.dataPackets { + + } + } } deinit { log() } - + + //WIP: + func _connect(url: String, token: String, connectOptions: ConnectOptions? = nil) async throws { + // update options if specified + if let connectOptions = connectOptions, connectOptions != _state.connectOptions { + _state.mutate { $0.connectOptions = connectOptions } + } + + try await _cleanUp() + self._state.mutate { $0.connectionState = .connecting } + + do { + try await self._fullConnectSequence(url: url, token: token) + } catch { + try await _cleanUp(reason: .networkError(error)) + throw error + } + + self.log("Connect sequence completed") + + // update internal vars (only if connect succeeded) + self._state.mutate { + $0.url = url + $0.token = token + $0.connectionState = .connected + } + } + // Connect sequence, resets existing state func connect(_ url: String, _ token: String, @@ -164,6 +200,11 @@ internal class Engine: MulticastDelegate { self.cleanUp(reason: .networkError(error)) } } + + func _cleanUp(reason: DisconnectReason? = nil, isFullReconnect: Bool = false) async throws { + guard let room else { throw EngineError.state(message: "Room is nil") } + await room._cleanUp(reason: reason, isFullReconnect: isFullReconnect) + } // cleanUp (reset) both Room & Engine's state @discardableResult @@ -177,6 +218,43 @@ internal class Engine: MulticastDelegate { return room.cleanUp(reason: reason, isFullReconnect: isFullReconnect) } + func _cleanupRTC() async { + await withTaskGroup(of: Void.self, body: { [_subscriberDC] group in + //TODO: - enable once publisher is ready +// group.addTask { +// await _subscriberDC.close() +// } +// group.addTask { +// await _publisherDC.close() +// } + + group.addTask { + await withUnsafeContinuation { continuation in + let closeDataChannelPromises = [ + self.publisherDC.close(), + self.subscriberDC.close() + ] + + closeDataChannelPromises.all(on: self.queue) + .then(on: self.queue) { + continuation.resume() + } + } + } + + [publisher, subscriber].forEach { transport in + guard let transport else { return } + group.addTask { + await transport._close() + } + } + }) + + self.publisher = nil + self.subscriber = nil + self._state.mutate { $0.hasPublished = false } + } + // Resets state of transports func cleanUpRTC() -> Promise { @@ -212,6 +290,20 @@ internal class Engine: MulticastDelegate { self._state.mutate { $0.hasPublished = false } } } + + func _publisherShouldNegotiate() async throws { + log() + + guard let publisher else { + throw EngineError.state(message: "self or publisher is nil") + } + + _state.mutate { $0.hasPublished = true } + + //v-- this is the offer ... do we need it? + _ = try await publisher._createAndSendOffer() + print("negotiated offer...") + } @discardableResult func publisherShouldNegotiate() -> Promise { @@ -317,6 +409,36 @@ internal extension Engine { private extension Engine { + func _fullConnectSequence(url: String, token: String) async throws { + + // this should never happen since Engine is owned by Room + guard let room else { throw EngineError.state(message: "Room is nil") } + + try await self.signalClient._connect( + urlString: url, + token: token, + connectOptions: _state.connectOptions, + reconnectMode: _state.reconnectMode, + adaptiveStream: room._state.options.adaptiveStream + ) + + assert(signalClient._state.connectionState == .connected) + + // wait for joinResponse + let joinResponse = try await signalClient._joinResponse() + + self._state.mutate { $0.connectStopwatch.split(label: "signal") } + + try await self._configureTransports(joinResponse: joinResponse) + + await self.signalClient._resumeResponseQueue() + + try await self.primaryTransportConnected() + + _state.mutate { $0.connectStopwatch.split(label: "engine") } + log("\(_state.connectStopwatch)") + } + // full connect sequence, doesn't update connection state func fullConnectSequence(_ url: String, _ token: String) -> Promise { @@ -349,6 +471,31 @@ private extension Engine { self.log("\(self._state.connectStopwatch)") } } + + //WIP: continue here! + func primaryTransportConnected() async throws { + //TODO: timeout + var primaryState: RTCPeerConnectionState? + + if #available(iOS 15.0, *) { + guard let primaryStates = primary?.connectionStatePublisher.values else { + throw TransportError.noPrimary(message: "transport is nil") + } + + for await state in primaryStates { + guard state == .connected else { continue } + primaryState = state + break + } + } else { + // Fallback on earlier versions + fatalError() + } + + guard let _ = primaryState else { + throw TransportError.noPrimary(message: "no state after waiting for primary connection") + } + } @discardableResult func startReconnect() -> Promise { @@ -623,7 +770,7 @@ extension Engine: TransportDelegate { func transport(_ transport: Transport, didUpdate pcState: RTCPeerConnectionState) { log("target: \(transport.target), state: \(pcState)") - + // primary connected if transport.primary { _state.mutate { $0.primaryTransportConnectedCompleter.set(value: .connected == pcState ? () : nil) } @@ -642,6 +789,67 @@ extension Engine: TransportDelegate { } } } + + private func _configureTransports(joinResponse: Livekit_JoinResponse) async throws { + log("configuring transports...") + + // this should never happen since Engine is owned by Room + guard let room else { throw EngineError.state(message: "Room is nil") } + + guard subscriber == nil, publisher == nil else { + log("transports already configured") + return + } + + // protocol v3 + subscriberPrimary = joinResponse.subscriberPrimary + log("subscriberPrimary: \(joinResponse.subscriberPrimary)") + + // Make a copy, instead of modifying the user-supplied RTCConfiguration object. + let rtcConfiguration = RTCConfiguration(copy: _state.connectOptions.rtcConfiguration) + + if rtcConfiguration.iceServers.isEmpty { + // Set iceServers provided by the server + rtcConfiguration.iceServers = joinResponse.iceServers.map { $0.toRTCType() } + } + + if joinResponse.clientConfiguration.forceRelay == .enabled { + rtcConfiguration.iceTransportPolicy = .relay + } + + let subscriber = try Transport(config: rtcConfiguration, + target: .subscriber, + primary: self.subscriberPrimary, + delegate: self, + reportStats: room._state.options.reportStats) + + let publisher = try Transport(config: rtcConfiguration, + target: .publisher, + primary: !self.subscriberPrimary, + delegate: self, + reportStats: room._state.options.reportStats) + + // data over pub channel for backwards compatibility + + let publisherReliableDC = publisher.dataChannel(for: RTCDataChannel.labels.reliable, + configuration: Engine.createDataChannelConfiguration()) + + let publisherLossyDC = publisher.dataChannel(for: RTCDataChannel.labels.lossy, + configuration: Engine.createDataChannelConfiguration(maxRetransmits: 0)) + + self.publisherDC.set(reliable: publisherReliableDC, lossy: publisherLossyDC) + + self.log("dataChannel.\(String(describing: publisherReliableDC?.label)) : \(String(describing: publisherReliableDC?.channelId))") + self.log("dataChannel.\(String(describing: publisherLossyDC?.label)) : \(String(describing: publisherLossyDC?.channelId))") + + if !self.subscriberPrimary { + // lazy negotiation for protocol v3+ + try await self._publisherShouldNegotiate() + } + + self.subscriber = subscriber + self.publisher = publisher + } private func configureTransports(joinResponse: Livekit_JoinResponse) -> Promise { diff --git a/Sources/LiveKit/Core/Room.swift b/Sources/LiveKit/Core/Room.swift index 918505082..8390e9b79 100644 --- a/Sources/LiveKit/Core/Room.swift +++ b/Sources/LiveKit/Core/Room.swift @@ -190,6 +190,26 @@ public class Room: NSObject, Loggable { deinit { log() } + + public func _connect(url: String, token: String, connectOptions: ConnectOptions? = nil, roomOptions: RoomOptions? = nil) async throws { + + log("connecting to room...", .info) + + let state = _state.readCopy() + + guard state.localParticipant == nil else { + log("localParticipant is not nil", .warning) + throw EngineError.state(message: "localParticipant is not nil") + } + + // update options if specified + if let roomOptions = roomOptions, roomOptions != state.options { + _state.mutate { $0.options = roomOptions } + } + + try await engine._connect(url: url, token: token, connectOptions: connectOptions) + log("connected to \(String(describing: self)) \(String(describing: state.localParticipant))", .info) + } @discardableResult public func connect(_ url: String, @@ -218,6 +238,20 @@ public class Room: NSObject, Loggable { return self } } + + public func _disconnect() async throws { + + // return if already disconnected state + guard case .connected = connectionState else { return } + + do { + try await engine.signalClient.sendLeave() + } catch { + log("Failed to send leave, error: \(error)") + } + + await _cleanUp(reason: .user) + } @discardableResult public func disconnect() -> Promise { @@ -238,6 +272,40 @@ public class Room: NSObject, Loggable { internal extension Room { + func _cleanUp(reason: DisconnectReason? = nil, isFullReconnect: Bool = false) async { + + log("reason: \(String(describing: reason))") + + engine._state.mutate { + $0.primaryTransportConnectedCompleter.reset() + $0.publisherTransportConnectedCompleter.reset() + + // if isFullReconnect, keep connection related states + $0 = isFullReconnect ? Engine.State( + connectOptions: $0.connectOptions, + url: $0.url, + token: $0.token, + nextPreferredReconnectMode: $0.nextPreferredReconnectMode, + reconnectMode: $0.reconnectMode, + connectionState: $0.connectionState + ) : Engine.State( + connectOptions: $0.connectOptions, + connectionState: .disconnected(reason: reason) + ) + } + + await engine.signalClient._cleanUp(reason: reason) + log("did cleanUp signalClient") + + await engine._cleanupRTC() + log("did cleanup RTC engine") + + await _cleanupParticipants() + log("did cleanUp participants") + + self._state.mutate { $0 = State(options: $0.options) } + } + // Resets state of Room @discardableResult func cleanUp(reason: DisconnectReason? = nil, @@ -277,6 +345,33 @@ internal extension Room { self.log("Room cleanUp failed with error: \(error)", .error) } } + + func _cleanupParticipants(notify: Bool = true) async { + + log("notify: \(notify)") + + await withTaskGroup(of: Void.self, body: { [localParticipant, _state] group in + + if let localParticipant { + group.addTask { + await localParticipant._cleanUp(notify: notify) + } + } + + // Stop all local --^ & remote tracks --v + _state.remoteParticipants.values.compactMap { $0 }.forEach { participant in + group.addTask { + await participant._cleanUp(notify: notify) + } + } + }) + + //reset state and be done + _state.mutate { + $0.localParticipant = nil + $0.remoteParticipants = [:] + } + } } // MARK: - Private diff --git a/Sources/LiveKit/Core/SignalClient.swift b/Sources/LiveKit/Core/SignalClient.swift index 3187b9ecf..32786a2d0 100644 --- a/Sources/LiveKit/Core/SignalClient.swift +++ b/Sources/LiveKit/Core/SignalClient.swift @@ -17,6 +17,7 @@ import Foundation import WebRTC import Promises +import Combine internal class SignalClient: MulticastDelegate { @@ -36,6 +37,11 @@ internal class SignalClient: MulticastDelegate { } internal var _state = StateSync(State()) + + enum Errors: Error { + case alreadyConnected + case invalidJoinResponse + } // MARK: - Private @@ -55,14 +61,36 @@ internal class SignalClient: MulticastDelegate { private var webSocket: WebSocket? private var latestJoinResponse: Livekit_JoinResponse? + + private var responseSubscriptions: Set = [] + private let receivedResponses: CurrentValueSubject + var receivedResponsePublisher: AnyPublisher { + receivedResponses + .compactMap { [weak self] signalResponse in + guard let self else { return nil } + guard let message = signalResponse?.message else { + self.log("Failed to decode SignalResponse", .warning) + return nil + } + return message + } + .eraseToAnyPublisher() + } + + private let joinResponses: CurrentValueSubject private var pingIntervalTimer: DispatchQueueTimer? private var pingTimeoutTimer: DispatchQueueTimer? init() { + receivedResponses = CurrentValueSubject(nil) + joinResponses = CurrentValueSubject(nil) + super.init() log() + + joinResponsesSubscription() // trigger events when state mutates self._state.onMutate = { [weak self] state, oldState in @@ -80,8 +108,118 @@ internal class SignalClient: MulticastDelegate { deinit { log() + responseSubscriptions.removeAll() + } + + private func joinResponsesSubscription() { + receivedResponsePublisher + .compactMap { + guard case .join(let joinResponse) = $0 else { return nil } + return joinResponse + } + .assign(to: \.value, on: joinResponses) + .store(in: &responseSubscriptions) + } + + func _joinResponse() async throws -> Livekit_JoinResponse { + if #available(iOS 15.0, *) { + var result: Livekit_JoinResponse? + for await joinResponse in joinResponses.values { + guard let joinResponse else { continue } + result = joinResponse + break // << stop at the first valid join response + } + + guard let result else { throw Errors.invalidJoinResponse } + return result + + } else { + fatalError() + } } + func _connect(urlString: String, token: String, connectOptions: ConnectOptions? = nil, reconnectMode: ReconnectMode? = nil, adaptiveStream: Bool) async throws { + + guard self.webSocket == nil else { throw Errors.alreadyConnected } + + await _cleanUp() + + log("reconnectMode: \(String(describing: reconnectMode))") + + guard let url = Utils.buildUrl(urlString, + token, + connectOptions: connectOptions, + reconnectMode: reconnectMode, + adaptiveStream: adaptiveStream) else { throw InternalError.parse(message: "Failed to parse url") } + log("Connecting with url: \(urlString)") + + _state.mutate { + $0.reconnectMode = reconnectMode + $0.connectionState = .connecting + } + + let websocket = try await socketConnect(url: url, urlString: urlString, reconnectMode: reconnectMode) { + Utils.buildUrl(urlString, + token, + connectOptions: connectOptions, + adaptiveStream: adaptiveStream, + validate: true) + } + + self._state.mutate { $0.connectionState = .connected } + self.queue.async { + self.webSocket = websocket + } + } + + private func socketConnect(url: URL, + urlString: String, + reconnectMode: ReconnectMode? = nil, + validatedURL: @escaping () -> URL?) async throws -> WebSocket { + + try await withUnsafeThrowingContinuation { continuation in + + WebSocket.connect( + url: url, + onMessage: self.onWebSocketMessage, + onDisconnect: { reason in + self.cleanUp(reason: reason) + self.queue.sync { + self.webSocket = nil + } + } + ) + .then(on: queue) { (webSocket: WebSocket) -> Void in + dispatchPrecondition(condition: .onQueue(self.queue)) + continuation.resume(returning: webSocket) + } + .recover(on: queue) { error -> Promise in + // Skip validation if reconnect mode + guard reconnectMode == nil else { throw error } + // Catch first, then throw again after getting validation response + // Re-build url with validate mode + guard let validateUrl = validatedURL() else { + return Promise(InternalError.parse(message: "Failed to parse validation url")) + } + + self.log("Validating with url: \(validateUrl)") + + return HTTP().get(on: self.queue, url: validateUrl).then(on: self.queue) { data in + guard let string = String(data: data, encoding: .utf8) else { + throw SignalClientError.connect(message: "Failed to decode string") + } + self.log("validate response: \(string)") + // re-throw with validation response + throw SignalClientError.connect(message: string) + } + } + .catch(on: queue) { error in + self.cleanUp(reason: .networkError(error)) + continuation.resume(throwing: error) + } + } + } + func connect(_ urlString: String, _ token: String, connectOptions: ConnectOptions? = nil, @@ -145,6 +283,52 @@ internal class SignalClient: MulticastDelegate { self.cleanUp(reason: .networkError(error)) } } + + func _cleanUp(reason: DisconnectReason? = nil) async { + log("reason: \(String(describing: reason))") + + _state.mutate { $0.connectionState = .disconnected(reason: reason) } + + pingIntervalTimer = nil + pingTimeoutTimer = nil + + if let socket = webSocket { + socket.cleanUp(reason: reason, notify: false) + socket.onMessage = nil + socket.onDisconnect = nil + self.webSocket = nil + } + + latestJoinResponse = nil + + _state.mutate { + for var completer in $0.completersForAddTrack.values { + completer.reset() + } + + $0.joinResponseCompleter.reset() + + // reset state + $0 = State() + } + + await withUnsafeContinuation { continuation in + requestDispatchQueue.async { [weak self] in + guard let self = self else { return } + self.requestQueue = [] + continuation.resume() + } + } + + await withUnsafeContinuation { continuation in + responseDispatchQueue.async { [weak self] in + guard let self = self else { return } + self.responseQueue = [] + self.responseQueueState = .resumed + continuation.resume() + } + } + } func cleanUp(reason: DisconnectReason? = nil) { @@ -219,6 +403,34 @@ internal class SignalClient: MulticastDelegate { private extension SignalClient { + func _sendRequest(_ request: Livekit_SignalRequest, enqueueIfReconnecting: Bool = true) async throws { + guard !(self._state.connectionState.isReconnecting && request.canEnqueue() && enqueueIfReconnecting) else { + self.log("queuing request while reconnecting, request: \(request)") + requestDispatchQueue.async { + self.requestQueue.append(request) + } + return + } + + guard case .connected = self.connectionState else { + self.log("not connected", .error) + throw SignalClientError.state(message: "Not connected") + } + + // this shouldn't happen + guard let webSocket = self.webSocket else { + self.log("webSocket is nil", .error) + throw SignalClientError.state(message: "WebSocket is nil") + } + + guard let data = try? request.serializedData() else { + self.log("could not serialize data", .error) + throw InternalError.convert(message: "Could not serialize data") + } + + try await webSocket.send(data: data) + } + // send request or enqueue while reconnecting func sendRequest(_ request: Livekit_SignalRequest, enqueueIfReconnecting: Bool = true) -> Promise { @@ -275,13 +487,15 @@ private extension SignalClient { } else { self.onSignalResponse(response) } + + self.receivedResponses.send(response) } } func onSignalResponse(_ response: Livekit_SignalResponse) { guard case .connected = connectionState else { - log("Not connected", .warning) + log("Not connected, received response though: \(response)", .warning) return } @@ -363,6 +577,35 @@ private extension SignalClient { internal extension SignalClient { + func _resumeResponseQueue() async { + + await withUnsafeContinuation({ (continuation: UnsafeContinuation) in + + responseDispatchQueue.async { + defer { self.responseQueueState = .resumed } + + // quickly return if no queued requests + guard !self.responseQueue.isEmpty else { + self.log("No queued response") + continuation.resume() + return + } + + // send requests in sequential order + let promises = self.responseQueue.reduce(into: Promise(())) { result, response in + result = result.then(on: self.queue) { self.onSignalResponse(response) } + } + + // clear the queue + self.responseQueue = [] + + promises.then(on: self.queue) { + continuation.resume() + } + } + }) + } + func resumeResponseQueue() -> Promise { log() @@ -422,6 +665,16 @@ internal extension SignalClient { return promises } } + + func _sendOffer(offer: RTCSessionDescription) async throws { + log() + + let request = Livekit_SignalRequest.with { + $0.offer = offer.toPBType() + } + + try await _sendRequest(request) + } func sendOffer(offer: RTCSessionDescription) -> Promise { log() @@ -599,6 +852,19 @@ internal extension SignalClient { return sendRequest(r) } + + func sendLeave() async throws { + log() + + let request = Livekit_SignalRequest.with { + $0.leave = Livekit_LeaveRequest.with { + $0.canReconnect = false + $0.reason = .clientInitiated + } + } + + try await _sendRequest(request) + } @discardableResult func sendLeave() -> Promise { diff --git a/Sources/LiveKit/Core/Transport.swift b/Sources/LiveKit/Core/Transport.swift index e87f1236f..881c605a2 100644 --- a/Sources/LiveKit/Core/Transport.swift +++ b/Sources/LiveKit/Core/Transport.swift @@ -18,6 +18,7 @@ import Foundation import WebRTC import Promises import SwiftProtobuf +import Combine internal typealias TransportOnOffer = (RTCSessionDescription) -> Promise @@ -36,6 +37,12 @@ internal class Transport: MulticastDelegate { public var connectionState: RTCPeerConnectionState { DispatchQueue.webRTC.sync { pc.connectionState } } + + + private var connectionStateSubject: CurrentValueSubject + public var connectionStatePublisher: AnyPublisher { + connectionStateSubject.eraseToAnyPublisher() + } public var localDescription: RTCSessionDescription? { DispatchQueue.webRTC.sync { pc.localDescription } @@ -93,6 +100,7 @@ internal class Transport: MulticastDelegate { self.target = target self.primary = primary self.pc = pc + self.connectionStateSubject = CurrentValueSubject(pc.connectionState) super.init() @@ -129,6 +137,11 @@ internal class Transport: MulticastDelegate { self.pendingCandidates.append(candidate) } } + + //WIP: + func setRemoteDescription(_ sd: RTCSessionDescription) async { + fatalError() + } @discardableResult func setRemoteDescription(_ sd: RTCSessionDescription) -> Promise { @@ -148,6 +161,30 @@ internal class Transport: MulticastDelegate { return Promise(()) } } + + func _createAndSendOffer(iceRestart: Bool = false) async throws -> RTCSessionDescription { + var constraints = [String: String]() + if iceRestart { + log("Restarting ICE...") + constraints[kRTCMediaConstraintsIceRestart] = kRTCMediaConstraintsValueTrue + restartingIce = true + } + + if signalingState == .haveLocalOffer, !(iceRestart && remoteDescription != nil) { + renegotiate = true + } + + if signalingState == .haveLocalOffer, iceRestart, let sd = remoteDescription { + await withUnsafeContinuation { continuation in + setRemoteDescriptionPromise(sd).then(on: queue) { _ in + continuation.resume() + } + } + } + + let offer = try await self.createOffer(for: constraints) + return try await self.setLocalDescription(offer) + } @discardableResult func createAndSendOffer(iceRestart: Bool = false) -> Promise { @@ -186,6 +223,26 @@ internal class Transport: MulticastDelegate { return negotiateSequence() } + + func _close() async { + // prevent debounced negotiate firing + self.debounceWorkItem?.cancel() + self.statsTimer.suspend() + + await withUnsafeContinuation({ continuation in + DispatchQueue.webRTC.async { + // Stop listening to delegate + self.pc.delegate = nil + // Remove all senders (if any) + for sender in self.pc.senders { + self.pc.removeTrack(sender) + } + + self.pc.close() + continuation.resume() + } + }) + } func close() -> Promise { @@ -264,6 +321,7 @@ extension Transport: RTCPeerConnectionDelegate { internal func peerConnection(_ peerConnection: RTCPeerConnection, didChange state: RTCPeerConnectionState) { log("did update state \(state) for \(target)") + connectionStateSubject.send(state) notify { $0.transport(self, didUpdate: state) } } @@ -321,6 +379,25 @@ extension Transport: RTCPeerConnectionDelegate { private extension Transport { + func createOffer(for constraints: [String: String]? = nil) async throws -> RTCSessionDescription { + try await withUnsafeThrowingContinuation { continuation in + DispatchQueue.webRTC.async { //<< once actor isolation is established, this becomes unnecessary + + let mediaConstraints = RTCMediaConstraints(mandatoryConstraints: constraints, optionalConstraints: nil) + self.pc.offer(for: mediaConstraints) { sd, error in + + guard let sd = sd else { + let engineError = EngineError.webRTC(message: "Failed to create offer", error) + continuation.resume(throwing: engineError) + return + } + + continuation.resume(returning: sd) + } + } + } + } + func createOffer(for constraints: [String: String]? = nil) -> Promise { Promise(on: .webRTC) { complete, fail in @@ -340,6 +417,23 @@ private extension Transport { } } + func setRemoteDescription(_ sd: RTCSessionDescription) async throws -> RTCSessionDescription { + try await withUnsafeThrowingContinuation { continuation in + DispatchQueue.webRTC.async { //<< once actor isolation is established, this becomes unnecessary + self.pc.setRemoteDescription(sd) { error in + + guard error == nil else { + let engineError = EngineError.webRTC(message: "failed to set remote description", error) + continuation.resume(throwing: engineError) + return + } + + continuation.resume(returning: sd) + } + } + } + } + func setRemoteDescriptionPromise(_ sd: RTCSessionDescription) -> Promise { Promise(on: .webRTC) { complete, fail in @@ -396,6 +490,23 @@ internal extension Transport { } } + func setLocalDescription(_ sd: RTCSessionDescription) async throws -> RTCSessionDescription { + try await withUnsafeThrowingContinuation { continuation in + DispatchQueue.webRTC.async { //<< once actor isolation is established, this becomes unnecessary + self.pc.setLocalDescription(sd) { error in + + guard error == nil else { + let engineError = EngineError.webRTC(message: "failed to set local description", error) + continuation.resume(throwing: engineError) + return + } + + continuation.resume(returning: sd) + } + } + } + } + func setLocalDescription(_ sd: RTCSessionDescription) -> Promise { Promise(on: .webRTC) { complete, fail in diff --git a/Sources/LiveKit/Errors.swift b/Sources/LiveKit/Errors.swift index 6964ed465..5edc94134 100644 --- a/Sources/LiveKit/Errors.swift +++ b/Sources/LiveKit/Errors.swift @@ -131,10 +131,12 @@ public enum NetworkError: LiveKitError { public enum TransportError: LiveKitError { case timedOut(message: String? = nil) + case noPrimary(message: String? = nil) public var description: String { switch self { case .timedOut(let message): return buildDescription("timedOut", message) + case .noPrimary(let message): return buildDescription("no primary", message) } } } diff --git a/Sources/LiveKit/Participant/LocalParticipant.swift b/Sources/LiveKit/Participant/LocalParticipant.swift index e1abd26ed..9920a3c18 100644 --- a/Sources/LiveKit/Participant/LocalParticipant.swift +++ b/Sources/LiveKit/Participant/LocalParticipant.swift @@ -202,7 +202,17 @@ public class LocalParticipant: Participant { publish(track: track, publishOptions: publishOptions) } - + + public override func _unpublishAll(notify: Bool = true) async { + await withUnsafeContinuation { continuation in + _state.tracks.values.compactMap { $0 as? LocalTrackPublication } + .map { unpublish(publication: $0, notify: notify) } + .all(on: queue).then(on: queue) { + continuation.resume() + } + } + } + public override func unpublishAll(notify _notify: Bool = true) -> Promise { // build a list of promises let promises = _state.tracks.values.compactMap { $0 as? LocalTrackPublication } diff --git a/Sources/LiveKit/Participant/Participant.swift b/Sources/LiveKit/Participant/Participant.swift index 0ce98f34f..386041c20 100644 --- a/Sources/LiveKit/Participant/Participant.swift +++ b/Sources/LiveKit/Participant/Participant.swift @@ -138,6 +138,12 @@ public class Participant: NSObject, Loggable { } } } + + func _cleanUp(notify: Bool = true) async { + await self._unpublishAll(notify: notify) + // reset state + self._state.mutate { $0 = State(identity: $0.identity, name: $0.name) } + } @discardableResult internal func cleanUp(notify _notify: Bool = true) -> Promise { @@ -147,6 +153,10 @@ public class Participant: NSObject, Loggable { self._state.mutate { $0 = State(identity: $0.identity, name: $0.name) } } } + + func _unpublishAll(notify: Bool = true) async { + fatalError("implement in subclass") + } internal func unpublishAll(notify _notify: Bool = true) -> Promise { fatalError("Unimplemented") diff --git a/Sources/LiveKit/Participant/RemoteParticipant.swift b/Sources/LiveKit/Participant/RemoteParticipant.swift index b9dab849d..9a3838fc1 100644 --- a/Sources/LiveKit/Participant/RemoteParticipant.swift +++ b/Sources/LiveKit/Participant/RemoteParticipant.swift @@ -140,7 +140,17 @@ public class RemoteParticipant: Participant { } } } - + + public override func _unpublishAll(notify: Bool = true) async { + await withUnsafeContinuation { continuation in + _state.tracks.values.compactMap { $0 as? RemoteTrackPublication } + .map { unpublish(publication: $0, notify: notify) } + .all(on: queue).then(on: queue) { + continuation.resume() + } + } + } + public override func unpublishAll(notify _notify: Bool = true) -> Promise { // build a list of promises let promises = _state.tracks.values.compactMap { $0 as? RemoteTrackPublication } diff --git a/Sources/LiveKit/Support/WebSocket.swift b/Sources/LiveKit/Support/WebSocket.swift index 994804dc8..a205462ea 100644 --- a/Sources/LiveKit/Support/WebSocket.swift +++ b/Sources/LiveKit/Support/WebSocket.swift @@ -106,6 +106,10 @@ internal class WebSocket: NSObject, URLSessionWebSocketDelegate, Loggable { onDisconnect?(reason) } } + + public func send(data: Data) async throws { + try await self.task.send(URLSessionWebSocketTask.Message.data(data)) + } public func send(data: Data) -> Promise { let message = URLSessionWebSocketTask.Message.data(data)