diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index f21829a5d..968486792 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -68,6 +68,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; +import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.Pair; @@ -403,60 +404,63 @@ private Pair createDevice(final String password, throw new WebApplicationException(Response.status(409).build()); } - final Device device = new Device(); - device.setName(accountAttributes.getName()); - device.setAuthTokenHash(SaltedTokenHash.generateFor(password)); - device.setFetchesMessages(accountAttributes.getFetchesMessages()); - device.setRegistrationId(accountAttributes.getRegistrationId()); - device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId()); - device.setLastSeen(Util.todayInMillis()); - device.setCreated(System.currentTimeMillis()); - device.setCapabilities(accountAttributes.getCapabilities()); - - maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { - device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey()); - device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey()); - - deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> { - device.setApnId(apnRegistrationId.apnRegistrationId()); - device.setVoipApnId(apnRegistrationId.voipRegistrationId()); - }); - - deviceActivationRequest.gcmToken().ifPresent(gcmRegistrationId -> - device.setGcmId(gcmRegistrationId.gcmRegistrationId())); - }); - - final Account updatedAccount = accounts.update(account, a -> { - device.setId(a.getNextDeviceId()); - - final CompletableFuture deleteKeysFuture = CompletableFuture.allOf( - keys.delete(a.getUuid(), device.getId()), - keys.delete(a.getPhoneNumberIdentifier(), device.getId())); - - messages.clear(a.getUuid(), device.getId()).join(); - - deleteKeysFuture.join(); - - maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf( - keys.storeEcSignedPreKeys(a.getUuid(), - Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey())), - keys.storePqLastResort(a.getUuid(), - Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey())), - keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), - Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey())), - keys.storePqLastResort(a.getPhoneNumberIdentifier(), - Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey()))) - .join()); - - a.addDevice(device); - }); - - if (maybeAciFromToken.isPresent()) { - usedTokenCluster.useCluster(connection -> - connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))); - } - - return new Pair<>(updatedAccount, device); + return maybeDeviceActivationRequest.map(deviceActivationRequest -> { + final String signalAgent; + + if (deviceActivationRequest.apnToken().isPresent()) { + signalAgent = "OWP"; + } else if (deviceActivationRequest.gcmToken().isPresent()) { + signalAgent = "OWA"; + } else { + signalAgent = "OWD"; + } + + return accounts.addDevice(account, new DeviceSpec(accountAttributes.getName(), + password, + signalAgent, + capabilities, + accountAttributes.getRegistrationId(), + accountAttributes.getPhoneNumberIdentityRegistrationId(), + accountAttributes.getFetchesMessages(), + deviceActivationRequest.apnToken(), + deviceActivationRequest.gcmToken(), + deviceActivationRequest.aciSignedPreKey(), + deviceActivationRequest.pniSignedPreKey(), + deviceActivationRequest.aciPqLastResortPreKey(), + deviceActivationRequest.pniPqLastResortPreKey())) + .thenCompose(a -> usedTokenCluster.withCluster(connection -> connection.async() + .set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))) + .thenApply(ignored -> a)) + .join(); + }) + .orElseGet(() -> { + final Device device = new Device(); + device.setName(accountAttributes.getName()); + device.setAuthTokenHash(SaltedTokenHash.generateFor(password)); + device.setFetchesMessages(accountAttributes.getFetchesMessages()); + device.setRegistrationId(accountAttributes.getRegistrationId()); + device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId()); + device.setLastSeen(Util.todayInMillis()); + device.setCreated(System.currentTimeMillis()); + device.setCapabilities(accountAttributes.getCapabilities()); + + final Account updatedAccount = accounts.update(account, a -> { + device.setId(a.getNextDeviceId()); + + CompletableFuture.allOf( + keys.delete(a.getUuid(), device.getId()), + keys.delete(a.getPhoneNumberIdentifier(), device.getId()), + messages.clear(a.getUuid(), device.getId())) + .join(); + + a.addDevice(device); + }); + + usedTokenCluster.useCluster(connection -> + connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))); + + return new Pair<>(updatedAccount, device); + }); } private static String getUsedTokenKey(final String token) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java index 13cd2838e..4210a2523 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java @@ -43,6 +43,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Util; @@ -140,18 +141,24 @@ public AccountIdentityResponse register( } final Account account = accounts.create(number, - password, - signalAgent, registrationRequest.accountAttributes(), existingAccount.map(Account::getBadges).orElseGet(ArrayList::new), registrationRequest.aciIdentityKey(), registrationRequest.pniIdentityKey(), - registrationRequest.deviceActivationRequest().aciSignedPreKey(), - registrationRequest.deviceActivationRequest().pniSignedPreKey(), - registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(), - registrationRequest.deviceActivationRequest().pniPqLastResortPreKey(), - registrationRequest.deviceActivationRequest().apnToken(), - registrationRequest.deviceActivationRequest().gcmToken()); + new DeviceSpec( + registrationRequest.accountAttributes().getName(), + password, + signalAgent, + registrationRequest.accountAttributes().getCapabilities(), + registrationRequest.accountAttributes().getRegistrationId(), + registrationRequest.accountAttributes().getPhoneNumberIdentityRegistrationId(), + registrationRequest.accountAttributes().getFetchesMessages(), + registrationRequest.deviceActivationRequest().apnToken(), + registrationRequest.deviceActivationRequest().gcmToken(), + registrationRequest.deviceActivationRequest().aciSignedPreKey(), + registrationRequest.deviceActivationRequest().pniSignedPreKey(), + registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(), + registrationRequest.deviceActivationRequest().pniPqLastResortPreKey())); Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 1df84efef..0f90809b4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -53,9 +53,7 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; -import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; -import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.IdentityType; @@ -68,6 +66,7 @@ import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.ExceptionUtils; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; import reactor.core.publisher.ParallelFlux; @@ -132,11 +131,6 @@ public class AccountsManager { private static final int MAX_UPDATE_ATTEMPTS = 10; - @FunctionalInterface - private interface AccountPersister { - void persistAccount(Account account) throws UsernameHashNotAvailableException; - } - public enum DeletionReason { ADMIN_DELETED("admin"), EXPIRED ("expired"), @@ -181,46 +175,18 @@ public AccountsManager(final Accounts accounts, this.clock = requireNonNull(clock); } - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public Account create(final String number, - final String password, - final String signalAgent, final AccountAttributes accountAttributes, final List accountBadges, final IdentityKey aciIdentityKey, final IdentityKey pniIdentityKey, - final ECSignedPreKey aciSignedPreKey, - final ECSignedPreKey pniSignedPreKey, - final KEMSignedPreKey aciPqLastResortPreKey, - final KEMSignedPreKey pniPqLastResortPreKey, - final Optional maybeApnRegistrationId, - final Optional maybeGcmRegistrationId) throws InterruptedException { + final DeviceSpec primaryDeviceSpec) throws InterruptedException { try (Timer.Context ignored = createTimer.time()) { final Account account = new Account(); accountLockManager.withLock(List.of(number), () -> { - final Device device = new Device(); - device.setId(Device.PRIMARY_ID); - device.setAuthTokenHash(SaltedTokenHash.generateFor(password)); - device.setFetchesMessages(accountAttributes.getFetchesMessages()); - device.setRegistrationId(accountAttributes.getRegistrationId()); - device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId()); - device.setName(accountAttributes.getName()); - device.setCapabilities(accountAttributes.getCapabilities()); - device.setCreated(System.currentTimeMillis()); - device.setLastSeen(Util.todayInMillis()); - device.setUserAgent(signalAgent); - device.setSignedPreKey(aciSignedPreKey); - device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey); - - maybeApnRegistrationId.ifPresent(apnRegistrationId -> { - device.setApnId(apnRegistrationId.apnRegistrationId()); - device.setVoipApnId(apnRegistrationId.voipRegistrationId()); - }); - - maybeGcmRegistrationId.ifPresent(gcmRegistrationId -> - device.setGcmId(gcmRegistrationId.gcmRegistrationId())); + final Device device = primaryDeviceSpec.toDevice(Device.PRIMARY_ID, clock); account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number)); @@ -245,10 +211,10 @@ public Account create(final String number, a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI), a.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID, - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey), + primaryDeviceSpec.aciSignedPreKey(), + primaryDeviceSpec.pniSignedPreKey(), + primaryDeviceSpec.aciPqLastResortPreKey(), + primaryDeviceSpec.pniPqLastResortPreKey()), (aci, pni) -> CompletableFuture.allOf( keysManager.delete(aci), keysManager.delete(pni), @@ -299,6 +265,42 @@ public Account create(final String number, } } + public CompletableFuture> addDevice(final Account account, final DeviceSpec deviceSpec) { + return addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, MAX_UPDATE_ATTEMPTS); + } + + private CompletableFuture> addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final int retries) { + return accounts.getByAccountIdentifierAsync(accountIdentifier) + .thenApply(maybeAccount -> maybeAccount.orElseThrow(ContestedOptimisticLockException::new)) + .thenCompose(account -> { + final byte nextDeviceId = account.getNextDeviceId(); + account.addDevice(deviceSpec.toDevice(nextDeviceId, clock)); + + final List additionalWriteItems = keysManager.buildWriteItemsForRepeatedUseKeys( + account.getIdentifier(IdentityType.ACI), + account.getIdentifier(IdentityType.PNI), + nextDeviceId, + deviceSpec.aciSignedPreKey(), + deviceSpec.pniSignedPreKey(), + deviceSpec.aciPqLastResortPreKey(), + deviceSpec.pniPqLastResortPreKey()); + + return CompletableFuture.allOf( + keysManager.delete(account.getUuid(), nextDeviceId), + keysManager.delete(account.getPhoneNumberIdentifier(), nextDeviceId), + messagesManager.clear(account.getUuid(), nextDeviceId)) + .thenCompose(ignored -> accounts.updateTransactionallyAsync(account, additionalWriteItems)) + .thenApply(ignored -> new Pair<>(account, account.getDevice(nextDeviceId).orElseThrow())); + }) + .exceptionallyCompose(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException && retries > 0) { + return addDevice(accountIdentifier, deviceSpec, retries - 1); + } + + return CompletableFuture.failedFuture(throwable); + }); + } + public CompletableFuture removeDevice(final Account account, final byte deviceId) { if (deviceId == Device.PRIMARY_ID) { throw new IllegalArgumentException("Cannot remove primary device"); @@ -705,19 +707,6 @@ private Account updateWithRetries(Account account, final Consumer persister, final Supplier retriever, final AccountChangeValidator changeValidator) { - try { - return failableUpdateWithRetries(account, updater, persister::accept, retriever, changeValidator); - } catch (UsernameHashNotAvailableException e) { - // not possible - throw new IllegalStateException(e); - } - } - - private Account failableUpdateWithRetries(Account account, - final Function updater, - final AccountPersister persister, - final Supplier retriever, - final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException { Account originalAccount = AccountUtil.cloneAccountAsNotStale(account); @@ -731,7 +720,7 @@ private Account failableUpdateWithRetries(Account account, while (tries < maxTries) { try { - persister.persistAccount(account); + persister.accept(account); final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account); account.markStale(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceSpec.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceSpec.java new file mode 100644 index 000000000..8bff97f83 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceSpec.java @@ -0,0 +1,90 @@ +package org.whispersystems.textsecuregcm.storage; + +import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; +import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.util.Util; +import java.time.Clock; +import java.util.Arrays; +import java.util.Objects; +import java.util.Optional; + +public record DeviceSpec( + byte[] deviceNameCiphertext, + String password, + String signalAgent, + Device.DeviceCapabilities capabilities, + int aciRegistrationId, + int pniRegistrationId, + boolean fetchesMessages, + Optional apnRegistrationId, + Optional gcmRegistrationId, + ECSignedPreKey aciSignedPreKey, + ECSignedPreKey pniSignedPreKey, + KEMSignedPreKey aciPqLastResortPreKey, + KEMSignedPreKey pniPqLastResortPreKey) { + + public Device toDevice(final byte deviceId, final Clock clock) { + final Device device = new Device(); + device.setId(deviceId); + device.setAuthTokenHash(SaltedTokenHash.generateFor(password())); + device.setFetchesMessages(fetchesMessages()); + device.setRegistrationId(aciRegistrationId()); + device.setPhoneNumberIdentityRegistrationId(pniRegistrationId()); + device.setName(deviceNameCiphertext()); + device.setCapabilities(capabilities()); + device.setCreated(clock.millis()); + device.setLastSeen(Util.todayInMillis()); + device.setUserAgent(signalAgent()); + device.setSignedPreKey(aciSignedPreKey()); + device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey()); + + apnRegistrationId().ifPresent(apnRegistrationId -> { + device.setApnId(apnRegistrationId.apnRegistrationId()); + device.setVoipApnId(apnRegistrationId.voipRegistrationId()); + }); + + gcmRegistrationId().ifPresent(gcmRegistrationId -> + device.setGcmId(gcmRegistrationId.gcmRegistrationId())); + + return device; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + final DeviceSpec that = (DeviceSpec) o; + + return aciRegistrationId == that.aciRegistrationId + && pniRegistrationId == that.pniRegistrationId + && fetchesMessages == that.fetchesMessages + && Arrays.equals(deviceNameCiphertext, that.deviceNameCiphertext) + && Objects.equals(password, that.password) + && Objects.equals(signalAgent, that.signalAgent) + && Objects.equals(capabilities, that.capabilities) + && Objects.equals(apnRegistrationId, that.apnRegistrationId) + && Objects.equals(gcmRegistrationId, that.gcmRegistrationId) + && Objects.equals(aciSignedPreKey, that.aciSignedPreKey) + && Objects.equals(pniSignedPreKey, that.pniSignedPreKey) + && Objects.equals(aciPqLastResortPreKey, that.aciPqLastResortPreKey) + && Objects.equals(pniPqLastResortPreKey, that.pniPqLastResortPreKey); + } + + @Override + public int hashCode() { + int result = Objects.hash(password, signalAgent, capabilities, aciRegistrationId, pniRegistrationId, + fetchesMessages, apnRegistrationId, gcmRegistrationId, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, + pniPqLastResortPreKey); + result = 31 * result + Arrays.hashCode(deviceNameCiphertext); + return result; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index a2e52ddf5..559268b3d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -24,6 +24,7 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; +import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import java.nio.charset.StandardCharsets; import java.time.Instant; @@ -38,6 +39,7 @@ import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -72,12 +74,15 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; +import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.VerificationCode; @@ -91,6 +96,7 @@ class DeviceControllerTest { private static RateLimiters rateLimiters = mock(RateLimiters.class); private static RateLimiter rateLimiter = mock(RateLimiter.class); private static RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); + private static RedisAdvancedClusterAsyncCommands asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class); private static Account account = mock(Account.class); private static Account maxedAccount = mock(Account.class); private static Device primaryDevice = mock(Device.class); @@ -106,7 +112,10 @@ class DeviceControllerTest { messagesManager, keysManager, rateLimiters, - RedisClusterHelper.builder().stringCommands(commands).build(), + RedisClusterHelper.builder() + .stringCommands(commands) + .stringAsyncCommands(asyncCommands) + .build(), deviceConfiguration, testClock); @@ -114,6 +123,7 @@ class DeviceControllerTest { public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension(); private static final ResourceExtension resources = ResourceExtension.builder() + .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProvider(AuthHelper.getAuthFilter()) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) @@ -166,6 +176,7 @@ void teardown() { rateLimiters, rateLimiter, commands, + asyncCommands, account, maxedAccount, primaryDevice, @@ -300,11 +311,22 @@ void linkDeviceAtomic(final boolean fetchesMessages, when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); + when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> { + final Account a = invocation.getArgument(0); + final DeviceSpec deviceSpec = invocation.getArgument(1); + + return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock))); + }); + when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); + + final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null); + final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), - new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null), + accountAttributes, new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId)); final DeviceResponse response = resources.getJerseyTest() @@ -315,10 +337,10 @@ void linkDeviceAtomic(final boolean fetchesMessages, assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID); - final ArgumentCaptor deviceCaptor = ArgumentCaptor.forClass(Device.class); - verify(account).addDevice(deviceCaptor.capture()); + final ArgumentCaptor deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class); + verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture()); - final Device device = deviceCaptor.getValue(); + final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock); assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI)); assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI)); @@ -333,14 +355,9 @@ void linkDeviceAtomic(final boolean fetchesMessages, expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()), () -> assertNull(device.getGcmId())); - verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID)); - verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey)); - verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey)); - verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey)); - verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey)); - verify(commands).set(anyString(), anyString(), any()); + verify(asyncCommands).set(anyString(), anyString(), any()); } - + private static Stream linkDeviceAtomic() { final String apnsToken = "apns-token"; final String apnsVoipToken = "apns-voip-token"; @@ -596,9 +613,18 @@ void linkDeviceRegistrationId(final int registrationId, final int pniRegistratio when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); + when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> { + final Account a = invocation.getArgument(0); + final DeviceSpec deviceSpec = invocation.getArgument(1); + + return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock))); + }); + when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); + final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, null), new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn", null)), Optional.empty())); @@ -719,35 +745,66 @@ void longNameTest() { verifyNoMoreInteractions(messagesManager); } - @Test - void deviceDowngradePniTest() { - DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, - false, true); - AccountAttributes accountAttributes = - new AccountAttributes(false, 1234, 5678, null, null, true, deviceCapabilities); + @ParameterizedTest + @MethodSource + void deviceDowngradePniTest(final boolean accountSupportsPni, final boolean deviceSupportsPni, final int expectedStatus) { + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); - final String verificationToken = deviceController.generateVerificationToken(AuthHelper.VALID_UUID); + final Device primaryDevice = mock(Device.class); + when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice)); - Response response = resources - .getJerseyTest() - .target("/v1/devices/" + verificationToken) - .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) - .header(HttpHeaders.USER_AGENT, "Signal-Android/5.42.8675309 Android/30") - .put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE)); - assertThat(response.getStatus()).isEqualTo(409); + final ECSignedPreKey aciSignedPreKey; + final ECSignedPreKey pniSignedPreKey; + final KEMSignedPreKey aciPqLastResortPreKey; + final KEMSignedPreKey pniPqLastResortPreKey; - deviceCapabilities = new DeviceCapabilities(true, true, true, true); - accountAttributes = new AccountAttributes(false, 1234, 5678, null, null, true, deviceCapabilities); - response = resources - .getJerseyTest() - .target("/v1/devices/" + verificationToken) + final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + + aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair); + pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair); + aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); + pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); + + when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); + when(account.isPniSupported()).thenReturn(accountSupportsPni); + + when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> { + final Account a = invocation.getArgument(0); + final DeviceSpec deviceSpec = invocation.getArgument(1); + + return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock))); + }); + + when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); + + final AccountAttributes accountAttributes = new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, deviceSupportsPni, true)); + + final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID), + accountAttributes, + new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id")))); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/link") .request() - .header("Authorization", - AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) - .header(HttpHeaders.USER_AGENT, "Signal-Android/5.42.8675309 Android/30") - .put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE)); - assertThat(response.getStatus()).isEqualTo(200); + .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) + .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(expectedStatus, response.getStatus()); + } + } + + private static List deviceDowngradePniTest() { + return List.of( + Arguments.of(true, true, 200), + Arguments.of(true, false, 409), + Arguments.of(false, true, 200), + Arguments.of(false, false, 200)); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index 388981350..85bea1a31 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -9,7 +9,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -74,6 +73,7 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; @@ -167,7 +167,7 @@ void invalidRegistrationId(Optional registrationId, Optional p final Account account = mock(Account.class); when(account.getPrimaryDevice()).thenReturn(mock(Device.class)); - when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any())) .thenReturn(account); final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId.orElse(0)); @@ -290,7 +290,7 @@ void recoveryPasswordManagerVerificationTrue() throws InterruptedException { final Account account = mock(Account.class); when(account.getPrimaryDevice()).thenReturn(mock(Device.class)); - when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any())) .thenReturn(account); final Invocation.Builder request = resources.getJerseyTest() @@ -348,7 +348,7 @@ void registrationLockAndDeviceTransfer( final Account createdAccount = mock(Account.class); when(createdAccount.getPrimaryDevice()).thenReturn(mock(Device.class)); - when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any())) .thenReturn(createdAccount); expectedStatus = 200; @@ -402,7 +402,7 @@ void deviceTransferAvailable(final boolean existingAccount, final boolean transf final Account account = mock(Account.class); when(account.getPrimaryDevice()).thenReturn(mock(Device.class)); - when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any())) .thenReturn(account); final Invocation.Builder request = resources.getJerseyTest() @@ -426,7 +426,7 @@ void registrationSuccess() throws Exception { final Account account = mock(Account.class); when(account.getPrimaryDevice()).thenReturn(mock(Device.class)); - when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any())) .thenReturn(account); final Invocation.Builder request = resources.getJerseyTest() @@ -658,16 +658,10 @@ static Stream atomicAccountCreationPartialSignedPreKeys() { @ParameterizedTest @MethodSource - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, final IdentityKey expectedAciIdentityKey, final IdentityKey expectedPniIdentityKey, - final ECSignedPreKey expectedAciSignedPreKey, - final ECSignedPreKey expectedPniSignedPreKey, - final KEMSignedPreKey expectedAciPqLastResortPreKey, - final KEMSignedPreKey expectedPniPqLastResortPreKey, - final Optional expectedApnRegistrationId, - final Optional expectedGcmRegistrationId) throws InterruptedException { + final DeviceSpec expectedDeviceSpec) throws InterruptedException { when(registrationServiceClient.getSession(any(), any())) .thenReturn( @@ -685,7 +679,7 @@ void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, when(a.getPrimaryDevice()).thenReturn(device); }); - when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any())) .thenReturn(account); final Invocation.Builder request = resources.getJerseyTest() @@ -699,18 +693,11 @@ void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, verify(accountsManager).create( eq(NUMBER), - eq(PASSWORD), - isNull(), argThat(attributes -> accountAttributesEqual(attributes, registrationRequest.accountAttributes())), eq(Collections.emptyList()), eq(expectedAciIdentityKey), eq(expectedPniIdentityKey), - eq(expectedAciSignedPreKey), - eq(expectedPniSignedPreKey), - eq(expectedAciPqLastResortPreKey), - eq(expectedPniPqLastResortPreKey), - eq(expectedApnRegistrationId), - eq(expectedGcmRegistrationId)); + eq(expectedDeviceSpec)); } private static boolean accountAttributesEqual(final AccountAttributes a, final AccountAttributes b) { @@ -745,11 +732,17 @@ private static Stream atomicAccountCreationSuccess() { pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); } + final byte[] deviceName = "test".getBytes(StandardCharsets.UTF_8); + final int registrationId = 1; + final int pniRegistrationId = 2; + + final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities(false, false, false, false); + final AccountAttributes fetchesMessagesAccountAttributes = - new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false)); + new AccountAttributes(true, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false)); final AccountAttributes pushAccountAttributes = - new AccountAttributes(false, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false)); + new AccountAttributes(false, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false)); final String apnsToken = "apns-token"; final String apnsVoipToken = "apns-voip-token"; @@ -771,37 +764,22 @@ private static Stream atomicAccountCreationSuccess() { Optional.empty()), aciIdentityKey, pniIdentityKey, - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - Optional.empty(), - Optional.empty(), - Optional.empty()), - - // Has APNs tokens - Arguments.of(new RegistrationRequest("session-id", - new byte[0], - pushAccountAttributes, + new DeviceSpec( + deviceName, + PASSWORD, + null, + deviceCapabilities, + registrationId, + pniRegistrationId, true, - aciIdentityKey, - pniIdentityKey, + Optional.empty(), + Optional.empty(), aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, - pniPqLastResortPreKey, - Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), - Optional.empty()), - aciIdentityKey, - pniIdentityKey, - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), - Optional.empty()), + pniPqLastResortPreKey)), - // requires the request to be atomic + // Has APNs tokens Arguments.of(new RegistrationRequest("session-id", new byte[0], pushAccountAttributes, @@ -816,14 +794,22 @@ private static Stream atomicAccountCreationSuccess() { Optional.empty()), aciIdentityKey, pniIdentityKey, - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), - Optional.empty()), + new DeviceSpec( + deviceName, + PASSWORD, + null, + deviceCapabilities, + registrationId, + pniRegistrationId, + false, + Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), + Optional.empty(), + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey)), - // Fetches messages; no push tokens + // Has GCM token Arguments.of(new RegistrationRequest("session-id", new byte[0], pushAccountAttributes, @@ -838,12 +824,21 @@ private static Stream atomicAccountCreationSuccess() { Optional.of(new GcmRegistrationId(gcmToken))), aciIdentityKey, pniIdentityKey, - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - Optional.empty(), - Optional.of(new GcmRegistrationId(gcmToken)))); + new DeviceSpec( + deviceName, + PASSWORD, + null, + deviceCapabilities, + registrationId, + pniRegistrationId, + false, + Optional.empty(), + Optional.of(new GcmRegistrationId(gcmToken)), + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey)) + ); } /** diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationIntegrationTest.java index c581d47cc..24e321582 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationIntegrationTest.java @@ -211,18 +211,24 @@ void createAccount(final DeliveryChannels deliveryChannels, : Optional.empty(); final Account account = accountsManager.create(number, - password, - signalAgent, accountAttributes, badges, new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()), - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - maybeApnRegistrationId, - maybeGcmRegistrationId); + new DeviceSpec( + deviceName, + password, + signalAgent, + deviceCapabilities, + registrationId, + pniRegistrationId, + deliveryChannels.fetchesMessages(), + maybeApnRegistrationId, + maybeGcmRegistrationId, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey)); assertExpectedStoredAccount(account, number, @@ -264,18 +270,23 @@ void reregisterAccount(final DeliveryChannels deliveryChannels, final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair); final Account originalAccount = accountsManager.create(number, - RandomStringUtils.randomAlphanumeric(16), - "OWI", new AccountAttributes(true, 1, 1, "name".getBytes(StandardCharsets.UTF_8), "registration-lock", false, new Device.DeviceCapabilities(false, false, false, false)), Collections.emptyList(), new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()), - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - Optional.empty(), - Optional.empty()); + new DeviceSpec(null, + "password?", + "OWI", + new Device.DeviceCapabilities(false, false, false, false), + 1, + 2, + true, + Optional.empty(), + Optional.empty(), + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey)); existingAccountUuid = originalAccount.getUuid(); } @@ -324,18 +335,23 @@ void reregisterAccount(final DeliveryChannels deliveryChannels, : Optional.empty(); final Account reregisteredAccount = accountsManager.create(number, - password, - signalAgent, accountAttributes, badges, new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()), - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - maybeApnRegistrationId, - maybeGcmRegistrationId); + new DeviceSpec(deviceName, + password, + signalAgent, + deviceCapabilities, + registrationId, + pniRegistrationId, + accountAttributes.getFetchesMessages(), + maybeApnRegistrationId, + maybeGcmRegistrationId, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey)); assertExpectedStoredAccount(reregisteredAccount, number, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index 6dbc936f4..efe890848 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -87,14 +87,6 @@ void setup() throws InterruptedException { mock(DynamicConfigurationManager.class); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); - final KeysManager keysManager = new KeysManager( - DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), - Tables.EC_KEYS.tableName(), - Tables.PQ_KEYS.tableName(), - Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), - Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(), - dynamicConfigurationManager); - accounts = new Accounts( DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -157,18 +149,24 @@ void testConcurrentUpdate() throws IOException, InterruptedException { final Account account = accountsManager.update( accountsManager.create("+14155551212", - "password", - null, new AccountAttributes(), new ArrayList<>(), new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()), - KeysHelper.signedECPreKey(1, aciKeyPair), - KeysHelper.signedECPreKey(2, pniKeyPair), - KeysHelper.signedKEMPreKey(3, aciKeyPair), - KeysHelper.signedKEMPreKey(4, pniKeyPair), - Optional.empty(), - Optional.empty()), + new DeviceSpec( + null, + "password", + null, + new Device.DeviceCapabilities(false, false, false, false), + 1, + 2, + true, + Optional.empty(), + Optional.empty(), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair))), a -> { a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); a.removeDevice(Device.PRIMARY_ID); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 2457e8ca8..48dad12f5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -31,12 +31,12 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.i18n.phonenumbers.PhoneNumberUtil; import io.lettuce.core.RedisException; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import java.io.InputStream; import java.nio.charset.StandardCharsets; -import java.time.Clock; import java.time.Duration; import java.util.ArrayList; import java.util.Base64; @@ -85,6 +85,8 @@ import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; +import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.TestClock; @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) class AccountsManagerTest { @@ -109,6 +111,7 @@ class AccountsManagerTest { private RedisAdvancedClusterCommands commands; private RedisAdvancedClusterAsyncCommands asyncCommands; + private TestClock clock; private AccountsManager accountsManager; private static final Answer ACCOUNT_UPDATE_ANSWER = (answer) -> { @@ -219,6 +222,8 @@ void setup() throws InterruptedException { when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null)); when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null)); + clock = TestClock.now(); + accountsManager = new AccountsManager( accounts, phoneNumberIdentifiers, @@ -237,7 +242,7 @@ void setup() throws InterruptedException { registrationRecoveryPasswordsManager, mock(Executor.class), clientPresenceExecutor, - mock(Clock.class)); + clock); } @Test @@ -1074,6 +1079,84 @@ void testCreateWithStorageCapability(final boolean hasStorage) throws Interrupte assertEquals(hasStorage, account.isStorageSupported()); } + @Test + void testAddDevice() { + final String phoneNumber = + PhoneNumberUtil.getInstance().format(PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + + final Account account = AccountsHelper.generateTestAccount(phoneNumber, List.of(generateTestDevice(clock.millis()))); + final UUID aci = account.getIdentifier(IdentityType.ACI); + final UUID pni = account.getIdentifier(IdentityType.PNI); + + final byte nextDeviceId = account.getNextDeviceId(); + + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + final byte[] deviceNameCiphertext = "device-name".getBytes(StandardCharsets.UTF_8); + final String password = "password"; + final String signalAgent = "OWT"; + final DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, true, true); + final int aciRegistrationId = 17; + final int pniRegistrationId = 19; + final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair); + final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair); + final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair); + final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair); + + when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + when(accounts.getByAccountIdentifierAsync(aci)).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + + clock.pin(clock.instant().plusSeconds(60)); + + final Pair updatedAccountAndDevice = accountsManager.addDevice(account, new DeviceSpec( + deviceNameCiphertext, + password, + signalAgent, + deviceCapabilities, + aciRegistrationId, + pniRegistrationId, + true, + Optional.empty(), + Optional.empty(), + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey)) + .join(); + + verify(keysManager).delete(aci, nextDeviceId); + verify(keysManager).delete(pni, nextDeviceId); + verify(messagesManager).clear(aci, nextDeviceId); + + verify(keysManager).buildWriteItemsForRepeatedUseKeys( + aci, + pni, + nextDeviceId, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey); + + final Device device = updatedAccountAndDevice.second(); + + assertEquals(deviceNameCiphertext, device.getName()); + assertTrue(device.getAuthTokenHash().verify(password)); + assertEquals(signalAgent, device.getUserAgent()); + assertEquals(deviceCapabilities, device.getCapabilities()); + assertEquals(aciRegistrationId, device.getRegistrationId()); + assertEquals(pniRegistrationId, device.getPhoneNumberIdentityRegistrationId().getAsInt()); + assertTrue(device.getFetchesMessages()); + assertNull(device.getApnId()); + assertNull(device.getVoipApnId()); + assertNull(device.getGcmId()); + assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI)); + assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI)); + } + @ParameterizedTest @MethodSource void testUpdateDeviceLastSeen(final boolean expectUpdate, final long initialLastSeen, final long updatedLastSeen) { @@ -1649,17 +1732,23 @@ private Account createAccount(final String e164, final AccountAttributes account final ECKeyPair pniKeyPair = Curve.generateKeyPair(); return accountsManager.create(e164, - "password", - null, accountAttributes, new ArrayList<>(), new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()), - KeysHelper.signedECPreKey(1, aciKeyPair), - KeysHelper.signedECPreKey(2, pniKeyPair), - KeysHelper.signedKEMPreKey(3, aciKeyPair), - KeysHelper.signedKEMPreKey(4, pniKeyPair), - Optional.empty(), - Optional.empty()); + new DeviceSpec( + accountAttributes.getName(), + "password", + null, + accountAttributes.getCapabilities(), + accountAttributes.getRegistrationId(), + accountAttributes.getPhoneNumberIdentityRegistrationId(), + accountAttributes.getFetchesMessages(), + Optional.empty(), + Optional.empty(), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair))); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index be10b451d..700cfc253 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -30,6 +30,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -171,17 +172,23 @@ public static Account createAccount(final AccountsManager accountsManager, final final ECKeyPair pniKeyPair = Curve.generateKeyPair(); return accountsManager.create(e164, - "password", - null, accountAttributes, new ArrayList<>(), new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()), - KeysHelper.signedECPreKey(1, aciKeyPair), - KeysHelper.signedECPreKey(2, pniKeyPair), - KeysHelper.signedKEMPreKey(3, aciKeyPair), - KeysHelper.signedKEMPreKey(4, pniKeyPair), - Optional.empty(), - Optional.empty()); + new DeviceSpec( + accountAttributes.getName(), + "password", + "OWT", + accountAttributes.getCapabilities(), + accountAttributes.getRegistrationId(), + accountAttributes.getPhoneNumberIdentityRegistrationId(), + accountAttributes.getFetchesMessages(), + Optional.empty(), + Optional.empty(), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair))); } }