Skip to content

Commit

Permalink
report exceptions in fanned-out sends of multi-recipient messages
Browse files Browse the repository at this point in the history
  • Loading branch information
jkt-signal authored Nov 20, 2023
1 parent db7f18a commit cb1fc73
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -466,25 +467,37 @@ public Response sendMultiRecipientMessage(
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED)));

multiRecipientMessageExecutor.invokeAll(Arrays.stream(multiRecipientMessage.recipients())
.map(recipient -> (Callable<Void>) () -> {
Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid());

// we asserted this must exist in validateCompleteDeviceList
Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
sentMessageCounter.increment();
try {
sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent,
recipient, multiRecipientMessage.commonPayload());
} catch (NoSuchUserException e) {
uuids404.add(recipient.uuid());
}
return null;
})
.collect(Collectors.toList()));
CompletableFuture.allOf(
Arrays.stream(multiRecipientMessage.recipients())
// If we're sending a story, some recipients might not map to existing accounts
.filter(recipient -> accountsByServiceIdentifier.containsKey(recipient.uuid()))
.map(
recipient -> CompletableFuture.runAsync(
() -> {
Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid());

// we asserted this must exist in validateCompleteDeviceList
Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
sentMessageCounter.increment();
try {
sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent,
recipient, multiRecipientMessage.commonPayload());
} catch (NoSuchUserException e) {
uuids404.add(recipient.uuid());
}
},
multiRecipientMessageExecutor))
.toArray(CompletableFuture[]::new))
.get();
} catch (InterruptedException e) {
logger.error("interrupted while delivering multi-recipient messages", e);
return Response.serverError().entity("interrupted during delivery").build();
} catch (CancellationException e) {
logger.error("cancelled while delivering multi-recipient messages", e);
return Response.serverError().entity("delivery cancelled").build();
} catch (ExecutionException e) {
logger.error("partial failure while delivering multi-recipient messages", e.getCause());
return Response.serverError().entity("failure during delivery").build();
}
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
Expand Down Expand Up @@ -166,7 +167,7 @@ class MessageControllerTest {
private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class);
private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService();
private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
private static final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);

Expand Down Expand Up @@ -252,8 +253,7 @@ void teardown() {
rateLimiter,
cardinalityEstimator,
pushNotificationManager,
reportMessageManager,
multiRecipientMessageExecutor
reportMessageManager
);
}

Expand Down Expand Up @@ -990,19 +990,6 @@ void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean is
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);

when(multiRecipientMessageExecutor.invokeAll(any()))
.thenAnswer(answer -> {
final List<Callable> tasks = answer.getArgument(0, List.class);
tasks.forEach(c -> {
try {
c.call();
} catch (Exception e) {
throw new RuntimeException(e);
}
});
return null;
});

// start building the request
Invocation.Builder bldr = resources
.getJerseyTest()
Expand Down Expand Up @@ -1110,6 +1097,32 @@ private static Stream<Arguments> testMultiRecipientMessage() {
);
}

@Test
void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception {
UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty());

final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1,
new byte[48]),
new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48]));

Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1700000000000L)
.queryParam("story", true)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true),
MultiRecipientMessageProvider.MEDIA_TYPE));

checkGoodMultiRecipientResponse(response, 1);
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
Expand Down Expand Up @@ -1316,19 +1329,6 @@ private static Stream<Arguments> sendMultiRecipientMessageStaleDevices() {
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier)
throws NotPushRegisteredException, InterruptedException {

when(multiRecipientMessageExecutor.invokeAll(any()))
.thenAnswer(answer -> {
final List<Callable> tasks = answer.getArgument(0, List.class);
tasks.forEach(c -> {
try {
c.call();
} catch (Exception e) {
throw new RuntimeException(e);
}
});
return null;
});

final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
Expand Down Expand Up @@ -1371,14 +1371,12 @@ private static Stream<Arguments> sendMultiRecipientMessage404() {
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
verify(multiRecipientMessageExecutor, never()).invokeAll(any());
}

private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class);
verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture());
assert (captor.getValue().size() == expectedCount);
verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean());
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assert (smrmr.uuids404().isEmpty());
}
Expand Down

0 comments on commit cb1fc73

Please sign in to comment.