From 041aa8639aaa800d101dca22bf640766f529eb72 Mon Sep 17 00:00:00 2001 From: Katherine Date: Thu, 16 Nov 2023 12:36:43 -0500 Subject: [PATCH] Enforce story ratelimit --- .../controllers/MessageController.java | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 951f8de3b..ff485db80 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -132,7 +132,6 @@ public class MessageController { private final ReportSpamTokenProvider reportSpamTokenProvider; private final ClientReleaseManager clientReleaseManager; private final DynamicConfigurationManager dynamicConfigurationManager; - private static final String REJECT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "rejectOversizeMessage"); private static final String SENT_MESSAGE_COUNTER_NAME = name(MessageController.class, "sentMessages"); private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize"); @@ -279,7 +278,7 @@ public Response sendMessage(@Auth Optional source, } if (isStory) { - checkStoryRateLimit(destination.get(), userAgent); + rateLimiters.getStoriesLimiter().validate(destination.get().getUuid()); } final Set excludedDeviceIds; @@ -378,7 +377,7 @@ public Response sendMultiRecipientMessage( @QueryParam("ts") long timestamp, @QueryParam("urgent") @DefaultValue("true") final boolean isUrgent, @QueryParam("story") boolean isStory, - @NotNull @Valid MultiRecipientMessage multiRecipientMessage) { + @NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException { final Map accountsByServiceIdentifier = new HashMap<>(); @@ -412,17 +411,20 @@ public Response sendMultiRecipientMessage( Collection accountMismatchedDevices = new ArrayList<>(); Collection accountStaleDevices = new ArrayList<>(); - accountsByServiceIdentifier.forEach((serviceIdentifier, account) -> { + + for (Map.Entry entry : accountsByServiceIdentifier.entrySet()) { + final ServiceIdentifier serviceIdentifier = entry.getKey(); + final Account account = entry.getValue(); if (isStory) { - checkStoryRateLimit(account, userAgent); + rateLimiters.getStoriesLimiter().validate(account.getUuid()); } Set deviceIds = accountToDeviceIdAndRegistrationIdMap - .getOrDefault(account, Collections.emptySet()) - .stream() - .map(Pair::first) - .collect(Collectors.toSet()); + .getOrDefault(account, Collections.emptySet()) + .stream() + .map(Pair::first) + .collect(Collectors.toSet()); try { DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); @@ -439,7 +441,8 @@ public Response sendMultiRecipientMessage( } catch (StaleDevicesException e) { accountStaleDevices.add(new AccountStaleDevices(serviceIdentifier, new StaleDevices(e.getStaleDevices()))); } - }); + } + if (!accountMismatchedDevices.isEmpty()) { return Response .status(409) @@ -735,14 +738,6 @@ private void sendCommonPayloadMessage(Account destinationAccount, } } - private void checkStoryRateLimit(Account destination, String userAgent) { - try { - rateLimiters.getStoriesLimiter().validate(destination.getUuid()); - } catch (final RateLimitExceededException e) { - Metrics.counter(RATE_LIMITED_STORIES_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))).increment(); - } - } - private void checkMessageRateLimit(AuthenticatedAccount source, Account destination, String userAgent) throws RateLimitExceededException { final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber());