Skip to content

Commit

Permalink
Use destination service ID from the envelope when removing views from…
Browse files Browse the repository at this point in the history
… shared MRM data
  • Loading branch information
eager-signal committed Sep 16, 2024
1 parent 11691c3 commit 374fe08
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public void sendMessage(final Account account, final Device device, final Envelo
if (clientPresent) {
messagesManager.insert(account.getUuid(), device.getId(), message.toBuilder().setEphemeral(true).build());
} else {
messagesManager.removeRecipientViewFromMrmData(account.getUuid(), device.getId(), message);
messagesManager.removeRecipientViewFromMrmData(device.getId(), message);
}
} else {
messagesManager.insert(account.getUuid(), device.getId(), message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
Expand Down Expand Up @@ -296,21 +295,25 @@ public CompletableFuture<List<RemovedMessage>> remove(final UUID destinationUuid
.thenApplyAsync(serialized -> {

final List<RemovedMessage> removedMessages = new ArrayList<>(serialized.size());
final List<byte[]> sharedMrmKeysToUpdate = new ArrayList<>();
final Map<ServiceIdentifier, List<byte[]>> serviceIdentifierToMrmKeys = new HashMap<>();

for (final byte[] bytes : serialized) {
try {
final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(bytes);
removedMessages.add(RemovedMessage.fromEnvelope(envelope));
if (envelope.hasSharedMrmKey()) {
sharedMrmKeysToUpdate.add(envelope.getSharedMrmKey().toByteArray());
serviceIdentifierToMrmKeys.computeIfAbsent(
ServiceIdentifier.valueOf(envelope.getDestinationServiceId()), ignored -> new ArrayList<>())
.add(envelope.getSharedMrmKey().toByteArray());
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}

removeRecipientViewFromMrmData(sharedMrmKeysToUpdate, destinationUuid, destinationDevice);
serviceIdentifierToMrmKeys.forEach(
(serviceId, keysToUpdate) -> removeRecipientViewFromMrmData(keysToUpdate, serviceId, destinationDevice));

return removedMessages;
}, messageDeletionExecutorService).whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer));

Expand Down Expand Up @@ -472,7 +475,8 @@ private Mono<?> maybeRunMrmViewExperiment(final MessageProtos.Envelope mrmMessag
/**
* Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure
*/
void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID accountUuid, final byte deviceId) {
void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final ServiceIdentifier serviceIdentifier,
final byte deviceId) {

if (sharedMrmKeys.isEmpty()) {
return;
Expand All @@ -483,7 +487,7 @@ void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID
.collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap(
keys -> removeRecipientViewFromMrmDataScript.execute(keys, new AciServiceIdentifier(accountUuid), deviceId),
keys -> removeRecipientViewFromMrmDataScript.execute(keys, serviceIdentifier, deviceId),
REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY)
.doOnNext(sharedMrmDataKeyRemovedCounter::increment)
.onErrorResume(e -> {
Expand Down Expand Up @@ -575,7 +579,7 @@ public CompletableFuture<Void> clear(final UUID destinationUuid, final byte devi
return Mono.empty();
}

final List<byte[]> mrmKeys = new ArrayList<>(messagesToProcess.size());
final Map<ServiceIdentifier, List<byte[]>> serviceIdentifierToMrmKeys = new HashMap<>();
final List<String> processedMessages = new ArrayList<>(messagesToProcess.size());
for (byte[] serialized : messagesToProcess) {
try {
Expand All @@ -584,14 +588,17 @@ public CompletableFuture<Void> clear(final UUID destinationUuid, final byte devi
processedMessages.add(message.getServerGuid());

if (message.hasSharedMrmKey()) {
mrmKeys.add(message.getSharedMrmKey().toByteArray());
serviceIdentifierToMrmKeys.computeIfAbsent(ServiceIdentifier.valueOf(message.getDestinationServiceId()),
ignored -> new ArrayList<>())
.add(message.getSharedMrmKey().toByteArray());
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}

removeRecipientViewFromMrmData(mrmKeys, destinationUuid, deviceId);
serviceIdentifierToMrmKeys.forEach((serviceId, keysToUpdate) ->
removeRecipientViewFromMrmData(keysToUpdate, serviceId, deviceId));

return removeQueueScript.execute(destinationUuid, deviceId, processedMessages);
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
Expand All @@ -30,8 +30,9 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScript {
"lua/remove_recipient_view_from_mrm_data.lua", ScriptOutputType.INTEGER);
}

Mono<Long> execute(final Collection<byte[]> keysCollection, final AciServiceIdentifier serviceIdentifier,
Mono<Long> execute(final Collection<byte[]> keysCollection, final ServiceIdentifier serviceIdentifier,
final byte deviceId) {

final List<byte[]> keys = keysCollection instanceof List<byte[]>
? (List<byte[]>) keysCollection
: new ArrayList<>(keysCollection);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.observability.micrometer.Micrometer;
Expand Down Expand Up @@ -214,10 +215,10 @@ public byte[] insertSharedMultiRecipientMessagePayload(
/**
* Removes the recipient's view from shared MRM data if necessary
*/
public void removeRecipientViewFromMrmData(final UUID destinationUuid, final byte destinationDeviceId,
final Envelope message) {
public void removeRecipientViewFromMrmData(final byte destinationDeviceId, final Envelope message) {
if (message.hasSharedMrmKey()) {
messagesCache.removeRecipientViewFromMrmData(List.of(message.getSharedMrmKey().toByteArray()), destinationUuid,
messagesCache.removeRecipientViewFromMrmData(List.of(message.getSharedMrmKey().toByteArray()),
ServiceIdentifier.valueOf(message.getDestinationServiceId()),
destinationDeviceId);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ void testSendOnlineMessageClientPresent() throws Exception {
MessageProtos.Envelope.class);

verify(messagesManager).insert(any(), anyByte(), envelopeArgumentCaptor.capture());
verify(messagesManager, never()).removeRecipientViewFromMrmData(any(), anyByte(),
any(MessageProtos.Envelope.class));
verify(messagesManager, never()).removeRecipientViewFromMrmData(anyByte(), any(MessageProtos.Envelope.class));

assertTrue(envelopeArgumentCaptor.getValue().getEphemeral());

Expand All @@ -96,7 +95,7 @@ void testSendOnlineMessageClientNotPresent(final boolean hasSharedMrmKey) throws
}

verify(messagesManager, never()).insert(any(), anyByte(), any());
verify(messagesManager).removeRecipientViewFromMrmData(any(), anyByte(), any(MessageProtos.Envelope.class));
verify(messagesManager).removeRecipientViewFromMrmData(anyByte(), any(MessageProtos.Envelope.class));

verifyNoInteractions(pushNotificationManager);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.lettuce.core.RedisCommandExecutionException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import io.lettuce.core.RedisCommandExecutionException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair;

Expand All @@ -32,7 +33,7 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {

@ParameterizedTest
@MethodSource
void testInsert(final int count, final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception {
void testInsert(final int count, final Map<ServiceIdentifier, List<Byte>> destinations) throws Exception {

final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
Expand All @@ -49,7 +50,7 @@ void testInsert(final int count, final Map<AciServiceIdentifier, List<Byte>> des
}

public static List<Arguments> testInsert() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of(
final Map<ServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));

final List<Arguments> testCases = new ArrayList<>();
Expand All @@ -58,7 +59,7 @@ public static List<Arguments> testInsert() {
for (int j = 1000; j <= 30000; j += 1000) {

final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
final Map<ServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> {
final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.publisher.Flux;
Expand All @@ -34,7 +35,7 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {

@ParameterizedTest
@MethodSource
void testUpdateSingleKey(final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception {
void testUpdateSingleKey(final Map<ServiceIdentifier, List<Byte>> destinations) throws Exception {

final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
Expand All @@ -48,8 +49,8 @@ void testUpdateSingleKey(final Map<AciServiceIdentifier, List<Byte>> destination

final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(destinations.entrySet())
.flatMap(e -> Flux.fromStream(e.getValue().stream().map(deviceId -> Tuples.of(e.getKey(), deviceId))))
.flatMap(aciServiceIdentifierByteTuple -> removeRecipientViewFromMrmDataScript.execute(List.of(sharedMrmKey),
aciServiceIdentifierByteTuple.getT1(), aciServiceIdentifierByteTuple.getT2()))
.flatMap(serviceIdentifierByteTuple -> removeRecipientViewFromMrmDataScript.execute(List.of(sharedMrmKey),
serviceIdentifierByteTuple.getT1(), serviceIdentifierByteTuple.getT2()))
.reduce(Long::sum)
.block(Duration.ofSeconds(35)));

Expand All @@ -60,18 +61,18 @@ void testUpdateSingleKey(final Map<AciServiceIdentifier, List<Byte>> destination
assertEquals(0, keyExists);
}

public static List<Map<AciServiceIdentifier, List<Byte>>> testUpdateSingleKey() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of(
public static List<Map<ServiceIdentifier, List<Byte>>> testUpdateSingleKey() {
final Map<ServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));

final List<Map<AciServiceIdentifier, List<Byte>>> testCases = new ArrayList<>();
final List<Map<ServiceIdentifier, List<Byte>>> testCases = new ArrayList<>();
testCases.add(singleAccount);

// Generate a more, from smallish to very large
for (int j = 1000; j <= 81000; j *= 3) {

final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
final Map<ServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> {
final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)
Expand All @@ -93,7 +94,7 @@ public static List<Map<AciServiceIdentifier, List<Byte>>> testUpdateSingleKey()
void testUpdateManyKeys(int keyCount) throws Exception {

final List<byte[]> sharedMrmKeys = new ArrayList<>(keyCount);
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1;

for (int i = 0; i < keyCount; i++) {
Expand All @@ -103,7 +104,7 @@ void testUpdateManyKeys(int keyCount) throws Exception {

final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(aciServiceIdentifier, deviceId));
MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId));

sharedMrmKeys.add(sharedMrmKey);
}
Expand All @@ -114,7 +115,7 @@ void testUpdateManyKeys(int keyCount) throws Exception {
final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(sharedMrmKeys)
.collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap(keys -> removeRecipientViewFromMrmDataScript.execute(keys, aciServiceIdentifier, deviceId))
.flatMap(keys -> removeRecipientViewFromMrmDataScript.execute(keys, serviceIdentifier, deviceId))
.reduce(Long::sum)
.block(Duration.ofSeconds(5)));

Expand Down
Loading

0 comments on commit 374fe08

Please sign in to comment.