diff --git a/StripeCore/StripeCore/Source/API Bindings/STPAPIClient.swift b/StripeCore/StripeCore/Source/API Bindings/STPAPIClient.swift index 559da44ae5b..41dd4841c99 100644 --- a/StripeCore/StripeCore/Source/API Bindings/STPAPIClient.swift +++ b/StripeCore/StripeCore/Source/API Bindings/STPAPIClient.swift @@ -89,6 +89,8 @@ import UIKit /// Determines the `Stripe-Livemode` header value when the publishable key is a user key @_spi(DashboardOnly) public var userKeyLiveMode = true + @_spi(STP) public lazy var stripeAttest: StripeAttest = StripeAttest(apiClient: self) + // MARK: Initializers override public init() { sourcePollers = [:] diff --git a/StripeCore/StripeCore/Source/Attestation/StripeAttest.swift b/StripeCore/StripeCore/Source/Attestation/StripeAttest.swift index da6c4a4275f..56196ac6b80 100644 --- a/StripeCore/StripeCore/Source/Attestation/StripeAttest.swift +++ b/StripeCore/StripeCore/Source/Attestation/StripeAttest.swift @@ -18,18 +18,31 @@ import UIKit /// Sign an assertion. /// Will create and attest a new device key if needed. - @_spi(STP) public func assert() async throws -> Assertion { + /// Returns an AssertionHandle, which must be called after the network request completes (with success or failure) in order to unblock future assertions. + @_spi(STP) public func assert() async throws -> AssertionHandle { + // Make sure we only process one assertion at a time, until the latest + if assertionInProgress { + try await withCheckedThrowingContinuation { continuation in + assertionWaiters.append(continuation) + } + } + assertionInProgress = true + do { let assertion = try await _assert() let successAnalytic = GenericAnalytic(event: .assertionSucceeded, params: [:]) - STPAnalyticsClient.sharedClient.log(analytic: successAnalytic, apiClient: apiClient) - return assertion + if let apiClient { + STPAnalyticsClient.sharedClient.log(analytic: successAnalytic, apiClient: apiClient) + } + return AssertionHandle(assertion: assertion, stripeAttest: self) } catch { let errorAnalytic = ErrorAnalytic(event: .assertionFailed, error: error) - STPAnalyticsClient.sharedClient.log(analytic: errorAnalytic, apiClient: apiClient) - if apiClient.isTestmode { + if let apiClient { + STPAnalyticsClient.sharedClient.log(analytic: errorAnalytic, apiClient: apiClient) + } + if apiClient?.isTestmode ?? false { // In testmode, we can provide a test assertion even if the real assertion fails - return await testmodeAssertion() + return await AssertionHandle(assertion: testmodeAssertion(), stripeAttest: self) } else { throw error } @@ -56,7 +69,9 @@ import UIKit /// The key will be reset. @_spi(STP) public func receivedAssertionError(_ error: Error) { let resetKeyAnalytic = ErrorAnalytic(event: .resetKeyForAssertionError, error: error) - STPAnalyticsClient.sharedClient.log(analytic: resetKeyAnalytic, apiClient: apiClient) + if let apiClient { + STPAnalyticsClient.sharedClient.log(analytic: resetKeyAnalytic, apiClient: apiClient) + } resetKey() } @@ -149,8 +164,8 @@ import UIKit /// The key to use for storing an attestation key in NSUserDefaults. func defaultsKeyForSetting(_ setting: SettingsKeys) -> String { - var key = "\(setting.rawValue):\(apiClient.publishableKey ?? "unknown")" - if let stripeAccount = apiClient.stripeAccount { + var key = "\(setting.rawValue):\(apiClient?.publishableKey ?? "unknown")" + if let stripeAccount = apiClient?.stripeAccount { key += ":\(stripeAccount)" } return key @@ -170,7 +185,7 @@ import UIKit /// A network backend for the /challenge and /attest endpoints. let appAttestBackend: StripeAttestBackend /// The API client to use for network requests - var apiClient: STPAPIClient + weak var apiClient: STPAPIClient? /// The minimum time between key generation attempts. /// This is a safeguard against generating keys too often, as each key generation @@ -190,7 +205,9 @@ import UIKit let task = Task { try await _attest() let successAnalytic = GenericAnalytic(event: .attestationSucceeded, params: [:]) - STPAnalyticsClient.sharedClient.log(analytic: successAnalytic, apiClient: apiClient) + if let apiClient { + STPAnalyticsClient.sharedClient.log(analytic: successAnalytic, apiClient: apiClient) + } } attestationTask = task defer { attestationTask = nil } // Clear the task after it's done @@ -198,12 +215,17 @@ import UIKit try await task.value } catch { let errorAnalytic = ErrorAnalytic(event: .attestationFailed, error: error) - STPAnalyticsClient.sharedClient.log(analytic: errorAnalytic, apiClient: apiClient) + if let apiClient { + STPAnalyticsClient.sharedClient.log(analytic: errorAnalytic, apiClient: apiClient) + } throw error } } private var attestationTask: Task? + private var assertionInProgress: Bool = false + private var assertionWaiters: [CheckedContinuation] = [] + func _assert() async throws -> Assertion { let keyId = try await self.getOrCreateKeyID() @@ -273,7 +295,9 @@ import UIKit if error.domain == DCErrorDomain && error.code == DCError.invalidKey.rawValue { resetKey() let resetKeyAnalytic = ErrorAnalytic(event: .resetKeyForAttestationError, error: error) - STPAnalyticsClient.sharedClient.log(analytic: resetKeyAnalytic, apiClient: apiClient) + if let apiClient { + STPAnalyticsClient.sharedClient.log(analytic: resetKeyAnalytic, apiClient: apiClient) + } } // For other errors, just report them as an analytic and throw. We'll want to retry attestation with the same key. throw error @@ -286,7 +310,7 @@ import UIKit guard appAttestService.isSupported else { throw AttestationError.attestationNotSupported } - guard apiClient.publishableKey != nil else { + guard apiClient?.publishableKey != nil else { throw AttestationError.noPublishableKey } if let keyId = storedKeyID { @@ -365,6 +389,19 @@ import UIKit } } + // MARK: Assertion concurrency + + // Called when an assertion handle is completed + private func assertionCompleted() { + assertionInProgress = false + + // Resume the next waiter if there is one + if !assertionWaiters.isEmpty { + let nextContinuation = assertionWaiters.removeFirst() + nextContinuation.resume() + } + } + private func testmodeAssertion() async -> Assertion { Assertion(assertionData: Data(bytes: [0x01, 0x02, 0x03], count: 3), deviceID: (try? await getDeviceID()) ?? "test-device-id", @@ -372,3 +409,23 @@ import UIKit keyID: "TestKeyID") } } + +extension StripeAttest { + public class AssertionHandle { + public let assertion: Assertion + private weak var stripeAttest: StripeAttest? + + init(assertion: Assertion, stripeAttest: StripeAttest) { + self.assertion = assertion + self.stripeAttest = stripeAttest + } + + // Must be called by the caller when done with the assertion + public func complete() { + guard let stripeAttest = stripeAttest else { return } + Task { + await stripeAttest.assertionCompleted() + } + } + } +} diff --git a/StripeCore/StripeCoreTests/Attestation/MockAppAttestService.swift b/StripeCore/StripeCoreTests/Attestation/MockAppAttestService.swift index 51935353839..aade85f160b 100644 --- a/StripeCore/StripeCoreTests/Attestation/MockAppAttestService.swift +++ b/StripeCore/StripeCoreTests/Attestation/MockAppAttestService.swift @@ -68,8 +68,9 @@ actor MockAppAttestService: AppAttestService { } key.counter += 1 keys[key.id] = key - // Our fake assertion is the keyID glommed onto the clientDataHash - return key.id.data(using: .utf8)! + clientDataHash + // Generate a fake assertion + let assertion = ["keyID": key.id, "counter": key.counter, "clientDataHash": clientDataHash.base64EncodedString()] as [String: Any] + return try JSONSerialization.data(withJSONObject: assertion) } @_spi(STP) public func attestKey(_ keyId: String, clientDataHash: Data) async throws -> Data { @@ -96,6 +97,7 @@ actor MockAppAttestService: AppAttestService { @_spi(STP) public class MockAttestBackend: StripeAttestBackend { var storedChallenge: String? var keyHasBeenAttested: [String: Bool] = [:] + var keyCounter: [String: Int] = [:] public func attest(appId: String, deviceId: String, keyId: String, attestation: Data) async throws { // Decode the attestation data (it's a JSON dictionary) @@ -131,11 +133,20 @@ actor MockAppAttestService: AppAttestService { print(String(data: clientDataToHash, encoding: .utf8)!) let clientDataHash = Data(SHA256.hash(data: clientDataToHash)) - // Our fake assertion is the keyID glommed onto the clientDataHash - let expectedAssertionData = assertion.keyID.data(using: .utf8)! + clientDataHash - guard expectedAssertionData == assertion.assertionData else { + // Decode assertion (it's JSON) + let assertionDict = try JSONSerialization.jsonObject(with: assertion.assertionData) as! [String: Any] + let keyID = assertionDict["keyID"] as! String + let counter = assertionDict["counter"] as! Int + let assertionClientDataHash = assertionDict["clientDataHash"] as! String + guard clientDataHash.base64EncodedString() == assertionClientDataHash else { throw NSError(domain: "com.stripe.internal-error", code: 403, userInfo: ["error": "Assertion data does not match expected data"]) } + // Confirm counter has incremented since last assertion + guard counter > (keyCounter[keyID] ?? 0) else { + throw NSError(domain: "com.stripe.internal-error", code: 403, userInfo: ["error": "Counter has not incremented"]) + } + // Update counter + keyCounter[keyID] = counter // Clean up the challenge storedChallenge = nil } diff --git a/StripeCore/StripeCoreTests/Attestation/StripeAttestTest.swift b/StripeCore/StripeCoreTests/Attestation/StripeAttestTest.swift index 5b930fc7303..fcfe7ca79aa 100644 --- a/StripeCore/StripeCoreTests/Attestation/StripeAttestTest.swift +++ b/StripeCore/StripeCoreTests/Attestation/StripeAttestTest.swift @@ -11,9 +11,9 @@ class StripeAttestTest: XCTestCase { var mockAttestService: MockAppAttestService! var mockAttestBackend: MockAttestBackend! var stripeAttest: StripeAttest! + let apiClient = STPAPIClient(publishableKey: "pk_live_abc123") override func setUp() { - let apiClient = STPAPIClient(publishableKey: "pk_live_abc123") self.mockAttestBackend = MockAttestBackend() self.mockAttestService = MockAppAttestService() self.stripeAttest = StripeAttest(appAttestService: mockAttestService, appAttestBackend: mockAttestBackend, apiClient: apiClient) @@ -30,13 +30,13 @@ class StripeAttestTest: XCTestCase { func testAppAttestService() async { try! await stripeAttest.attest() - let assertionResponse = try! await stripeAttest.assert() - try! await self.mockAttestBackend.assertionTest(assertion: assertionResponse) + let assertionHandle = try! await stripeAttest.assert() + try! await self.mockAttestBackend.assertionTest(assertion: assertionHandle.assertion) } func testCanAssertWithoutAttestation() async { - let assertionResponse = try! await stripeAttest.assert() - try! await self.mockAttestBackend.assertionTest(assertion: assertionResponse) + let assertionHandle = try! await stripeAttest.assert() + try! await self.mockAttestBackend.assertionTest(assertion: assertionHandle.assertion) } func testCanOnlyAttestOncePerDayInProd() async { @@ -92,7 +92,7 @@ class StripeAttestTest: XCTestCase { } func testNoPublishableKey() async { - await stripeAttest.apiClient.publishableKey = nil + await stripeAttest.apiClient!.publishableKey = nil do { // Create and attest a key try await stripeAttest.attest() @@ -104,23 +104,27 @@ class StripeAttestTest: XCTestCase { func testAssertionsNotRequiredInTestMode() async { // Configure a test merchant PK: - await stripeAttest.apiClient.publishableKey = "pk_test_abc123" + await stripeAttest.apiClient!.publishableKey = "pk_test_abc123" // And reset the last attestation date: await UserDefaults.standard.removeObject(forKey: self.stripeAttest.defaultsKeyForSetting(.lastAttestedDate)) // Fail the assertion, which will cause us to try to re-attest the key, but then the // assertions still won't work, so we'll send the testmode data instead. let invalidKeyError = NSError(domain: DCErrorDomain, code: DCError.invalidKey.rawValue, userInfo: nil) await mockAttestService.setShouldFailAssertionWithError(invalidKeyError) - let assertion = try! await stripeAttest.assert() - XCTAssertEqual(assertion.keyID, "TestKeyID") + let assertionHandle = try! await stripeAttest.assert() + XCTAssertEqual(assertionHandle.assertion.keyID, "TestKeyID") } - func testConcurrentAssertionsAndAttestations() async { + func testConcurrentAssertionsOccurSequentially() async { let iterations = 500 try! await withThrowingTaskGroup(of: Void.self) { group in for _ in 0..