Skip to content

Commit

Permalink
Refactor key-fetching to be reactive
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal committed Dec 13, 2023
1 parent 4ce060a commit 609c901
Showing 1 changed file with 26 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -40,7 +42,6 @@
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
Expand All @@ -59,6 +60,8 @@
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v2/keys")
Expand Down Expand Up @@ -237,43 +240,28 @@ public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
io.micrometer.core.instrument.Tag.of("wildcardDeviceId", String.valueOf("*".equals(deviceId)))))
.increment();

final List<Device> devices = parseDeviceId(deviceId, target);
final List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());

final List<CompletableFuture<Void>> tasks = devices.stream().map(device -> {
final CompletableFuture<Optional<ECPreKey>> unsignedEcPreKeyFuture =
keys.takeEC(targetIdentifier.uuid(), device.getId());

final CompletableFuture<Optional<ECSignedPreKey>> signedEcPreKeyFuture =
keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId());

final CompletableFuture<Optional<KEMSignedPreKey>> pqPreKeyFuture = returnPqKey
? keys.takePQ(targetIdentifier.uuid(), device.getId())
: CompletableFuture.completedFuture(Optional.empty());

return CompletableFuture.allOf(unsignedEcPreKeyFuture, signedEcPreKeyFuture, pqPreKeyFuture)
.thenAccept(ignored -> {
final KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null);
final ECPreKey unsignedEcPreKey = unsignedEcPreKeyFuture.join().orElse(null);
final ECSignedPreKey signedEcPreKey = signedEcPreKeyFuture.join().orElse(null);

if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};

synchronized (responseItems) {
responseItems.add(
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey,
pqPreKey));
}
}
});
})
.toList();

CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join();
final List<PreKeyResponseItem> responseItems = Flux.fromIterable(parseDeviceId(deviceId, target))
.flatMap(device -> Mono.zip(
Mono.just(device),
Mono.fromFuture(() -> keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())),
Mono.fromFuture(() -> keys.takeEC(targetIdentifier.uuid(), device.getId())),
Mono.fromFuture(() -> returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId())
: CompletableFuture.<Optional<KEMSignedPreKey>>completedFuture(Optional.empty()))
)).filter(keys -> keys.getT2().isPresent() || keys.getT3().isPresent() || keys.getT4().isPresent())
.map(deviceAndKeys -> {
final Device device = deviceAndKeys.getT1();
final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};
return new PreKeyResponseItem(device.getId(), registrationId,
deviceAndKeys.getT2().orElse(null),
deviceAndKeys.getT3().orElse(null),
deviceAndKeys.getT4().orElse(null));
}).collectList()
.timeout(Duration.ofSeconds(30))
.blockOptional()
.orElse(Collections.emptyList());

final IdentityKey identityKey = target.getIdentityKey(targetIdentifier.identityType());

Expand Down

0 comments on commit 609c901

Please sign in to comment.