Skip to content

Commit

Permalink
WIP: RulePath support
Browse files Browse the repository at this point in the history
  • Loading branch information
jchadwick-buf committed Oct 22, 2024
1 parent 0e335b4 commit 0f4c41c
Show file tree
Hide file tree
Showing 16 changed files with 1,393 additions and 50 deletions.
17 changes: 1 addition & 16 deletions internal/cmd/protovalidate-conformance-go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import (
"os"
"strings"

"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"github.com/bufbuild/protovalidate-go"
"github.com/bufbuild/protovalidate-go/internal/errors"
"github.com/bufbuild/protovalidate-go/internal/gen/buf/validate/conformance/harness"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
Expand Down Expand Up @@ -115,7 +113,7 @@ func TestCase(val *protovalidate.Validator, files *protoregistry.Files, testCase
case *protovalidate.ValidationError:
return &harness.TestResult{
Result: &harness.TestResult_ValidationError{
ValidationError: violationsToProto(res.Violations),
ValidationError: res.ToProto(),
},
}
case *protovalidate.RuntimeError:
Expand All @@ -135,19 +133,6 @@ func TestCase(val *protovalidate.Validator, files *protoregistry.Files, testCase
}
}

func violationsToProto(violations []errors.Violation) *validate.Violations {
result := make([]*validate.Violation, len(violations))
for i := range violations {
result[i] = &validate.Violation{
FieldPath: &violations[i].FieldPath,
ConstraintId: &violations[i].ConstraintID,
Message: &violations[i].Message,
ForKey: &violations[i].ForKey,
}
}
return &validate.Violations{Violations: result}
}

func unexpectedErrorResult(format string, args ...any) *harness.TestResult {
return &harness.TestResult{
Result: &harness.TestResult_UnexpectedError{
Expand Down
31 changes: 22 additions & 9 deletions internal/constraints/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ func (c *Cache) Build(
fieldConstraints *validate.FieldConstraints,
extensionTypeResolver protoregistry.ExtensionTypeResolver,
allowUnknownFields bool,
forMap bool,
forItems bool,
) (set expression.ProgramSet, err error) {
constraints, done, err := c.resolveConstraints(
constraints, setOneof, done, err := c.resolveConstraints(
fieldDesc,
fieldConstraints,
forItems,
Expand All @@ -58,6 +59,17 @@ func (c *Cache) Build(
return nil, err
}

rulePath := ""
if forMap && forItems {
rulePath += "map.values."
} else if forMap && !forItems {
rulePath += "map.keys."
} else if forItems {
rulePath += "repeated.items."
}

rulePath += string(setOneof.Name()) + "."

if err = reparseUnrecognized(extensionTypeResolver, constraints); err != nil {
return nil, errors.NewCompilationErrorf("error reparsing message: %w", err)
}
Expand All @@ -71,12 +83,12 @@ func (c *Cache) Build(
}

var asts expression.ASTSet
constraints.Range(func(desc protoreflect.FieldDescriptor, rule protoreflect.Value) bool {
constraints.Range(func(desc protoreflect.FieldDescriptor, ruleValue protoreflect.Value) bool {
fieldEnv, compileErr := env.Extend(
cel.Constant(
"rule",
celext.ProtoFieldToCELType(desc, true, false),
celext.ProtoFieldToCELValue(desc, rule, false),
celext.ProtoFieldToCELValue(desc, ruleValue, false),
),
)
if compileErr != nil {
Expand All @@ -88,7 +100,8 @@ func (c *Cache) Build(
err = compileErr
return false
}
precomputedASTs.SetRuleValue(rule)
rulePath := rulePath + desc.TextName()
precomputedASTs.SetRule(rulePath, ruleValue)
asts = asts.Merge(precomputedASTs)
return true
})
Expand All @@ -109,26 +122,26 @@ func (c *Cache) resolveConstraints(
fieldDesc protoreflect.FieldDescriptor,
fieldConstraints *validate.FieldConstraints,
forItems bool,
) (rules protoreflect.Message, done bool, err error) {
) (rules protoreflect.Message, fieldRule protoreflect.FieldDescriptor, done bool, err error) {
constraints := fieldConstraints.ProtoReflect()
setOneof := constraints.WhichOneof(fieldConstraintsOneofDesc)
if setOneof == nil {
return nil, true, nil
return nil, nil, true, nil
}
expected, ok := c.getExpectedConstraintDescriptor(fieldDesc, forItems)
if ok && setOneof.FullName() != expected.FullName() {
return nil, true, errors.NewCompilationErrorf(
return nil, nil, true, errors.NewCompilationErrorf(
"expected constraint %q, got %q on field %q",
expected.FullName(),
setOneof.FullName(),
fieldDesc.FullName(),
)
}
if !ok || !constraints.Has(setOneof) {
return nil, true, nil
return nil, nil, true, nil
}
rules = constraints.Get(setOneof).Message()
return rules, false, nil
return rules, setOneof, false, nil
}

// prepareEnvironment prepares the environment for compiling standard constraint
Expand Down
2 changes: 1 addition & 1 deletion internal/constraints/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func TestCache_BuildStandardConstraints(t *testing.T) {
require.NoError(t, err)
c := NewCache()

set, err := c.Build(env, test.desc, test.cons, protoregistry.GlobalTypes, false, test.forItems)
set, err := c.Build(env, test.desc, test.cons, protoregistry.GlobalTypes, false, false, test.forItems)
if test.exErr {
assert.Error(t, err)
} else {
Expand Down
204 changes: 204 additions & 0 deletions internal/errors/fieldpath.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// Copyright 2023-2024 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package errors

import (
"errors"
"fmt"
"strconv"
"strings"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)

// getFieldValue returns the field value at a given path, using the provided
// registry to resolve extensions.
func getFieldValue(
registry protoregistry.ExtensionTypeResolver,
message proto.Message,
path string,
) (field protoreflect.Value, descriptor protoreflect.FieldDescriptor, err error) {
var name, subscript string
var atEnd, isExt bool
reflectMessage := message.ProtoReflect()
for !atEnd {
name, subscript, path, atEnd, isExt = parsePathElement(path)
if name == "" {
return protoreflect.Value{}, nil, errors.New("empty field name")
}
var descriptor protoreflect.FieldDescriptor
if isExt {
extension, err := registry.FindExtensionByName(protoreflect.FullName(name))
if err != nil {
return protoreflect.Value{}, nil, fmt.Errorf("resolving extension: %w", err)
}
descriptor = extension.TypeDescriptor()
} else {
descriptor = reflectMessage.Descriptor().Fields().ByTextName(name)
}
if descriptor == nil {
return protoreflect.Value{}, nil, fmt.Errorf("field %s not found", name)
}
field = reflectMessage.Get(descriptor)
if subscript != "" {
descriptor, field, err = traverseSubscript(descriptor, subscript, field, name)
if err != nil {
return protoreflect.Value{}, nil, err
}
} else if descriptor.IsList() || descriptor.IsMap() {
if atEnd {
break
}
return protoreflect.Value{}, nil, fmt.Errorf("missing subscript on field %s", name)
}
if descriptor.Message() != nil {
reflectMessage = field.Message()
}
}
return field, descriptor, nil
}

func traverseSubscript(
descriptor protoreflect.FieldDescriptor,
subscript string,
field protoreflect.Value,
name string,
) (protoreflect.FieldDescriptor, protoreflect.Value, error) {
switch {
case descriptor.IsList():
i, err := strconv.Atoi(subscript)
if err != nil {
return nil, protoreflect.Value{}, fmt.Errorf("invalid list index: %s", subscript)
}
if !field.IsValid() || i >= field.List().Len() {
return nil, protoreflect.Value{}, fmt.Errorf("index %d out of bounds of field %s", i, name)
}
field = field.List().Get(i)
case descriptor.IsMap():
key, err := parseMapKey(descriptor, subscript)
if err != nil {
return nil, protoreflect.Value{}, err
}
field = field.Map().Get(key)
if !field.IsValid() {
return nil, protoreflect.Value{}, fmt.Errorf("key %s not present on field %s", subscript, name)
}
descriptor = descriptor.MapValue()
default:
return nil, protoreflect.Value{}, fmt.Errorf("unexpected subscript on field %s", name)
}
return descriptor, field, nil
}

func parseMapKey(mapDescriptor protoreflect.FieldDescriptor, subscript string) (protoreflect.MapKey, error) {
switch mapDescriptor.MapKey().Kind() {
case protoreflect.BoolKind:
if boolValue, err := strconv.ParseBool(subscript); err == nil {
return protoreflect.MapKey(protoreflect.ValueOfBool(boolValue)), nil
}
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
if intValue, err := strconv.ParseInt(subscript, 10, 32); err == nil {
return protoreflect.MapKey(protoreflect.ValueOfInt32(int32(intValue))), nil
}
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
if intValue, err := strconv.ParseInt(subscript, 10, 64); err == nil {
return protoreflect.MapKey(protoreflect.ValueOfInt64(intValue)), nil
}
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
if intValue, err := strconv.ParseUint(subscript, 10, 32); err == nil {
return protoreflect.MapKey(protoreflect.ValueOfUint32(uint32(intValue))), nil
}
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
if intValue, err := strconv.ParseUint(subscript, 10, 64); err == nil {
return protoreflect.MapKey(protoreflect.ValueOfUint64(intValue)), nil
}
case protoreflect.StringKind:
if stringValue, err := strconv.Unquote(subscript); err == nil {
return protoreflect.MapKey(protoreflect.ValueOfString(stringValue)), nil
}
case protoreflect.EnumKind, protoreflect.FloatKind, protoreflect.DoubleKind,
protoreflect.BytesKind, protoreflect.MessageKind, protoreflect.GroupKind:
fallthrough
default:
// This should not occur, but it might if the rules are relaxed in the
// future.
return protoreflect.MapKey{}, fmt.Errorf("unsupported map key type: %s", mapDescriptor.MapKey().Kind())
}
return protoreflect.MapKey{}, fmt.Errorf("invalid map key: %s", subscript)
}

// parsePathElement parses a single
func parsePathElement(path string) (name, subscript, rest string, atEnd bool, isExt bool) {
// Scan extension name.
if len(path) > 0 && path[0] == '[' {
if i := strings.IndexByte(path, ']'); i >= 0 {
isExt = true
name, path = path[1:i], path[i+1:]
}
}
// Scan field name.
if !isExt {
if i := strings.IndexAny(path, ".["); i >= 0 {
name, path = path[:i], path[i:]
} else {
name, path = path, ""
}
}
// No subscript: At end of path.
if len(path) == 0 {
return name, "", path, true, isExt
}
// No subscript: At end of path element.
if path[0] == '.' {
return name, "", path[1:], false, isExt
}
// Malformed subscript
if len(path) == 1 || path[1] == '.' {
name, path = name+path[:1], path[1:]
return name, "", path, true, isExt
}
switch path[1] {
case ']':
// Empty subscript
name, path = name+path[:2], path[2:]
case '`', '"', '\'':
// String subscript: must scan string.
var err error
subscript, err = strconv.QuotedPrefix(path[1:])
if err == nil {
path = path[len(subscript)+2:]
}
default:
// Other subscript; can skip to next ]
if i := strings.IndexByte(path, ']'); i >= 0 {
subscript, path = path[1:i], path[i+1:]
} else {
// Unterminated subscript
return name + path, "", "", true, isExt
}
}
// No subscript: At end of path.
if len(path) == 0 {
return name, subscript, path, true, isExt
}
// No subscript: At end of path element.
if path[0] == '.' {
return name, subscript, path[1:], false, isExt
}
// Malformed element
return name, subscript, path, false, isExt
}
Loading

0 comments on commit 0f4c41c

Please sign in to comment.