Skip to content

Commit

Permalink
gRPC validations
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-signal committed Nov 3, 2023
1 parent 115431a commit db63ff6
Show file tree
Hide file tree
Showing 14 changed files with 1,289 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

package org.whispersystems.textsecuregcm.grpc;

import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.internalError;

import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.StatusException;
import java.util.Map;
import org.whispersystems.textsecuregcm.grpc.validators.E164FieldValidator;
import org.whispersystems.textsecuregcm.grpc.validators.EnumSpecifiedFieldValidator;
import org.whispersystems.textsecuregcm.grpc.validators.ExactlySizeFieldValidator;
import org.whispersystems.textsecuregcm.grpc.validators.FieldValidator;
import org.whispersystems.textsecuregcm.grpc.validators.NonEmptyFieldValidator;
import org.whispersystems.textsecuregcm.grpc.validators.RangeFieldValidator;
import org.whispersystems.textsecuregcm.grpc.validators.SizeFieldValidator;

public class ValidatingInterceptor implements ServerInterceptor {

private final Map<String, FieldValidator> fieldValidators = Map.of(
"org.signal.chat.require.nonEmpty", new NonEmptyFieldValidator(),
"org.signal.chat.require.specified", new EnumSpecifiedFieldValidator(),
"org.signal.chat.require.e164", new E164FieldValidator(),
"org.signal.chat.require.exactlySize", new ExactlySizeFieldValidator(),
"org.signal.chat.require.range", new RangeFieldValidator(),
"org.signal.chat.require.size", new SizeFieldValidator()
);

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {

// The way `UnaryServerCallHandler` (which is what we're wrapping here) is implemented
// is when `onMessage()` is called, the processing of the message doesn't immediately start
// and instead is delayed until `onHalfClose()` (which is the point when client says
// that no more messages will be sent). Then, in `onHalfClose()` it either tries to process
// the message if it's there, or reports an error if the message is not there.
// This means that the logic is not designed for the case of the call being closed by the interceptor.
// The only workaround is to not delegate calls to it in the case when we're closing the call
// because of the validation error.
private boolean forwardCalls = true;

@Override
public void onMessage(final ReqT message) {
try {
validateMessage(message);
super.onMessage(message);
} catch (final StatusException e) {
call.close(e.getStatus(), new Metadata());
forwardCalls = false;
}
}

@Override
public void onHalfClose() {
if (forwardCalls) {
super.onHalfClose();
}
}
};
}

private void validateMessage(final Object message) throws StatusException {
if (message instanceof GeneratedMessageV3 msg) {
try {
for (final Descriptors.FieldDescriptor fd: msg.getDescriptorForType().getFields()) {
for (final Map.Entry<Descriptors.FieldDescriptor, Object> entry: fd.getOptions().getAllFields().entrySet()) {
final Descriptors.FieldDescriptor extensionFieldDescriptor = entry.getKey();
final String extensionName = extensionFieldDescriptor.getFullName();
final FieldValidator validator = fieldValidators.get(extensionName);
// not all extensions are validators, so `validator` value here could legitimately be `null`
if (validator != null) {
validator.validate(entry.getValue(), fd, msg);
}
}
}
} catch (final StatusException e) {
throw e;
} catch (final Exception e) {
throw internalError(e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

package org.whispersystems.textsecuregcm.grpc.validators;

import static java.util.Objects.requireNonNull;
import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.internalError;
import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.invalidArgument;

import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import io.grpc.Status;
import io.grpc.StatusException;
import java.util.Set;

public abstract class BaseFieldValidator<T> implements FieldValidator {

private final String extensionName;

private final Set<Descriptors.FieldDescriptor.Type> supportedTypes;

private final MissingOptionalAction missingOptionalAction;

private final boolean applicableToRepeated;

protected enum MissingOptionalAction {
FAIL,
SUCCEED,
VALIDATE_DEFAULT_VALUE
}


protected BaseFieldValidator(
final String extensionName,
final Set<Descriptors.FieldDescriptor.Type> supportedTypes,
final MissingOptionalAction missingOptionalAction,
final boolean applicableToRepeated) {
this.extensionName = requireNonNull(extensionName);
this.supportedTypes = requireNonNull(supportedTypes);
this.missingOptionalAction = missingOptionalAction;
this.applicableToRepeated = applicableToRepeated;
}

@Override
public void validate(
final Object extensionValue,
final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException {
try {
final T extensionValueTyped = resolveExtensionValue(extensionValue);

// for the fields with an `optional` modifier, checking if the field was set
// and if not, checking if extension allows missing optional field
if (fd.hasPresence() && !msg.hasField(fd)) {
switch (missingOptionalAction) {
case FAIL -> {
throw invalidArgument("extension requires a value to be set");
}
case SUCCEED -> {
return;
}
case VALIDATE_DEFAULT_VALUE -> {
// just continuing
}
}
}

// for the `repeated` fields, checking if it's supported by the extension
if (fd.isRepeated()) {
if (applicableToRepeated) {
validateRepeatedField(extensionValueTyped, fd, msg);
return;
}
throw internalError("can't apply extension to a `repeated` field");
}

// checking field type against the set of supported types
final Descriptors.FieldDescriptor.Type type = fd.getType();
if (!supportedTypes.contains(type)) {
throw internalError("can't apply extension to a field of type [%s]".formatted(type));
}
switch (type) {
case INT64, UINT64, INT32, FIXED64, FIXED32, UINT32, SFIXED32, SFIXED64, SINT32, SINT64 ->
validateIntegerNumber(extensionValueTyped, ((Number) msg.getField(fd)).longValue(), type);
case STRING ->
validateStringValue(extensionValueTyped, (String) msg.getField(fd));
case BYTES ->
validateBytesValue(extensionValueTyped, (ByteString) msg.getField(fd));
case ENUM ->
validateEnumValue(extensionValueTyped, (Descriptors.EnumValueDescriptor) msg.getField(fd));
case FLOAT, DOUBLE, BOOL, MESSAGE, GROUP -> {
// at this moment, there are no validations specific to these types of fields
}
}
} catch (StatusException e) {
throw new StatusException(e.getStatus().withDescription(
"field [%s], extension [%s]: %s".formatted(fd.getName(), extensionName, e.getStatus().getDescription())
), e.getTrailers());
} catch (RuntimeException e) {
throw Status.INTERNAL
.withDescription("field [%s], extension [%s]: %s".formatted(fd.getName(), extensionName, e.getMessage()))
.withCause(e)
.asException();
}
}

protected abstract T resolveExtensionValue(final Object extensionValue) throws StatusException;

protected void validateRepeatedField(
final T extensionValue,
final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException {
throw internalError("`validateRepeatedField` method needs to be implemented");
}

protected void validateIntegerNumber(
final T extensionValue,
final long fieldValue, final Descriptors.FieldDescriptor.Type type) throws StatusException {
throw internalError("`validateIntegerNumber` method needs to be implemented");
}

protected void validateStringValue(
final T extensionValue,
final String fieldValue) throws StatusException {
throw internalError("`validateStringValue` method needs to be implemented");
}

protected void validateBytesValue(
final T extensionValue,
final ByteString fieldValue) throws StatusException {
throw internalError("`validateBytesValue` method needs to be implemented");
}

protected void validateEnumValue(
final T extensionValue,
final Descriptors.EnumValueDescriptor enumValueDescriptor) throws StatusException {
throw internalError("`validateEnumValue` method needs to be implemented");
}

protected static boolean requireFlagExtension(final Object extensionValue) throws StatusException {
if (extensionValue instanceof Boolean flagIsOn && flagIsOn) {
return true;
}
throw internalError("only value `true` is allowed");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

package org.whispersystems.textsecuregcm.grpc.validators;

import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.invalidArgument;

import com.google.protobuf.Descriptors;
import io.grpc.StatusException;
import java.util.Set;
import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException;
import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException;
import org.whispersystems.textsecuregcm.util.Util;

public class E164FieldValidator extends BaseFieldValidator<Boolean> {

public E164FieldValidator() {
super("e164", Set.of(Descriptors.FieldDescriptor.Type.STRING), MissingOptionalAction.SUCCEED, false);
}

@Override
protected Boolean resolveExtensionValue(final Object extensionValue) throws StatusException {
return requireFlagExtension(extensionValue);
}

@Override
protected void validateStringValue(
final Boolean extensionValue,
final String fieldValue) throws StatusException {
try {
Util.requireNormalizedNumber(fieldValue);
} catch (final ImpossiblePhoneNumberException | NonNormalizedPhoneNumberException e) {
throw invalidArgument("value is not in E164 format");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

package org.whispersystems.textsecuregcm.grpc.validators;

import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.invalidArgument;

import com.google.protobuf.Descriptors;
import io.grpc.StatusException;
import java.util.Set;

public class EnumSpecifiedFieldValidator extends BaseFieldValidator<Boolean> {

public EnumSpecifiedFieldValidator() {
super("specified", Set.of(Descriptors.FieldDescriptor.Type.ENUM), MissingOptionalAction.FAIL, false);
}

@Override
protected Boolean resolveExtensionValue(final Object extensionValue) throws StatusException {
return requireFlagExtension(extensionValue);
}

@Override
protected void validateEnumValue(
final Boolean extensionValue,
final Descriptors.EnumValueDescriptor enumValueDescriptor) throws StatusException {
if (enumValueDescriptor.getIndex() <= 0) {
throw invalidArgument("enum field must be specified");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

package org.whispersystems.textsecuregcm.grpc.validators;

import static org.whispersystems.textsecuregcm.grpc.validators.ValidatorUtils.invalidArgument;

import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.GeneratedMessageV3;
import io.grpc.StatusException;
import java.util.List;
import java.util.Set;

public class ExactlySizeFieldValidator extends BaseFieldValidator<Set<Integer>> {

public ExactlySizeFieldValidator() {
super("exactlySize", Set.of(
Descriptors.FieldDescriptor.Type.STRING,
Descriptors.FieldDescriptor.Type.BYTES
), MissingOptionalAction.VALIDATE_DEFAULT_VALUE, true);
}

@Override
protected Set<Integer> resolveExtensionValue(final Object extensionValue) throws StatusException {
//noinspection unchecked
return Set.copyOf((List<Integer>) extensionValue);
}

@Override
protected void validateBytesValue(
final Set<Integer> permittedSizes,
final ByteString fieldValue) throws StatusException {
if (permittedSizes.contains(fieldValue.size())) {
return;
}
throw invalidArgument("byte arrray length is [%d] but expected to be one of %s".formatted(fieldValue.size(), permittedSizes));
}

@Override
protected void validateStringValue(
final Set<Integer> permittedSizes,
final String fieldValue) throws StatusException {
if (permittedSizes.contains(fieldValue.length())) {
return;
}
throw invalidArgument("string length is [%d] but expected to be one of %s".formatted(fieldValue.length(), permittedSizes));
}

@Override
protected void validateRepeatedField(
final Set<Integer> permittedSizes,
final Descriptors.FieldDescriptor fd,
final GeneratedMessageV3 msg) throws StatusException {
final int size = msg.getRepeatedFieldCount(fd);
if (permittedSizes.contains(size)) {
return;
}
throw invalidArgument("list size is [%d] but expected to be one of %s".formatted(size, permittedSizes));
}
}
Loading

0 comments on commit db63ff6

Please sign in to comment.