Skip to content

Commit

Permalink
Update signed pre-keys in transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal authored Dec 5, 2023
1 parent ede9297 commit df421e0
Show file tree
Hide file tree
Showing 8 changed files with 417 additions and 254 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
Expand Down Expand Up @@ -67,6 +69,8 @@ public class KeysController {
private final AccountsManager accounts;
private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys");

private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0];

public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) {
this.rateLimiters = rateLimiters;
this.keys = keys;
Expand Down Expand Up @@ -110,24 +114,51 @@ public CompletableFuture<Response> setKeys(@Auth final DisabledPermittedAuthenti
description="whether this operation applies to the account (aci) or phone-number (pni) identity")
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {

Account account = disabledPermittedAuth.getAccount();
final Account account = disabledPermittedAuth.getAccount();
final Device device = disabledPermittedAuth.getAuthenticatedDevice();
final UUID identifier = account.getIdentifier(identityType);

checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType));

final CompletableFuture<Account> updateAccountFuture;

if (setKeysRequest.signedPreKey() != null &&
!setKeysRequest.signedPreKey().equals(device.getSignedPreKey(identityType))) {

account = accounts.update(account, a -> a.getDevice(device.getId()).ifPresent(d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey());
case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey());
}
}));
updateAccountFuture = accounts.updateDeviceTransactionallyAsync(account,
device.getId(),
d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey());
case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey());
}
},
d -> keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), setKeysRequest.signedPreKey())
.map(List::of)
.orElseGet(Collections::emptyList))
.toCompletableFuture();
} else {
updateAccountFuture = CompletableFuture.completedFuture(account);
}

return keys.store(account.getIdentifier(identityType), device.getId(),
setKeysRequest.preKeys(), setKeysRequest.pqPreKeys(), setKeysRequest.signedPreKey(), setKeysRequest.pqLastResortPreKey())
return updateAccountFuture.thenCompose(updatedAccount -> {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(3);

if (setKeysRequest.preKeys() != null) {
storeFutures.add(keys.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys()));
}

if (setKeysRequest.pqPreKeys() != null) {
storeFutures.add(keys.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys()));
}

if (setKeysRequest.pqLastResortPreKey() != null) {
storeFutures.add(
keys.storePqLastResort(identifier, Map.of(device.getId(), setKeysRequest.pqLastResortPreKey())));
}

return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY));
})
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}

Expand Down Expand Up @@ -265,17 +296,21 @@ public CompletableFuture<Response> setSignedKey(@Auth final AuthenticatedAccount
@Valid final ECSignedPreKey signedPreKey,
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {

Device device = auth.getAuthenticatedDevice();

accounts.updateDevice(auth.getAccount(), device.getId(), d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(signedPreKey);
case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey);
}
});

return keys.storeEcSignedPreKeys(auth.getAccount().getIdentifier(identityType),
Map.of(device.getId(), signedPreKey))
final UUID identifier = auth.getAccount().getIdentifier(identityType);
final byte deviceId = auth.getAuthenticatedDevice().getId();

return accounts.updateDeviceTransactionallyAsync(auth.getAccount(),
deviceId,
d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(signedPreKey);
case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey);
}
},
d -> keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), signedPreKey)
.map(List::of)
.orElseGet(Collections::emptyList))
.toCompletableFuture()
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public class Accounts extends AbstractDynamoDbStore {
private static final Timer RESERVE_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "reserveUsername"));
private static final Timer CLEAR_USERNAME_HASH_TIMER = Metrics.timer(name(Accounts.class, "clearUsernameHash"));
private static final Timer UPDATE_TIMER = Metrics.timer(name(Accounts.class, "update"));
private static final Timer UPDATE_TRANSACTIONALLY_TIMER = Metrics.timer(name(Accounts.class, "updateTransactionally"));
private static final Timer RECLAIM_TIMER = Metrics.timer(name(Accounts.class, "reclaim"));
private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "getByNumber"));
private static final Timer GET_BY_USERNAME_HASH_TIMER = Metrics.timer(name(Accounts.class, "getByUsernameHash"));
Expand Down Expand Up @@ -277,6 +278,7 @@ private CompletionStage<Void> reclaimAccount(final Account existingAccount, fina
!existingAccount.getNumber().equals(accountToCreate.getNumber())) {
throw new IllegalArgumentException("reclaimed accounts must match");
}

return AsyncTimerUtil.record(RECLAIM_TIMER, () -> {

accountToCreate.setVersion(existingAccount.getVersion());
Expand Down Expand Up @@ -364,7 +366,8 @@ private CompletionStage<Void> reclaimAccount(final Account existingAccount, fina
public void changeNumber(final Account account,
final String number,
final UUID phoneNumberIdentifier,
final Optional<UUID> maybeDisplacedAccountIdentifier) {
final Optional<UUID> maybeDisplacedAccountIdentifier,
final Collection<TransactWriteItem> additionalWriteItems) {

CHANGE_NUMBER_TIMER.record(() -> {
final String originalNumber = account.getNumber();
Expand Down Expand Up @@ -413,6 +416,8 @@ public void changeNumber(final Account account,
.build())
.build());

writeItems.addAll(additionalWriteItems);

final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(writeItems)
.build();
Expand Down Expand Up @@ -863,6 +868,35 @@ public void update(final Account account) throws ContestedOptimisticLockExceptio
joinAndUnwrapUpdateFuture(updateAsync(account));
}

public CompletionStage<Void> updateTransactionallyAsync(final Account account,
final Collection<TransactWriteItem> additionalWriteItems) {

return AsyncTimerUtil.record(UPDATE_TRANSACTIONALLY_TIMER, () -> {
final List<TransactWriteItem> writeItems = new ArrayList<>(additionalWriteItems.size() + 1);
writeItems.add(UpdateAccountSpec.forAccount(accountsTableName, account).transactItem());
writeItems.addAll(additionalWriteItems);

return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder()
.transactItems(writeItems)
.build())
.thenApply(response -> {
account.setVersion(account.getVersion() + 1);
return (Void) null;
})
.exceptionally(throwable -> {
final Throwable unwrapped = ExceptionUtils.unwrap(throwable);

if (unwrapped instanceof TransactionCanceledException transactionCanceledException) {
if ("ConditionalCheckFailed".equals(transactionCanceledException.cancellationReasons().get(0).code())) {
throw new ContestedOptimisticLockException();
}
}

throw CompletableFutureUtils.errorAsCompletionException(throwable);
});
});
}

public CompletableFuture<Boolean> usernameHashAvailable(final byte[] username) {
return usernameHashAvailable(Optional.empty(), username);
}
Expand Down
Loading

0 comments on commit df421e0

Please sign in to comment.