Skip to content

Commit

Permalink
Add devices to accounts transactionally
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal authored Dec 7, 2023
1 parent e084a9f commit 50d9226
Show file tree
Hide file tree
Showing 10 changed files with 527 additions and 275 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair;
Expand Down Expand Up @@ -403,60 +404,63 @@ private Pair<Account, Device> createDevice(final String password,
throw new WebApplicationException(Response.status(409).build());
}

final Device device = new Device();
device.setName(accountAttributes.getName());
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());

maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey());

deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});

deviceActivationRequest.gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
});

final Account updatedAccount = accounts.update(account, a -> {
device.setId(a.getNextDeviceId());

final CompletableFuture<Void> deleteKeysFuture = CompletableFuture.allOf(
keys.delete(a.getUuid(), device.getId()),
keys.delete(a.getPhoneNumberIdentifier(), device.getId()));

messages.clear(a.getUuid(), device.getId()).join();

deleteKeysFuture.join();

maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey())),
keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey())),
keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey())))
.join());

a.addDevice(device);
});

if (maybeAciFromToken.isPresent()) {
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
}

return new Pair<>(updatedAccount, device);
return maybeDeviceActivationRequest.map(deviceActivationRequest -> {
final String signalAgent;

if (deviceActivationRequest.apnToken().isPresent()) {
signalAgent = "OWP";
} else if (deviceActivationRequest.gcmToken().isPresent()) {
signalAgent = "OWA";
} else {
signalAgent = "OWD";
}

return accounts.addDevice(account, new DeviceSpec(accountAttributes.getName(),
password,
signalAgent,
capabilities,
accountAttributes.getRegistrationId(),
accountAttributes.getPhoneNumberIdentityRegistrationId(),
accountAttributes.getFetchesMessages(),
deviceActivationRequest.apnToken(),
deviceActivationRequest.gcmToken(),
deviceActivationRequest.aciSignedPreKey(),
deviceActivationRequest.pniSignedPreKey(),
deviceActivationRequest.aciPqLastResortPreKey(),
deviceActivationRequest.pniPqLastResortPreKey()))
.thenCompose(a -> usedTokenCluster.withCluster(connection -> connection.async()
.set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)))
.thenApply(ignored -> a))
.join();
})
.orElseGet(() -> {
final Device device = new Device();
device.setName(accountAttributes.getName());
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());

final Account updatedAccount = accounts.update(account, a -> {
device.setId(a.getNextDeviceId());

CompletableFuture.allOf(
keys.delete(a.getUuid(), device.getId()),
keys.delete(a.getPhoneNumberIdentifier(), device.getId()),
messages.clear(a.getUuid(), device.getId()))
.join();

a.addDevice(device);
});

usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));

return new Pair<>(updatedAccount, device);
});
}

private static String getUsedTokenKey(final String token) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;

Expand Down Expand Up @@ -140,18 +141,24 @@ public AccountIdentityResponse register(
}

final Account account = accounts.create(number,
password,
signalAgent,
registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new),
registrationRequest.aciIdentityKey(),
registrationRequest.pniIdentityKey(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().gcmToken());
new DeviceSpec(
registrationRequest.accountAttributes().getName(),
password,
signalAgent,
registrationRequest.accountAttributes().getCapabilities(),
registrationRequest.accountAttributes().getRegistrationId(),
registrationRequest.accountAttributes().getPhoneNumberIdentityRegistrationId(),
registrationRequest.accountAttributes().getFetchesMessages(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().gcmToken(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey()));

Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
Expand All @@ -68,6 +66,7 @@
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.ParallelFlux;
Expand Down Expand Up @@ -132,11 +131,6 @@ public class AccountsManager {

private static final int MAX_UPDATE_ATTEMPTS = 10;

@FunctionalInterface
private interface AccountPersister {
void persistAccount(Account account) throws UsernameHashNotAvailableException;
}

public enum DeletionReason {
ADMIN_DELETED("admin"),
EXPIRED ("expired"),
Expand Down Expand Up @@ -181,46 +175,18 @@ public AccountsManager(final Accounts accounts,
this.clock = requireNonNull(clock);
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public Account create(final String number,
final String password,
final String signalAgent,
final AccountAttributes accountAttributes,
final List<AccountBadge> accountBadges,
final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey,
final Optional<ApnRegistrationId> maybeApnRegistrationId,
final Optional<GcmRegistrationId> maybeGcmRegistrationId) throws InterruptedException {
final DeviceSpec primaryDeviceSpec) throws InterruptedException {

try (Timer.Context ignored = createTimer.time()) {
final Account account = new Account();

accountLockManager.withLock(List.of(number), () -> {
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setName(accountAttributes.getName());
device.setCapabilities(accountAttributes.getCapabilities());
device.setCreated(System.currentTimeMillis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent);
device.setSignedPreKey(aciSignedPreKey);
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey);

maybeApnRegistrationId.ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});

maybeGcmRegistrationId.ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
final Device device = primaryDeviceSpec.toDevice(Device.PRIMARY_ID, clock);

account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number));

Expand All @@ -245,10 +211,10 @@ public Account create(final String number,
a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI),
a.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey),
primaryDeviceSpec.aciSignedPreKey(),
primaryDeviceSpec.pniSignedPreKey(),
primaryDeviceSpec.aciPqLastResortPreKey(),
primaryDeviceSpec.pniPqLastResortPreKey()),
(aci, pni) -> CompletableFuture.allOf(
keysManager.delete(aci),
keysManager.delete(pni),
Expand Down Expand Up @@ -299,6 +265,42 @@ public Account create(final String number,
}
}

public CompletableFuture<Pair<Account, Device>> addDevice(final Account account, final DeviceSpec deviceSpec) {
return addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, MAX_UPDATE_ATTEMPTS);
}

private CompletableFuture<Pair<Account, Device>> addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final int retries) {
return accounts.getByAccountIdentifierAsync(accountIdentifier)
.thenApply(maybeAccount -> maybeAccount.orElseThrow(ContestedOptimisticLockException::new))
.thenCompose(account -> {
final byte nextDeviceId = account.getNextDeviceId();
account.addDevice(deviceSpec.toDevice(nextDeviceId, clock));

final List<TransactWriteItem> additionalWriteItems = keysManager.buildWriteItemsForRepeatedUseKeys(
account.getIdentifier(IdentityType.ACI),
account.getIdentifier(IdentityType.PNI),
nextDeviceId,
deviceSpec.aciSignedPreKey(),
deviceSpec.pniSignedPreKey(),
deviceSpec.aciPqLastResortPreKey(),
deviceSpec.pniPqLastResortPreKey());

return CompletableFuture.allOf(
keysManager.delete(account.getUuid(), nextDeviceId),
keysManager.delete(account.getPhoneNumberIdentifier(), nextDeviceId),
messagesManager.clear(account.getUuid(), nextDeviceId))
.thenCompose(ignored -> accounts.updateTransactionallyAsync(account, additionalWriteItems))
.thenApply(ignored -> new Pair<>(account, account.getDevice(nextDeviceId).orElseThrow()));
})
.exceptionallyCompose(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException && retries > 0) {
return addDevice(accountIdentifier, deviceSpec, retries - 1);
}

return CompletableFuture.failedFuture(throwable);
});
}

public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) {
throw new IllegalArgumentException("Cannot remove primary device");
Expand Down Expand Up @@ -705,19 +707,6 @@ private Account updateWithRetries(Account account,
final Consumer<Account> persister,
final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) {
try {
return failableUpdateWithRetries(account, updater, persister::accept, retriever, changeValidator);
} catch (UsernameHashNotAvailableException e) {
// not possible
throw new IllegalStateException(e);
}
}

private Account failableUpdateWithRetries(Account account,
final Function<Account, Boolean> updater,
final AccountPersister persister,
final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException {

Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);

Expand All @@ -731,7 +720,7 @@ private Account failableUpdateWithRetries(Account account,
while (tries < maxTries) {

try {
persister.persistAccount(account);
persister.accept(account);

final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
account.markStale();
Expand Down
Loading

0 comments on commit 50d9226

Please sign in to comment.