Skip to content

Commit

Permalink
Use all devices when checking limit
Browse files Browse the repository at this point in the history
  • Loading branch information
eager-signal committed Oct 30, 2023
1 parent 38b581a commit ba139dd
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ public VerificationCode createDeviceToken(@Auth AuthenticatedAccount auth)
maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber());
}

if (account.getEnabledDeviceCount() >= maxDeviceLimit) {
throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES);
if (account.getDevices().size() >= maxDeviceLimit) {
throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit);
}

if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) {
Expand Down Expand Up @@ -386,8 +386,8 @@ private Pair<Account, Device> createDevice(final String password,
maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber());
}

if (account.getEnabledDeviceCount() >= maxDeviceLimit) {
throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES);
if (account.getDevices().size() >= maxDeviceLimit) {
throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit);
}

final DeviceCapabilities capabilities = accountAttributes.getCapabilities();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ public Response sendMessage(@Auth Optional<AuthenticatedAccount> source,
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
}

boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1;
boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().hasEnabledLinkedDevice();

// We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify
// we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,12 @@ public long getNextDeviceId() {
return candidateId;
}

public int getEnabledDeviceCount() {
public boolean hasEnabledLinkedDevice() {
requireNotStale();

int count = 0;

for (final Device device : devices) {
if (device.isEnabled()) count++;
}

return count;
return devices.stream()
.filter(d -> Device.PRIMARY_ID != d.getId())
.anyMatch(Device::isEnabled);
}

public void setIdentityKey(final IdentityKey identityKey) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
Expand All @@ -44,6 +45,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
Expand Down Expand Up @@ -107,6 +109,9 @@ class DeviceControllerTest {
deviceConfiguration,
testClock);

@RegisterExtension
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();

private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
Expand Down Expand Up @@ -630,10 +635,17 @@ void oldDeviceRegisterTest() {

@Test
void maxDevicesTest() {
final AuthHelper.TestAccount testAccount = AUTH_FILTER_EXTENSION.createTestAccount();

final List<Device> devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1)
.mapToObj(i -> mock(Device.class))
.toList();
when(testAccount.account.getDevices()).thenReturn(devices);

Response response = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))
.header("Authorization", testAccount.getAuthHeader())
.get();

assertEquals(411, response.getStatus());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
Expand Down Expand Up @@ -380,4 +384,49 @@ public void testAccountClassJsonFilterIdMatchesClassName() throws Exception {
final JsonFilter jsonFilterAnnotation = (JsonFilter) maybeJsonFilterAnnotation.get();
assertEquals(Account.class.getSimpleName(), jsonFilterAnnotation.value());
}

@ParameterizedTest
@MethodSource
public void testHasEnabledLinkedDevice(final Account account, final boolean expect) {
assertEquals(expect, account.hasEnabledLinkedDevice());
}

static Stream<Arguments> testHasEnabledLinkedDevice() {
final Device enabledPrimary = mock(Device.class);
when(enabledPrimary.isEnabled()).thenReturn(true);
when(enabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);

final Device disabledPrimary = mock(Device.class);
when(disabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);

final long linked1DeviceId = Device.PRIMARY_ID + 1;
final Device enabledLinked1 = mock(Device.class);
when(enabledLinked1.isEnabled()).thenReturn(true);
when(enabledLinked1.getId()).thenReturn(linked1DeviceId);

final Device disabledLinked1 = mock(Device.class);
when(disabledLinked1.getId()).thenReturn(linked1DeviceId);

final long linked2DeviceId = Device.PRIMARY_ID + 2;
final Device enabledLinked2 = mock(Device.class);
when(enabledLinked2.isEnabled()).thenReturn(true);
when(enabledLinked2.getId()).thenReturn(linked2DeviceId);

final Device disabledLinked2 = mock(Device.class);
when(disabledLinked2.getId()).thenReturn(linked2DeviceId);

return Stream.of(
Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary)), false),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary, disabledLinked1)),
false),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
List.of(enabledPrimary, disabledLinked1, disabledLinked2)), false),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
List.of(enabledPrimary, enabledLinked1, disabledLinked2)), true),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
List.of(enabledPrimary, disabledLinked1, enabledLinked2)), true),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
List.of(disabledLinked2, enabledLinked1, enabledLinked2)), true)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private static Account copyAndMarkStale(Account account) throws IOException {
case "getNextDeviceId" -> when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing);
case "isPniSupported" -> when(updatedAccount.isPniSupported()).thenAnswer(stubbing);
case "isPaymentActivationSupported" -> when(updatedAccount.isPaymentActivationSupported()).thenAnswer(stubbing);
case "getEnabledDeviceCount" -> when(updatedAccount.getEnabledDeviceCount()).thenAnswer(stubbing);
case "hasEnabledLinkedDevice" -> when(updatedAccount.hasEnabledLinkedDevice()).thenAnswer(stubbing);
case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing);
case "getIdentityKey" ->
when(updatedAccount.getIdentityKey(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableMap;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber;
import io.dropwizard.auth.AuthFilter;
import io.dropwizard.auth.PolymorphicAuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.auth.basic.BasicCredentials;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
Expand Down Expand Up @@ -90,6 +97,8 @@ public class AuthHelper {
private static SaltedTokenHash DISABLED_CREDENTIALS = mock(SaltedTokenHash.class);
private static SaltedTokenHash UNDISCOVERABLE_CREDENTIALS = mock(SaltedTokenHash.class);

private static final Collection<TestAccount> EXTENSION_TEST_ACCOUNTS = new HashSet<>();

public static PolymorphicAuthDynamicFeature<? extends Principal> getAuthFilter() {
when(VALID_CREDENTIALS.verify("foo")).thenReturn(true);
when(VALID_CREDENTIALS_TWO.verify("baz")).thenReturn(true);
Expand Down Expand Up @@ -138,7 +147,7 @@ public static PolymorphicAuthDynamicFeature<? extends Principal> getAuthFilter()
when(VALID_ACCOUNT_3.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY));
when(VALID_ACCOUNT_3.getDevice(2L)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED));

when(VALID_ACCOUNT_TWO.getEnabledDeviceCount()).thenReturn(6);
when(VALID_ACCOUNT_TWO.hasEnabledLinkedDevice()).thenReturn(true);

when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER);
when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID);
Expand Down Expand Up @@ -261,6 +270,11 @@ private void setup(final AccountsManager accountsManager) {
when(accountsManager.getByE164(number)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
}

private void teardown(final AccountsManager accountsManager) {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.empty());
when(accountsManager.getByE164(number)).thenReturn(Optional.empty());
}
}

private static TestAccount[] generateTestAccounts() {
Expand All @@ -272,4 +286,35 @@ private static TestAccount[] generateTestAccounts() {
}
return testAccounts;
}

/**
* JUnit 5 extension for creating {@link TestAccount}s scoped to a single test
*/
public static class AuthFilterExtension implements AfterEachCallback {

public TestAccount createTestAccount() {
final UUID uuid = UUID.randomUUID();
final String region = new ArrayList<>((PhoneNumberUtil.getInstance().getSupportedRegions())).get(
EXTENSION_TEST_ACCOUNTS.size());
final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().getExampleNumber(region);

final TestAccount testAccount = new TestAccount(
PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164), uuid,
"extension-password-" + region);
testAccount.setup(ACCOUNTS_MANAGER);

EXTENSION_TEST_ACCOUNTS.add(testAccount);

return testAccount;
}

@Override
public void afterEach(final ExtensionContext context) {
EXTENSION_TEST_ACCOUNTS.forEach(testAccount -> {
testAccount.teardown(ACCOUNTS_MANAGER);
});

EXTENSION_TEST_ACCOUNTS.clear();
}
}
}

0 comments on commit ba139dd

Please sign in to comment.