diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index e2eefe2eb..07b56c406 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -186,9 +186,8 @@ public DeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) BasicAu @Context ContainerRequest containerRequest) throws RateLimitExceededException, DeviceLimitExceededException { - final Optional maybeAciFromToken = checkVerificationToken(linkDeviceRequest.verificationCode()); - - final Account account = maybeAciFromToken.flatMap(accounts::getByAccountIdentifier) + final Account account = checkVerificationToken(linkDeviceRequest.verificationCode()) + .flatMap(accounts::getByAccountIdentifier) .orElseThrow(ForbiddenException::new); final DeviceActivationRequest deviceActivationRequest = linkDeviceRequest.deviceActivationRequest(); @@ -211,18 +210,17 @@ public DeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) BasicAu // active user is and what their device states look like. AuthEnablementRefreshRequirementProvider.setAccount(containerRequest, account); - int maxDeviceLimit = MAX_DEVICES; - - if (maxDeviceConfiguration.containsKey(account.getNumber())) { - maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber()); - } + final int maxDeviceLimit = maxDeviceConfiguration.getOrDefault(account.getNumber(), MAX_DEVICES); if (account.getDevices().size() >= maxDeviceLimit) { throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit); } final DeviceCapabilities capabilities = accountAttributes.getCapabilities(); - if (capabilities != null && isCapabilityDowngrade(account, capabilities)) { + + if (capabilities == null) { + throw new WebApplicationException(Response.status(422, "Missing device capabilities").build()); + } else if (isCapabilityDowngrade(account, capabilities)) { throw new WebApplicationException(Response.status(409).build()); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 69cad4136..1423bf7aa 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -216,7 +216,7 @@ void linkDeviceAtomic(final boolean fetchesMessages, when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); - final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null); + final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, true, true)); final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), accountAttributes, @@ -458,6 +458,50 @@ private static Stream linkDeviceAtomicMissingProperty() { ); } + @Test + void linkDeviceAtomicMissingCapabilities() { + final ECSignedPreKey aciSignedPreKey; + final ECSignedPreKey pniSignedPreKey; + final KEMSignedPreKey aciPqLastResortPreKey; + final KEMSignedPreKey pniPqLastResortPreKey; + + final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + + aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair); + pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair); + aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); + pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); + + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + + final Device existingDevice = mock(Device.class); + when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + + VerificationCode deviceCode = resources.getJerseyTest() + .target("/v1/devices/provisioning/code") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(VerificationCode.class); + + when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); + + final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), + new AccountAttributes(true, 1234, 5678, null, null, true, null), + new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty())); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/link") + .request() + .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) + .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(422, response.getStatus()); + } + } + @ParameterizedTest @MethodSource void linkDeviceAtomicInvalidSignature(final IdentityKey aciIdentityKey, @@ -589,7 +633,7 @@ void linkDeviceRegistrationId(final int registrationId, final int pniRegistratio when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), - new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, null), + new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, new DeviceCapabilities(true, true, true, true)), new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn", null)), Optional.empty())); try (final Response response = resources.getJerseyTest()