Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix concurrent assertions #4426

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ import UIKit
/// Determines the `Stripe-Livemode` header value when the publishable key is a user key
@_spi(DashboardOnly) public var userKeyLiveMode = true

@_spi(STP) public lazy var stripeAttest: StripeAttest = StripeAttest(apiClient: self)

// MARK: Initializers
override public init() {
sourcePollers = [:]
Expand Down
85 changes: 71 additions & 14 deletions StripeCore/StripeCore/Source/Attestation/StripeAttest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,31 @@ import UIKit

/// Sign an assertion.
/// Will create and attest a new device key if needed.
@_spi(STP) public func assert() async throws -> Assertion {
/// Returns an AssertionHandle, which must be called after the network request completes (with success or failure) in order to unblock future assertions.
@_spi(STP) public func assert() async throws -> AssertionHandle {
// Make sure we only process one assertion at a time, until the latest
if assertionInProgress {
try await withCheckedThrowingContinuation { continuation in
assertionWaiters.append(continuation)
}
}
assertionInProgress = true

do {
let assertion = try await _assert()
let successAnalytic = GenericAnalytic(event: .assertionSucceeded, params: [:])
STPAnalyticsClient.sharedClient.log(analytic: successAnalytic, apiClient: apiClient)
return assertion
if let apiClient {
STPAnalyticsClient.sharedClient.log(analytic: successAnalytic, apiClient: apiClient)
}
return AssertionHandle(assertion: assertion, stripeAttest: self)
} catch {
let errorAnalytic = ErrorAnalytic(event: .assertionFailed, error: error)
STPAnalyticsClient.sharedClient.log(analytic: errorAnalytic, apiClient: apiClient)
if apiClient.isTestmode {
if let apiClient {
STPAnalyticsClient.sharedClient.log(analytic: errorAnalytic, apiClient: apiClient)
}
if apiClient?.isTestmode ?? false {
// In testmode, we can provide a test assertion even if the real assertion fails
return await testmodeAssertion()
return await AssertionHandle(assertion: testmodeAssertion(), stripeAttest: self)
} else {
throw error
}
Expand All @@ -56,7 +69,9 @@ import UIKit
/// The key will be reset.
@_spi(STP) public func receivedAssertionError(_ error: Error) {
let resetKeyAnalytic = ErrorAnalytic(event: .resetKeyForAssertionError, error: error)
STPAnalyticsClient.sharedClient.log(analytic: resetKeyAnalytic, apiClient: apiClient)
if let apiClient {
STPAnalyticsClient.sharedClient.log(analytic: resetKeyAnalytic, apiClient: apiClient)
}
resetKey()
}

Expand Down Expand Up @@ -149,8 +164,8 @@ import UIKit

/// The key to use for storing an attestation key in NSUserDefaults.
func defaultsKeyForSetting(_ setting: SettingsKeys) -> String {
var key = "\(setting.rawValue):\(apiClient.publishableKey ?? "unknown")"
if let stripeAccount = apiClient.stripeAccount {
var key = "\(setting.rawValue):\(apiClient?.publishableKey ?? "unknown")"
if let stripeAccount = apiClient?.stripeAccount {
key += ":\(stripeAccount)"
}
return key
Expand All @@ -170,7 +185,7 @@ import UIKit
/// A network backend for the /challenge and /attest endpoints.
let appAttestBackend: StripeAttestBackend
/// The API client to use for network requests
var apiClient: STPAPIClient
weak var apiClient: STPAPIClient?

/// The minimum time between key generation attempts.
/// This is a safeguard against generating keys too often, as each key generation
Expand All @@ -190,20 +205,27 @@ import UIKit
let task = Task<Void, Error> {
try await _attest()
let successAnalytic = GenericAnalytic(event: .attestationSucceeded, params: [:])
STPAnalyticsClient.sharedClient.log(analytic: successAnalytic, apiClient: apiClient)
if let apiClient {
STPAnalyticsClient.sharedClient.log(analytic: successAnalytic, apiClient: apiClient)
}
}
attestationTask = task
defer { attestationTask = nil } // Clear the task after it's done
do {
try await task.value
} catch {
let errorAnalytic = ErrorAnalytic(event: .attestationFailed, error: error)
STPAnalyticsClient.sharedClient.log(analytic: errorAnalytic, apiClient: apiClient)
if let apiClient {
STPAnalyticsClient.sharedClient.log(analytic: errorAnalytic, apiClient: apiClient)
}
throw error
}
}
private var attestationTask: Task<Void, Error>?

private var assertionInProgress: Bool = false
private var assertionWaiters: [CheckedContinuation<Void, Error>] = []

func _assert() async throws -> Assertion {
let keyId = try await self.getOrCreateKeyID()

Expand Down Expand Up @@ -273,7 +295,9 @@ import UIKit
if error.domain == DCErrorDomain && error.code == DCError.invalidKey.rawValue {
resetKey()
let resetKeyAnalytic = ErrorAnalytic(event: .resetKeyForAttestationError, error: error)
STPAnalyticsClient.sharedClient.log(analytic: resetKeyAnalytic, apiClient: apiClient)
if let apiClient {
STPAnalyticsClient.sharedClient.log(analytic: resetKeyAnalytic, apiClient: apiClient)
}
}
// For other errors, just report them as an analytic and throw. We'll want to retry attestation with the same key.
throw error
Expand All @@ -286,7 +310,7 @@ import UIKit
guard appAttestService.isSupported else {
throw AttestationError.attestationNotSupported
}
guard apiClient.publishableKey != nil else {
guard apiClient?.publishableKey != nil else {
throw AttestationError.noPublishableKey
}
if let keyId = storedKeyID {
Expand Down Expand Up @@ -365,10 +389,43 @@ import UIKit
}
}

// MARK: Assertion concurrency

// Called when an assertion handle is completed
private func assertionCompleted() {
assertionInProgress = false

// Resume the next waiter if there is one
if !assertionWaiters.isEmpty {
let nextContinuation = assertionWaiters.removeFirst()
nextContinuation.resume()
}
}

private func testmodeAssertion() async -> Assertion {
Assertion(assertionData: Data(bytes: [0x01, 0x02, 0x03], count: 3),
deviceID: (try? await getDeviceID()) ?? "test-device-id",
appID: (try? getAppID()) ?? "com.example.test",
keyID: "TestKeyID")
}
}

extension StripeAttest {
public class AssertionHandle {
public let assertion: Assertion
private weak var stripeAttest: StripeAttest?

init(assertion: Assertion, stripeAttest: StripeAttest) {
self.assertion = assertion
self.stripeAttest = stripeAttest
}

// Must be called by the caller when done with the assertion
public func complete() {
guard let stripeAttest = stripeAttest else { return }
Task {
await stripeAttest.assertionCompleted()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ actor MockAppAttestService: AppAttestService {
}
key.counter += 1
keys[key.id] = key
// Our fake assertion is the keyID glommed onto the clientDataHash
return key.id.data(using: .utf8)! + clientDataHash
// Generate a fake assertion
let assertion = ["keyID": key.id, "counter": key.counter, "clientDataHash": clientDataHash.base64EncodedString()] as [String: Any]
return try JSONSerialization.data(withJSONObject: assertion)
}

@_spi(STP) public func attestKey(_ keyId: String, clientDataHash: Data) async throws -> Data {
Expand All @@ -96,6 +97,7 @@ actor MockAppAttestService: AppAttestService {
@_spi(STP) public class MockAttestBackend: StripeAttestBackend {
var storedChallenge: String?
var keyHasBeenAttested: [String: Bool] = [:]
var keyCounter: [String: Int] = [:]

public func attest(appId: String, deviceId: String, keyId: String, attestation: Data) async throws {
// Decode the attestation data (it's a JSON dictionary)
Expand Down Expand Up @@ -131,11 +133,20 @@ actor MockAppAttestService: AppAttestService {
print(String(data: clientDataToHash, encoding: .utf8)!)
let clientDataHash = Data(SHA256.hash(data: clientDataToHash))

// Our fake assertion is the keyID glommed onto the clientDataHash
let expectedAssertionData = assertion.keyID.data(using: .utf8)! + clientDataHash
guard expectedAssertionData == assertion.assertionData else {
// Decode assertion (it's JSON)
let assertionDict = try JSONSerialization.jsonObject(with: assertion.assertionData) as! [String: Any]
let keyID = assertionDict["keyID"] as! String
let counter = assertionDict["counter"] as! Int
let assertionClientDataHash = assertionDict["clientDataHash"] as! String
guard clientDataHash.base64EncodedString() == assertionClientDataHash else {
throw NSError(domain: "com.stripe.internal-error", code: 403, userInfo: ["error": "Assertion data does not match expected data"])
}
// Confirm counter has incremented since last assertion
guard counter > (keyCounter[keyID] ?? 0) else {
throw NSError(domain: "com.stripe.internal-error", code: 403, userInfo: ["error": "Counter has not incremented"])
}
// Update counter
keyCounter[keyID] = counter
// Clean up the challenge
storedChallenge = nil
}
Expand Down
26 changes: 15 additions & 11 deletions StripeCore/StripeCoreTests/Attestation/StripeAttestTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class StripeAttestTest: XCTestCase {
var mockAttestService: MockAppAttestService!
var mockAttestBackend: MockAttestBackend!
var stripeAttest: StripeAttest!
let apiClient = STPAPIClient(publishableKey: "pk_live_abc123")

override func setUp() {
let apiClient = STPAPIClient(publishableKey: "pk_live_abc123")
self.mockAttestBackend = MockAttestBackend()
self.mockAttestService = MockAppAttestService()
self.stripeAttest = StripeAttest(appAttestService: mockAttestService, appAttestBackend: mockAttestBackend, apiClient: apiClient)
Expand All @@ -30,13 +30,13 @@ class StripeAttestTest: XCTestCase {

func testAppAttestService() async {
try! await stripeAttest.attest()
let assertionResponse = try! await stripeAttest.assert()
try! await self.mockAttestBackend.assertionTest(assertion: assertionResponse)
let assertionHandle = try! await stripeAttest.assert()
try! await self.mockAttestBackend.assertionTest(assertion: assertionHandle.assertion)
}

func testCanAssertWithoutAttestation() async {
let assertionResponse = try! await stripeAttest.assert()
try! await self.mockAttestBackend.assertionTest(assertion: assertionResponse)
let assertionHandle = try! await stripeAttest.assert()
try! await self.mockAttestBackend.assertionTest(assertion: assertionHandle.assertion)
}

func testCanOnlyAttestOncePerDayInProd() async {
Expand Down Expand Up @@ -92,7 +92,7 @@ class StripeAttestTest: XCTestCase {
}

func testNoPublishableKey() async {
await stripeAttest.apiClient.publishableKey = nil
await stripeAttest.apiClient!.publishableKey = nil
do {
// Create and attest a key
try await stripeAttest.attest()
Expand All @@ -104,23 +104,27 @@ class StripeAttestTest: XCTestCase {

func testAssertionsNotRequiredInTestMode() async {
// Configure a test merchant PK:
await stripeAttest.apiClient.publishableKey = "pk_test_abc123"
await stripeAttest.apiClient!.publishableKey = "pk_test_abc123"
// And reset the last attestation date:
await UserDefaults.standard.removeObject(forKey: self.stripeAttest.defaultsKeyForSetting(.lastAttestedDate))
// Fail the assertion, which will cause us to try to re-attest the key, but then the
// assertions still won't work, so we'll send the testmode data instead.
let invalidKeyError = NSError(domain: DCErrorDomain, code: DCError.invalidKey.rawValue, userInfo: nil)
await mockAttestService.setShouldFailAssertionWithError(invalidKeyError)
let assertion = try! await stripeAttest.assert()
XCTAssertEqual(assertion.keyID, "TestKeyID")
let assertionHandle = try! await stripeAttest.assert()
XCTAssertEqual(assertionHandle.assertion.keyID, "TestKeyID")
}

func testConcurrentAssertionsAndAttestations() async {
func testConcurrentAssertionsOccurSequentially() async {
let iterations = 500
try! await withThrowingTaskGroup(of: Void.self) { group in
for _ in 0..<iterations {
group.addTask {
try await self.stripeAttest.assert()
let assertionHandle = try! await self.stripeAttest.assert()
// Check the assertion against the mock backend (which will enforce that the counter value has incremented since the last assertion)
try! await self.mockAttestBackend.assertionTest(assertion: assertionHandle.assertion)
// Then complete the assertion
assertionHandle.complete()
}
}
try await group.waitForAll()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,19 @@ extension STPAPIClient {
return
}

if useMobileEndpoints {
do {
let attest = StripeAttest(apiClient: self)
let assertion = try await attest.assert()
parameters = parameters.merging(assertion.requestFields) { (_, new) in new }
} catch {
// If we can't get an assertion, we'll try the request anyway. It may fail.
let requestAssertionHandle: StripeAttest.AssertionHandle? = await {
if useMobileEndpoints {
do {
let assertionHandle = try await stripeAttest.assert()
parameters = parameters.merging(assertionHandle.assertion.requestFields) { (_, new) in new }
return assertionHandle
} catch {
// If we can't get an assertion, we'll try the request anyway. It may fail.
}
}
}
return nil
}()

post(
resource: useMobileEndpoints ? mobileEndpoint : legacyEndpoint,
parameters: parameters,
Expand All @@ -61,8 +65,10 @@ extension STPAPIClient {
if useMobileEndpoints,
case .failure(let error) = result,
Self.isLinkAssertionError(error: error) {
await StripeAttest(apiClient: self).receivedAssertionError(error)
await self.stripeAttest.receivedAssertionError(error)
}
// Mark the assertion handle as completed
requestAssertionHandle?.complete()
completion(result)
}
}
Expand Down Expand Up @@ -103,15 +109,18 @@ extension STPAPIClient {
parameters["consent_action"] = consentAction
}

if useMobileEndpoints {
do {
let attest = StripeAttest(apiClient: self)
let assertion = try await attest.assert()
parameters = parameters.merging(assertion.requestFields) { (_, new) in new }
} catch {
// If we can't get an assertion, we'll try the request anyway. It may fail.
let requestAssertionHandle: StripeAttest.AssertionHandle? = await {
if useMobileEndpoints {
do {
let assertionHandle = try await stripeAttest.assert()
parameters = parameters.merging(assertionHandle.assertion.requestFields) { (_, new) in new }
return assertionHandle
} catch {
// If we can't get an assertion, we'll try the request anyway. It may fail.
}
}
}
return nil
}()

post(
resource: useMobileEndpoints ? modernEndpoint : legacyEndpoint,
Expand All @@ -122,9 +131,10 @@ extension STPAPIClient {
if useMobileEndpoints,
case .failure(let error) = result,
Self.isLinkAssertionError(error: error) {
await StripeAttest(apiClient: self).receivedAssertionError(error)
await self.stripeAttest.receivedAssertionError(error)
}

// Mark the assertion handle as completed
requestAssertionHandle?.complete()
completion(result)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func prepareNativeLink(elementsSession: STPElementsSession, configuration: Payme
return true
}

let stripeAttest = StripeAttest(apiClient: configuration.apiClient)
let stripeAttest = configuration.apiClient.stripeAttest
return await stripeAttest.prepareAttestation()
}

Expand All @@ -197,5 +197,5 @@ func deviceCanUseNativeLink(elementsSession: STPElementsSession, configuration:
return true
}

return StripeAttest(apiClient: configuration.apiClient).isSupported
return configuration.apiClient.stripeAttest.isSupported
}
Loading