Skip to content

Commit

Permalink
Create accounts transactionally
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal committed Nov 27, 2023
1 parent 07c0400 commit c8033f8
Show file tree
Hide file tree
Showing 16 changed files with 852 additions and 263 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName());
KeysManager keys = new KeysManager(
dynamoDbAsyncClient,
KeysManager keysManager = new KeysManager(
dynamoDbAsyncClient,
config.getDynamoDbTables().getEcKeys().getTableName(),
config.getDynamoDbTables().getKemKeys().getTableName(),
config.getDynamoDbTables().getEcSignedPreKeys().getTableName(),
Expand Down Expand Up @@ -525,7 +525,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient,
config.getDynamoDbTables().getDeletedAccountsLock().getTableName());
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keys, messagesManager, profilesManager,
accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client,
clientPresenceManager,
experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, clock);
Expand Down Expand Up @@ -669,8 +669,8 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
.addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config))
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor))
.addService(new KeysAnonymousGrpcService(accountsManager, keys))
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keysManager, rateLimiters), basicCredentialAuthenticationInterceptor))
.addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
.addService(new PaymentsGrpcService(currencyManager))
.addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()), basicCredentialAuthenticationInterceptor))
Expand Down Expand Up @@ -725,7 +725,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
turnTokenGenerator,
registrationRecoveryPasswordsManager, usernameHashZkProofVerifier));

environment.jersey().register(new KeysController(rateLimiters, keys, accountsManager));
environment.jersey().register(new KeysController(rateLimiters, keysManager, accountsManager));

boolean registeredSpamFilter = false;
ReportSpamTokenProvider reportSpamTokenProvider = null;
Expand Down Expand Up @@ -784,7 +784,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
new CallLinkController(rateLimiters, callingGenericZkSecretParams),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager),
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, messagesManager, keys, rateLimiters,
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, messagesManager, keysManager, rateLimiters,
rateLimitersCluster, config.getMaxDevices(), clock),
new DirectoryV2Controller(directoryV2CredentialsGenerator),
new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(),
Expand All @@ -799,7 +799,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
config.getCdnConfiguration().bucket(), zkProfileOperations, batchIdentityCheckExecutor),
new ProvisioningController(rateLimiters, provisioningManager),
new RegistrationController(accountsManager, phoneVerificationTokenManager, registrationLockVerificationManager,
keys, rateLimiters),
rateLimiters),
new RemoteConfigController(remoteConfigsManager, adminEventLogger,
config.getRemoteConfigConfiguration().authorizedUsers(),
config.getRemoteConfigConfiguration().requiredHostedDomain(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
Expand All @@ -45,8 +43,6 @@
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;

Expand All @@ -69,18 +65,16 @@ public class RegistrationController {
private final AccountsManager accounts;
private final PhoneVerificationTokenManager phoneVerificationTokenManager;
private final RegistrationLockVerificationManager registrationLockVerificationManager;
private final KeysManager keysManager;
private final RateLimiters rateLimiters;

public RegistrationController(final AccountsManager accounts,
final PhoneVerificationTokenManager phoneVerificationTokenManager,
final RegistrationLockVerificationManager registrationLockVerificationManager,
final KeysManager keysManager,
final RateLimiters rateLimiters) {

this.accounts = accounts;
this.phoneVerificationTokenManager = phoneVerificationTokenManager;
this.registrationLockVerificationManager = registrationLockVerificationManager;
this.keysManager = keysManager;
this.rateLimiters = rateLimiters;
}

Expand Down Expand Up @@ -141,37 +135,19 @@ public AccountIdentityResponse register(
userAgent, RegistrationLockVerificationManager.Flow.REGISTRATION, verificationType);
}

Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new));

account = accounts.update(account, a -> {
a.setIdentityKey(registrationRequest.aciIdentityKey());
a.setPhoneNumberIdentityKey(registrationRequest.pniIdentityKey());

final Device device = a.getPrimaryDevice().orElseThrow();

device.setSignedPreKey(registrationRequest.deviceActivationRequest().aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(registrationRequest.deviceActivationRequest().pniSignedPreKey());

registrationRequest.deviceActivationRequest().apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});

registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));

CompletableFuture.allOf(
keysManager.storeEcSignedPreKeys(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey())),
keysManager.storePqLastResort(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey())),
keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey())),
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey())))
.join();
});
final Account account = accounts.create(number,
password,
signalAgent,
registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new),
registrationRequest.aciIdentityKey(),
registrationRequest.pniIdentityKey(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().gcmToken());

Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -28,6 +29,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
Expand Down Expand Up @@ -157,6 +159,7 @@ public Accounts(
final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName,
final String deletedAccountsTableName) {

super(client);
this.clock = clock;
this.asyncClient = asyncClient;
Expand All @@ -175,12 +178,14 @@ public Accounts(
final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName,
final String deletedAccountsTableName) {

this(Clock.systemUTC(), client, asyncClient, accountsTableName,
phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName,
deletedAccountsTableName);
}

public boolean create(final Account account) {
public boolean create(final Account account, final Function<Account, Collection<TransactWriteItem>> additionalWriteItemsFunction) {

return CREATE_TIMER.record(() -> {
try {
final AttributeValue uuidAttr = AttributeValues.fromUUID(account.getUuid());
Expand All @@ -199,8 +204,13 @@ public boolean create(final Account account) {
// the newly-created account.
final TransactWriteItem deletedAccountDelete = buildRemoveDeletedAccount(account.getNumber());

final Collection<TransactWriteItem> writeItems = new ArrayList<>(
List.of(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut, deletedAccountDelete));

writeItems.addAll(additionalWriteItemsFunction.apply(account));

final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut, deletedAccountDelete)
.transactItems(writeItems)
.build();

try {
Expand Down Expand Up @@ -229,7 +239,8 @@ public boolean create(final Account account) {
account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid));
final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow();
account.setNumber(existingAccount.getNumber(), existingAccount.getPhoneNumberIdentifier());
joinAndUnwrapUpdateFuture(reclaimAccount(existingAccount, account));
joinAndUnwrapUpdateFuture(reclaimAccount(existingAccount, account, additionalWriteItemsFunction.apply(account)));

return false;
}

Expand All @@ -254,7 +265,7 @@ public boolean create(final Account account) {
* @param existingAccount the existing account in the accounts table
* @param accountToCreate a new account, with the same number and identifier as existingAccount
*/
private CompletionStage<Void> reclaimAccount(final Account existingAccount, final Account accountToCreate) {
private CompletionStage<Void> reclaimAccount(final Account existingAccount, final Account accountToCreate, final Collection<TransactWriteItem> additionalWriteItems) {
if (!existingAccount.getUuid().equals(accountToCreate.getUuid()) ||
!existingAccount.getNumber().equals(accountToCreate.getNumber())) {
throw new IllegalArgumentException("reclaimed accounts must match");
Expand Down Expand Up @@ -310,6 +321,7 @@ private CompletionStage<Void> reclaimAccount(final Account existingAccount, fina
.build());
}
writeItems.add(UpdateAccountSpec.forAccount(accountsTableName, accountToCreate).transactItem());
writeItems.addAll(additionalWriteItems);

return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder().transactItems(writeItems).build())
.thenApply(response -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
Expand Down Expand Up @@ -175,17 +177,26 @@ public AccountsManager(final Accounts accounts,
this.clock = requireNonNull(clock);
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public Account create(final String number,
final String password,
final String signalAgent,
final AccountAttributes accountAttributes,
final List<AccountBadge> accountBadges) throws InterruptedException {
final List<AccountBadge> accountBadges,
final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey,
final Optional<ApnRegistrationId> maybeApnRegistrationId,
final Optional<GcmRegistrationId> maybeGcmRegistrationId) throws InterruptedException {

try (Timer.Context ignored = createTimer.time()) {
final Account account = new Account();

accountLockManager.withLock(List.of(number), () -> {
Device device = new Device();
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
Expand All @@ -196,6 +207,16 @@ public Account create(final String number,
device.setCreated(System.currentTimeMillis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent);
device.setSignedPreKey(aciSignedPreKey);
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey);

maybeApnRegistrationId.ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});

maybeGcmRegistrationId.ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));

account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number));

Expand All @@ -205,6 +226,8 @@ public Account create(final String number,
// Reuse the ACI from any recently-deleted account with this number to cover cases where somebody is
// re-registering.
account.setUuid(maybeRecentlyDeletedAccountIdentifier.orElseGet(UUID::randomUUID));
account.setIdentityKey(aciIdentityKey);
account.setPhoneNumberIdentityKey(pniIdentityKey);
account.addDevice(device);
account.setRegistrationLockFromAttributes(accountAttributes);
account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey());
Expand All @@ -214,7 +237,14 @@ public Account create(final String number,

final UUID originalUuid = account.getUuid();

boolean freshUser = accounts.create(account);
final boolean freshUser = accounts.create(account,
a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI),
a.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));

// create() sometimes updates the UUID, if there was a number conflict.
// for metrics, we want secondary to run with the same original UUID
Expand All @@ -235,9 +265,11 @@ public Account create(final String number,
// confident that everything has already been deleted. In the second case, though, we're taking over an existing
// account and need to clear out messages and keys that may have been stored for the old account.
if (!originalUuid.equals(actualUuid)) {
// We exclude the primary device's repeated-use keys from deletion because new keys were provided as part of
// the account creation process, and we don't want to delete the keys that just got added.
final CompletableFuture<Void> deleteKeysFuture = CompletableFuture.allOf(
keysManager.delete(actualUuid),
keysManager.delete(account.getPhoneNumberIdentifier()));
keysManager.delete(actualUuid, true),
keysManager.delete(account.getPhoneNumberIdentifier(), true));

messagesManager.clear(actualUuid).join();
profilesManager.deleteAll(actualUuid).join();
Expand Down
Loading

0 comments on commit c8033f8

Please sign in to comment.