Skip to content

Commit

Permalink
multisend cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jkt-signal authored Nov 30, 2023
1 parent 22e6584 commit c03249b
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 344 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;

import java.security.MessageDigest;
import java.time.Duration;
import java.util.ArrayList;
Expand All @@ -37,6 +43,8 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
Expand Down Expand Up @@ -87,6 +95,7 @@
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
Expand All @@ -111,13 +120,21 @@
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.Stories;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/messages")
@io.swagger.v3.oas.annotations.tags.Tag(name = "Messages")
public class MessageController {

private record MessageRecipient(
ServiceIdentifier serviceIdentifier,
Account account,
Map<Byte, Recipient> perDeviceData) {
}

private static final Logger logger = LoggerFactory.getLogger(MessageController.class);

private final RateLimiters rateLimiters;
Expand All @@ -138,9 +155,9 @@ public class MessageController {
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize");
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage");
private static final String RATE_LIMITED_STORIES_COUNTER_NAME = name(MessageController.class, "rateLimitedStory");

private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType");
private static final String UNEXPECTED_MISSING_USER_COUNTER_NAME = name(MessageController.class, "unexpectedMissingDestinationForMultiRecipientMessage");

private static final String EPHEMERAL_TAG_NAME = "ephemeral";
private static final String SENDER_TYPE_TAG_NAME = "senderType";
Expand Down Expand Up @@ -343,26 +360,25 @@ public Response sendMessage(@Auth Optional<AuthenticatedAccount> source,


/**
* Build mapping of accounts to devices/registration IDs.
* Build mapping of service IDs to resolved accounts and device/registration IDs
*/
private Map<Account, Set<Pair<Byte, Integer>>> buildDeviceIdAndRegistrationIdMap(
MultiRecipientMessage multiRecipientMessage,
Map<ServiceIdentifier, Account> accountsByServiceIdentifier) {

return Arrays.stream(multiRecipientMessage.recipients())
// for normal messages, all recipients UUIDs are in the map,
// but story messages might specify inactive UUIDs, which we
// have previously filtered
.filter(r -> accountsByServiceIdentifier.containsKey(r.uuid()))
.collect(Collectors.toMap(
recipient -> accountsByServiceIdentifier.get(recipient.uuid()),
recipient -> new HashSet<>(
Collections.singletonList(new Pair<>(recipient.deviceId(), recipient.registrationId()))),
(a, b) -> {
a.addAll(b);
return a;
}
));
private Map<ServiceIdentifier, MessageRecipient> buildRecipientMap(
MultiRecipientMessage multiRecipientMessage, boolean isStory) {
return Flux.fromArray(multiRecipientMessage.recipients())
.groupBy(Recipient::uuid)
.flatMap(
gf -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(gf.key()))
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
.flatMap(
account ->
gf.collectMap(Recipient::deviceId)
.map(perRecipientData ->
new MessageRecipient(
gf.key(),
account,
perRecipientData))))
.collectMap(MessageRecipient::serviceIdentifier)
.block();
}

@Timed
Expand All @@ -371,79 +387,87 @@ private Map<Account, Set<Pair<Byte, Integer>>> buildDeviceIdAndRegistrationIdMap
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
@Produces(MediaType.APPLICATION_JSON)
@FilterSpam
@Operation(
summary = "Send multi-recipient sealed-sender message",
description = """
Deliver a common-payload message to multiple recipients.
An unidentifed-access key for all recipients must be provided, unless the message is a story.
""")
@ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true)
@ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times")
@ApiResponse(responseCode="401", description="The message is not a story and the unauthorized access key is incorrect")
@ApiResponse(
responseCode="404",
description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users")
@ApiResponse(
responseCode = "409", description = "Incorrect set of devices supplied for some recipients",
content = @Content(schema = @Schema(implementation = AccountMismatchedDevices[].class)))
@ApiResponse(
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class)))

public Response sendMultiRecipientMessage(
@Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message")
@HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys,

@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,

@Parameter(description="If true, deliver the message only to recipients that are online when it is sent")
@QueryParam("online") boolean online,

@Parameter(description="The sender's timestamp for the envelope")
@QueryParam("ts") long timestamp,

@Parameter(description="If true, this message should cause push notifications to be sent to recipients")
@QueryParam("urgent") @DefaultValue("true") final boolean isUrgent,

@Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted")
@QueryParam("story") boolean isStory,
@Parameter(description="The sealed-sender multi-recipient message payload")
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {

final Map<ServiceIdentifier, Account> accountsByServiceIdentifier = new HashMap<>();

for (final Recipient recipient : multiRecipientMessage.recipients()) {
if (!accountsByServiceIdentifier.containsKey(recipient.uuid())) {
final Optional<Account> maybeAccount = accountsManager.getByServiceIdentifier(recipient.uuid());

if (maybeAccount.isPresent()) {
accountsByServiceIdentifier.put(recipient.uuid(), maybeAccount.get());
} else {
if (!isStory) {
throw new NotFoundException();
}
}
}
}
final Map<ServiceIdentifier, MessageRecipient> recipients = buildRecipientMap(multiRecipientMessage, isStory);

// Stories will be checked by the client; we bypass access checks here for stories.
if (!isStory) {
checkAccessKeys(accessKeys, accountsByServiceIdentifier.values());
checkAccessKeys(accessKeys, recipients.values());
}

final Map<Account, Set<Pair<Byte, Integer>>> accountToDeviceIdAndRegistrationIdMap =
buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier);

// We might filter out all the recipients of a story (if none have enabled stories).
// We might filter out all the recipients of a story (if none exist).
// In this case there is no error so we should just return 200 now.
if (isStory && accountToDeviceIdAndRegistrationIdMap.isEmpty()) {
return Response.ok(new SendMultiRecipientMessageResponse(new LinkedList<>())).build();
}

Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();

for (Map.Entry<ServiceIdentifier, Account> entry : accountsByServiceIdentifier.entrySet()) {
final ServiceIdentifier serviceIdentifier = entry.getKey();
final Account account = entry.getValue();

if (isStory) {
rateLimiters.getStoriesLimiter().validate(account.getUuid());
if (isStory) {
if (recipients.isEmpty()) {
return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build();
}

Set<Byte> deviceIds = accountToDeviceIdAndRegistrationIdMap
.getOrDefault(account, Collections.emptySet())
.stream()
.map(Pair::first)
.collect(Collectors.toSet());

try {
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());

// Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number
// identity
DestinationDeviceValidator.validateRegistrationIds(
account,
accountToDeviceIdAndRegistrationIdMap.get(account).stream(),
false);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(serviceIdentifier,
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(new AccountStaleDevices(serviceIdentifier, new StaleDevices(e.getStaleDevices())));
for (MessageRecipient recipient : recipients.values()) {
rateLimiters.getStoriesLimiter().validate(recipient.account().getUuid());
}
}

Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
recipients.values().forEach(recipient -> {
final Account account = recipient.account();

try {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.perDeviceData().keySet(), Collections.emptySet());

DestinationDeviceValidator.validateRegistrationIds(
account,
recipient.perDeviceData().values(),
Recipient::deviceId,
Recipient::registrationId,
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
recipient.serviceIdentifier(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
}
});
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
Expand All @@ -468,25 +492,28 @@ public Response sendMultiRecipientMessage(
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED)));

CompletableFuture.allOf(
Arrays.stream(multiRecipientMessage.recipients())
// If we're sending a story, some recipients might not map to existing accounts
.filter(recipient -> accountsByServiceIdentifier.containsKey(recipient.uuid()))
.map(
recipient -> CompletableFuture.runAsync(
() -> {
Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid());

// we asserted this must exist in validateCompleteDeviceList
Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
sentMessageCounter.increment();
try {
sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent,
recipient, multiRecipientMessage.commonPayload());
} catch (NoSuchUserException e) {
uuids404.add(recipient.uuid());
}
},
multiRecipientMessageExecutor))
recipients.values().stream()
.flatMap(recipientData ->
recipientData.perDeviceData().values().stream().map(
recipient -> CompletableFuture.runAsync(
() -> {
final Account destinationAccount = recipientData.account();
// we asserted this must exist in validateCompleteDeviceList
final Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
try {
sentMessageCounter.increment();
sendCommonPayloadMessage(
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online,
isStory, isUrgent, recipient, multiRecipientMessage.commonPayload());
} catch (NoSuchUserException e) {
// this should never happen, because we already asserted the device is present and enabled
Metrics.counter(
UNEXPECTED_MISSING_USER_COUNTER_NAME,
Tags.of("isPrimary", String.valueOf(destinationDevice.isPrimary()))).increment();
uuids404.add(recipientData.serviceIdentifier());
}
},
multiRecipientMessageExecutor)))
.toArray(CompletableFuture[]::new))
.get();
} catch (InterruptedException e) {
Expand All @@ -502,43 +529,31 @@ public Response sendMultiRecipientMessage(
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
}

private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<Account> destinationAccounts) {
private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<MessageRecipient> destinations) {
// We should not have null access keys when checking access; bail out early.
if (accessKeys == null) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
AtomicBoolean throwUnauthorized = new AtomicBoolean(false);
byte[] empty = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH];
final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
byte[] combinedUnknownAccessKeys = destinationAccounts.stream()
.map(account -> {
if (account.isUnrestrictedUnidentifiedAccess()) {
return UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY;
} else {
return account.getUnidentifiedAccessKey();
}
})
.map(accessKey -> {
if (accessKey.isEmpty()) {
throwUnauthorized.set(true);
return empty;
}
return accessKey.get();
})
.reduce(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH], (bytes, bytes2) -> {
if (bytes.length != bytes2.length) {
throwUnauthorized.set(true);
return bytes;
}
for (int i = 0; i < bytes.length; i++) {
bytes[i] ^= bytes2[i];
}
return bytes;
});
if (throwUnauthorized.get()
|| !MessageDigest.isEqual(combinedUnknownAccessKeys, accessKeys.getAccessKeys())) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
destinations.stream()
.map(MessageRecipient::account)
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess))
.map(account -> account.getUnidentifiedAccessKey().orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
.reduce(
(bytes, bytes2) -> {
if (bytes.length != bytes2.length) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
for (int i = 0; i < bytes.length; i++) {
bytes[i] ^= bytes2[i];
}
return bytes;
})
.ifPresent(
combinedUnidentifiedAccessKeys -> {
if (!MessageDigest.isEqual(combinedUnidentifiedAccessKeys, accessKeys.getAccessKeys())) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
});
}

@Timed
Expand Down Expand Up @@ -716,6 +731,7 @@ private void sendIndividualMessage(

private void sendCommonPayloadMessage(Account destinationAccount,
Device destinationDevice,
ServiceIdentifier serviceIdentifier,
long timestamp,
boolean online,
boolean story,
Expand All @@ -739,7 +755,7 @@ private void sendCommonPayloadMessage(Account destinationAccount,
.setContent(ByteString.copyFrom(payload))
.setStory(story)
.setUrgent(urgent)
.setDestinationUuid(new AciServiceIdentifier(destinationAccount.getUuid()).toServiceIdentifierString());
.setDestinationUuid(serviceIdentifier.toServiceIdentifierString());

messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
} catch (NotPushRegisteredException e) {
Expand Down
Loading

0 comments on commit c03249b

Please sign in to comment.