Skip to content

Commit

Permalink
Revert "Retire the "migrate signed pre-keys" command"
Browse files Browse the repository at this point in the history
This reverts commit c7cc300.
  • Loading branch information
jon-signal committed Dec 13, 2023
1 parent 3f9edfe commit f738bc9
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@
import org.whispersystems.textsecuregcm.workers.CheckDynamicConfigurationCommand;
import org.whispersystems.textsecuregcm.workers.DeleteUserCommand;
import org.whispersystems.textsecuregcm.workers.MessagePersisterServiceCommand;
import org.whispersystems.textsecuregcm.workers.MigrateSignedECPreKeysCommand;
import org.whispersystems.textsecuregcm.workers.ProcessPushNotificationFeedbackCommand;
import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand;
import org.whispersystems.textsecuregcm.workers.ScheduledApnPushNotificationSenderServiceCommand;
Expand Down Expand Up @@ -273,6 +274,7 @@ public void initialize(final Bootstrap<WhisperServerConfiguration> bootstrap) {
bootstrap.addCommand(new UnlinkDeviceCommand());
bootstrap.addCommand(new ScheduledApnPushNotificationSenderServiceCommand());
bootstrap.addCommand(new MessagePersisterServiceCommand());
bootstrap.addCommand(new MigrateSignedECPreKeysCommand());
bootstrap.addCommand(new RemoveExpiredAccountsCommand(Clock.systemUTC()));
bootstrap.addCommand(new ProcessPushNotificationFeedbackCommand(Clock.systemUTC()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final
}
}

public CompletableFuture<Boolean> storeEcSignedPreKeyIfAbsent(final UUID identifier, final byte deviceId,
final ECSignedPreKey signedPreKey) {
return ecSignedPreKeys.storeIfAbsent(identifier, deviceId, signedPreKey);
}

public CompletableFuture<Void> storePqLastResort(final UUID identifier, final Map<Byte, KEMSignedPreKey> keys) {
return pqLastResortKeys.store(identifier, keys);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,27 @@

import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;

public class RepeatedUseECSignedPreKeyStore extends RepeatedUseSignedPreKeyStore<ECSignedPreKey> {

private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName;

public RepeatedUseECSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName);

this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName;
}

@Override
Expand All @@ -43,4 +53,21 @@ protected ECSignedPreKey getPreKeyFromItem(final Map<String, AttributeValue> ite
throw new IllegalArgumentException(e);
}
}

public CompletableFuture<Boolean> storeIfAbsent(final UUID identifier, final byte deviceId, final ECSignedPreKey signedPreKey) {
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, signedPreKey))
.conditionExpression("attribute_not_exists(#public_key)")
.expressionAttributeNames(Map.of("#public_key", ATTR_PUBLIC_KEY))
.build())
.thenApply(ignored -> true)
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ConditionalCheckFailedException) {
return false;
}

throw ExceptionUtils.wrap(throwable);
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

package org.whispersystems.textsecuregcm.workers;

import io.micrometer.core.instrument.Metrics;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.function.Function;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple3;
import reactor.util.function.Tuples;
import reactor.util.retry.Retry;

public class MigrateSignedECPreKeysCommand extends AbstractSinglePassCrawlAccountsCommand {

private static final String STORE_KEY_ATTEMPT_COUNTER_NAME =
MetricsUtil.name(MigrateSignedECPreKeysCommand.class, "storeKeyAttempt");

// It's tricky to find, but the default connection count for the AWS SDK's async DynamoDB client is 50. As long as
// we stay below that, we should be fine.
private static final int DEFAULT_MAX_CONCURRENCY = 32;

private static final String BUFFER_ARGUMENT = "buffer";
private static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency";

private static final Logger logger = LoggerFactory.getLogger(MigrateSignedECPreKeysCommand.class);

public MigrateSignedECPreKeysCommand() {
super("migrate-signed-ec-pre-keys", "Migrate signed EC pre-keys from Account records to a dedicated table");
}

@Override
public void configure(final Subparser subparser) {
super.configure(subparser);

subparser.addArgument("--max-concurrency")
.type(Integer.class)
.dest(MAX_CONCURRENCY_ARGUMENT)
.setDefault(DEFAULT_MAX_CONCURRENCY)
.help("Max concurrency for DynamoDB operations");

subparser.addArgument("--buffer")
.type(Integer.class)
.dest(BUFFER_ARGUMENT)
.setDefault(16_384)
.help("Devices to buffer");
}

@Override
protected void crawlAccounts(final Flux<Account> accounts) {
final KeysManager keysManager = getCommandDependencies().keysManager();
final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT);
final int bufferSize = getNamespace().getInt(BUFFER_ARGUMENT);

accounts
.flatMap(account -> Flux.fromIterable(account.getDevices())
.flatMap(device -> Flux.fromArray(IdentityType.values())
.filter(identityType -> device.getSignedPreKey(identityType) != null)
.map(identityType -> Tuples.of(account.getIdentifier(identityType), device.getId(), device.getSignedPreKey(identityType)))))
.buffer(bufferSize)
.map(source -> {
final List<Tuple3<UUID, Byte, ECSignedPreKey>> shuffled = new ArrayList<>(source);
Collections.shuffle(shuffled);
return shuffled;
})
.flatMapIterable(Function.identity())
.flatMap(keyTuple -> {
final UUID identifier = keyTuple.getT1();
final byte deviceId = keyTuple.getT2();
final ECSignedPreKey signedPreKey = keyTuple.getT3();

return Mono.fromFuture(() -> keysManager.storeEcSignedPreKeyIfAbsent(identifier, deviceId, signedPreKey))
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).onRetryExhaustedThrow((spec, rs) -> rs.failure()))
.onErrorResume(throwable -> {
logger.warn("Failed to migrate key for UUID {}, device {}", identifier, deviceId);
return Mono.just(false);
})
.doOnSuccess(keyStored -> Metrics.counter(STORE_KEY_ATTEMPT_COUNTER_NAME, "stored", String.valueOf(keyStored)).increment());
}, maxConcurrency)
.then()
.block();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@

package org.whispersystems.textsecuregcm.storage;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
Expand Down Expand Up @@ -39,4 +46,21 @@ protected RepeatedUseSignedPreKeyStore<ECSignedPreKey> getKeyStore() {
protected ECSignedPreKey generateSignedPreKey() {
return KeysHelper.signedECPreKey(currentKeyId++, IDENTITY_KEY_PAIR);
}

@Test
void storeIfAbsent() {
final UUID identifier = UUID.randomUUID();
final byte deviceIdWithExistingKey = 1;
final byte deviceIdWithoutExistingKey = deviceIdWithExistingKey + 1;

final ECSignedPreKey originalSignedPreKey = generateSignedPreKey();

keyStore.store(identifier, deviceIdWithExistingKey, originalSignedPreKey).join();

assertFalse(keyStore.storeIfAbsent(identifier, deviceIdWithExistingKey, generateSignedPreKey()).join());
assertTrue(keyStore.storeIfAbsent(identifier, deviceIdWithoutExistingKey, generateSignedPreKey()).join());

assertEquals(Optional.of(originalSignedPreKey), keyStore.find(identifier, deviceIdWithExistingKey).join());
assertTrue(keyStore.find(identifier, deviceIdWithoutExistingKey).join().isPresent());
}
}

0 comments on commit f738bc9

Please sign in to comment.