Skip to content

Commit

Permalink
Perform cleanup operations before overwriting an existing account record
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal authored Dec 5, 2023
1 parent 331bbdd commit 5f0726a
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
.minThreads(8)
.maxThreads(8)
.build();
ExecutorService clientPresenceExecutor = environment.lifecycle()
.executorService(name(getClass(), "clientPresence-%d"))
.minThreads(8)
.maxThreads(8)
.build();
ScheduledExecutorService subscriptionProcessorRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "subscriptionProcessorRetry-%d")).threads(1).build();

Expand Down Expand Up @@ -540,7 +545,8 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client,
clientPresenceManager,
experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, clock);
experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, clientPresenceExecutor,
clock);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials().value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.util.AsyncTimerUtil;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
Expand Down Expand Up @@ -184,7 +186,9 @@ public Accounts(
deletedAccountsTableName);
}

public boolean create(final Account account, final Function<Account, Collection<TransactWriteItem>> additionalWriteItemsFunction) {
public boolean create(final Account account,
final Function<Account, Collection<TransactWriteItem>> additionalWriteItemsFunction,
final BiFunction<UUID, UUID, CompletableFuture<Void>> existingAccountCleanupOperation) {

return CREATE_TIMER.record(() -> {
try {
Expand Down Expand Up @@ -239,7 +243,10 @@ public boolean create(final Account account, final Function<Account, Collection<
account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid));
final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow();
account.setNumber(existingAccount.getNumber(), existingAccount.getPhoneNumberIdentifier());
joinAndUnwrapUpdateFuture(reclaimAccount(existingAccount, account, additionalWriteItemsFunction.apply(account)));

existingAccountCleanupOperation.apply(existingAccount.getIdentifier(IdentityType.ACI), existingAccount.getIdentifier(IdentityType.PNI))
.thenCompose(ignored -> reclaimAccount(existingAccount, account, additionalWriteItemsFunction.apply(account)))
.join();

return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public class AccountsManager {
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final Executor accountLockExecutor;
private final Executor clientPresenceExecutor;
private final Clock clock;

private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
Expand Down Expand Up @@ -159,6 +160,7 @@ public AccountsManager(final Accounts accounts,
final ExperimentEnrollmentManager experimentEnrollmentManager,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final Executor accountLockExecutor,
final Executor clientPresenceExecutor,
final Clock clock) {
this.accounts = accounts;
this.phoneNumberIdentifiers = phoneNumberIdentifiers;
Expand All @@ -173,6 +175,7 @@ public AccountsManager(final Accounts accounts,
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager);
this.accountLockExecutor = accountLockExecutor;
this.clientPresenceExecutor = clientPresenceExecutor;
this.clock = requireNonNull(clock);
}

Expand Down Expand Up @@ -243,14 +246,33 @@ public Account create(final String number,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));

// create() sometimes updates the UUID, if there was a number conflict.
// for metrics, we want secondary to run with the same original UUID
final UUID actualUuid = account.getUuid();
pniPqLastResortPreKey),
(aci, pni) -> CompletableFuture.allOf(
keysManager.delete(aci),
keysManager.delete(pni),
messagesManager.clear(aci),
profilesManager.deleteAll(aci)
).thenRunAsync(() -> clientPresenceManager.disconnectAllPresencesForUuid(aci), clientPresenceExecutor));

if (!account.getUuid().equals(originalUuid)) {
// If the UUID changed, then we overwrote an existing account. We should have cleared all messages before
// overwriting the old account, but more may have arrived while we were working. Similarly, the old account
// holder could have added keys or profiles. We'll largely repeat the cleanup process after creating the
// account to make sure we really REALLY got everything.
//
// We exclude the primary device's repeated-use keys from deletion because new keys were provided as
// part of the account creation process, and we don't want to delete the keys that just got added.
CompletableFuture.allOf(keysManager.delete(account.getIdentifier(IdentityType.ACI), true),
keysManager.delete(account.getIdentifier(IdentityType.PNI), true),
messagesManager.clear(account.getIdentifier(IdentityType.ACI)),
profilesManager.deleteAll(account.getIdentifier(IdentityType.ACI)))
.join();
}

redisSet(account);

final Tags tags;

// In terms of previously-existing accounts, there are three possible cases:
//
// 1. This is a completely new account; there was no pre-existing account and no recently-deleted account
Expand All @@ -259,27 +281,6 @@ public Account create(final String number,
// instance to match the stored account record (i.e. originalUuid != actualUuid).
// 3. This is a re-registration of a recently-deleted account, in which case maybeRecentlyDeletedUuid is
// present.
//
// All cases are mutually-exclusive. In the first case, we don't need to do anything. In the third, we can be
// confident that everything has already been deleted. In the second case, though, we're taking over an existing
// account and need to clear out messages and keys that may have been stored for the old account.
if (!originalUuid.equals(actualUuid)) {
// We exclude the primary device's repeated-use keys from deletion because new keys were provided as part of
// the account creation process, and we don't want to delete the keys that just got added.
final CompletableFuture<Void> deleteKeysFuture = CompletableFuture.allOf(
keysManager.delete(actualUuid, true),
keysManager.delete(account.getPhoneNumberIdentifier(), true));

messagesManager.clear(actualUuid).join();
profilesManager.deleteAll(actualUuid).join();

deleteKeysFuture.join();

clientPresenceManager.disconnectAllPresencesForUuid(actualUuid);
}

final Tags tags;

if (freshUser) {
tags = Tags.of("type", maybeRecentlyDeletedAccountIdentifier.isPresent() ? "recently-deleted" : "new");
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ protected void run(Environment environment, Namespace namespace,
.executorService(name(getClass(), "storageService-%d")).maxThreads(1).minThreads(1).build();
ExecutorService accountLockExecutor = environment.lifecycle()
.executorService(name(getClass(), "accountLock-%d")).minThreads(1).maxThreads(1).build();
ExecutorService clientPresenceExecutor = environment.lifecycle()
.executorService(name(getClass(), "clientPresence-%d")).minThreads(1).maxThreads(1).build();
ScheduledExecutorService secureValueRecoveryServiceRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "secureValueRecoveryServiceRetry-%d")).threads(1).build();
ScheduledExecutorService storageServiceRetryExecutor = environment.lifecycle()
Expand Down Expand Up @@ -202,7 +204,8 @@ protected void run(Environment environment, Namespace namespace,
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keys, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, clientPresenceManager,
experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, Clock.systemUTC());
experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, clientPresenceExecutor,
Clock.systemUTC());

final String usernameHash = namespace.getString("usernameHash");
final String encryptedUsername = namespace.getString("encryptedUsername");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ static CommandDependencies build(
.executorService(name(name, "storageService-%d")).maxThreads(8).minThreads(8).build();
ExecutorService accountLockExecutor = environment.lifecycle()
.executorService(name(name, "accountLock-%d")).minThreads(8).maxThreads(8).build();
ExecutorService clientPresenceExecutor = environment.lifecycle()
.executorService(name(name, "clientPresence-%d")).minThreads(8).maxThreads(8).build();

ScheduledExecutorService secureValueRecoveryServiceRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "secureValueRecoveryServiceRetry-%d")).threads(1).build();
Expand Down Expand Up @@ -177,7 +179,8 @@ static CommandDependencies build(
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keys, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, clientPresenceManager,
experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, clock);
experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, clientPresenceExecutor,
clock);

environment.lifecycle().manage(messagesCache);
environment.lifecycle().manage(clientPresenceManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class AccountCreationIntegrationTest {
private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());

private ExecutorService accountLockExecutor;
private ExecutorService clientPresenceExecutor;

private AccountsManager accountsManager;
private KeysManager keysManager;
Expand Down Expand Up @@ -99,6 +100,7 @@ void setUp() {
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName());

accountLockExecutor = Executors.newSingleThreadExecutor();
clientPresenceExecutor = Executors.newSingleThreadExecutor();

final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName());
Expand Down Expand Up @@ -139,15 +141,20 @@ void setUp() {
mock(ExperimentEnrollmentManager.class),
registrationRecoveryPasswordsManager,
accountLockExecutor,
clientPresenceExecutor,
CLOCK);
}

@AfterEach
void tearDown() throws InterruptedException {
accountLockExecutor.shutdown();
clientPresenceExecutor.shutdown();

//noinspection ResultOfMethodCallIgnored
accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS);

//noinspection ResultOfMethodCallIgnored
clientPresenceExecutor.awaitTermination(1, TimeUnit.SECONDS);
}

@CartesianTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class AccountsManagerChangeNumberIntegrationTest {

private ClientPresenceManager clientPresenceManager;
private ExecutorService accountLockExecutor;
private ExecutorService clientPresenceExecutor;

private AccountsManager accountsManager;

Expand Down Expand Up @@ -95,6 +96,7 @@ void setup() throws InterruptedException {
Tables.DELETED_ACCOUNTS.tableName());

accountLockExecutor = Executors.newSingleThreadExecutor();
clientPresenceExecutor = Executors.newSingleThreadExecutor();

final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
Tables.DELETED_ACCOUNTS_LOCK.tableName());
Expand Down Expand Up @@ -136,16 +138,21 @@ void setup() throws InterruptedException {
mock(ExperimentEnrollmentManager.class),
registrationRecoveryPasswordsManager,
accountLockExecutor,
clientPresenceExecutor,
mock(Clock.class));
}
}

@AfterEach
void tearDown() throws InterruptedException {
accountLockExecutor.shutdown();
clientPresenceExecutor.shutdown();

//noinspection ResultOfMethodCallIgnored
accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS);

//noinspection ResultOfMethodCallIgnored
clientPresenceExecutor.awaitTermination(1, TimeUnit.SECONDS);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ void setup() throws InterruptedException {
mock(ExperimentEnrollmentManager.class),
mock(RegistrationRecoveryPasswordsManager.class),
mock(Executor.class),
mock(Executor.class),
mock(Clock.class)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -134,6 +135,15 @@ void setup() throws InterruptedException {
profilesManager = mock(ProfilesManager.class);
clientPresenceManager = mock(ClientPresenceManager.class);

final Executor clientPresenceExecutor = mock(Executor.class);

doAnswer(invocation -> {
final Runnable runnable = invocation.getArgument(0);
runnable.run();

return null;
}).when(clientPresenceExecutor).execute(any());

//noinspection unchecked
commands = mock(RedisAdvancedClusterCommands.class);

Expand Down Expand Up @@ -224,6 +234,7 @@ void setup() throws InterruptedException {
enrollmentManager,
registrationRecoveryPasswordsManager,
mock(Executor.class),
clientPresenceExecutor,
mock(Clock.class));
}

Expand Down Expand Up @@ -856,7 +867,7 @@ void testUpdate_dynamoOptimisticLockingFailureDuringCreate() {
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty())
.thenReturn(Optional.of(account));
when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class);
when(accounts.create(any(), any(), any())).thenThrow(ContestedOptimisticLockException.class);

accountsManager.update(account, a -> {
});
Expand Down Expand Up @@ -974,14 +985,14 @@ void testRemovePrimaryDevice() {

@Test
void testCreateFreshAccount() throws InterruptedException {
when(accounts.create(any(), any())).thenReturn(true);
when(accounts.create(any(), any(), any())).thenReturn(true);

final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 1, 2, null, null, true, null);

createAccount(e164, attributes);

verify(accounts).create(argThat(account -> e164.equals(account.getNumber())), any());
verify(accounts).create(argThat(account -> e164.equals(account.getNumber())), any(), any());

verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager);
Expand All @@ -991,25 +1002,31 @@ void testCreateFreshAccount() throws InterruptedException {
void testReregisterAccount() throws InterruptedException {
final UUID existingUuid = UUID.randomUUID();

when(accounts.create(any(), any())).thenAnswer(invocation -> {
final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 1, 2, null, null, true, null);

when(accounts.create(any(), any(), any())).thenAnswer(invocation -> {
invocation.getArgument(0, Account.class).setUuid(existingUuid);

final BiFunction<UUID, UUID, CompletableFuture<Void>> cleanupOperation = invocation.getArgument(2);
cleanupOperation.apply(existingUuid, phoneNumberIdentifiersByE164.get(e164));

return false;
});

final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 1, 2, null, null, true, null);

createAccount(e164, attributes);

assertTrue(phoneNumberIdentifiersByE164.containsKey(e164));

verify(accounts)
.create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())), any());
.create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())), any(), any());

verify(keysManager).delete(existingUuid);
verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164));
verify(keysManager).delete(existingUuid, true);
verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164), true);
verify(messagesManager).clear(existingUuid);
verify(profilesManager).deleteAll(existingUuid);
verify(messagesManager, times(2)).clear(existingUuid);
verify(profilesManager, times(2)).deleteAll(existingUuid);
verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid);
}

Expand All @@ -1018,7 +1035,7 @@ void testCreateAccountRecentlyDeleted() throws InterruptedException {
final UUID recentlyDeletedUuid = UUID.randomUUID();

when(accounts.findRecentlyDeletedAccountIdentifier(anyString())).thenReturn(Optional.of(recentlyDeletedUuid));
when(accounts.create(any(), any())).thenReturn(true);
when(accounts.create(any(), any(), any())).thenReturn(true);

final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 1, 2, null, null, true, null);
Expand All @@ -1027,6 +1044,7 @@ void testCreateAccountRecentlyDeleted() throws InterruptedException {

verify(accounts).create(
argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid())),
any(),
any());

verifyNoInteractions(keysManager);
Expand Down
Loading

0 comments on commit 5f0726a

Please sign in to comment.