Skip to content

Commit

Permalink
Resolve correct CEL types consistently for all field types (#80)
Browse files Browse the repository at this point in the history
The way we were resolving the CEL types for type-checking expressions
was inconsistent between custom and standard constraints. Resulting in
compilation errors (particularly around repeated fields). Logic shared
between the two (mostly lookups) was moved into the `expression`
internal package and used uniformly for both environments.

This also improves on a previously discovered bug around the Any WKT
where custom expressions against such a field would fail with a runtime
error if its underlying type was not known to CEL (CEL treats Any's as
the underlying type, instead of the Any message itself). The standard
constraints on Any do not have this limitation. We are populating the
root CEL environment with `protoregistry.GlobalFiles` for now, but will
likely make this configurable in the long-run.

Context:
bufbuild/protovalidate#92 (comment)
(h/t @matthewpi)
  • Loading branch information
rodaine authored Nov 17, 2023
1 parent 0fc6610 commit 5202cdc
Show file tree
Hide file tree
Showing 11 changed files with 424 additions and 170 deletions.
6 changes: 6 additions & 0 deletions celext/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/ext"
"google.golang.org/protobuf/reflect/protoregistry"
)

// DefaultEnv produces a cel.Env with the necessary cel.EnvOption and
Expand All @@ -37,6 +38,11 @@ import (
// of the local timezone.
func DefaultEnv(useUTC bool) (*cel.Env, error) {
return cel.NewEnv(
// we bind in the global type registry optimistically to ensure expressions
// operating against Any WKTs can resolve their underlying type if it's
// known to the application. They will otherwise fail with a runtime error
// if the type is unknown.
cel.TypeDescs(protoregistry.GlobalFiles),
cel.Lib(lib{
useUTC: useUTC,
}),
Expand Down
30 changes: 1 addition & 29 deletions internal/constraints/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (c *Cache) prepareEnvironment(
) (*cel.Env, error) {
env, err := env.Extend(
cel.Types(rules.Interface()),
cel.Variable("this", c.getCELType(fieldDesc, forItems)),
cel.Variable("this", expression.ProtoFieldToCELType(fieldDesc, true, forItems)),
cel.Variable("rules",
cel.ObjectType(string(rules.Descriptor().FullName()))),
)
Expand Down Expand Up @@ -168,31 +168,3 @@ func (c *Cache) getExpectedConstraintDescriptor(
return expected, ok
}
}

// getCELType resolves the CEL value type for the provided FieldDescriptor. If
// forItems is true, the type for the repeated list items is returned instead of
// the list type itself.
func (c *Cache) getCELType(fieldDesc protoreflect.FieldDescriptor, forItems bool) *cel.Type {
if !forItems {
switch {
case fieldDesc.IsMap():
return cel.MapType(cel.DynType, cel.DynType)
case fieldDesc.IsList():
return cel.ListType(cel.DynType)
}
}

if fieldDesc.Kind() == protoreflect.MessageKind {
switch fqn := fieldDesc.Message().FullName(); fqn {
case "google.protobuf.Any":
return cel.AnyType
case "google.protobuf.Duration":
return cel.DurationType
case "google.protobuf.Timestamp":
return cel.TimestampType
default:
return cel.ObjectType(string(fqn))
}
}
return ProtoKindToCELType(fieldDesc.Kind())
}
55 changes: 0 additions & 55 deletions internal/constraints/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"github.com/bufbuild/protovalidate-go/celext"
"github.com/bufbuild/protovalidate-go/internal/gen/buf/validate/conformance/cases"
"github.com/google/cel-go/cel"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -197,57 +196,3 @@ func TestCache_GetExpectedConstraintDescriptor(t *testing.T) {
})
}
}

func TestCache_GetCELType(t *testing.T) {
t.Parallel()

tests := []struct {
desc protoreflect.FieldDescriptor
forItems bool
ex *cel.Type
}{
{
desc: getFieldDesc(t, &cases.MapNone{}, "val"),
ex: cel.MapType(cel.DynType, cel.DynType),
},
{
desc: getFieldDesc(t, &cases.RepeatedNone{}, "val"),
ex: cel.ListType(cel.DynType),
},
{
desc: getFieldDesc(t, &cases.RepeatedNone{}, "val"),
forItems: true,
ex: cel.IntType,
},
{
desc: getFieldDesc(t, &cases.AnyNone{}, "val"),
ex: cel.AnyType,
},
{
desc: getFieldDesc(t, &cases.DurationNone{}, "val"),
ex: cel.DurationType,
},
{
desc: getFieldDesc(t, &cases.TimestampNone{}, "val"),
ex: cel.TimestampType,
},
{
desc: getFieldDesc(t, &cases.MessageNone{}, "val"),
ex: cel.ObjectType(string(((&cases.MessageNone{}).GetVal()).ProtoReflect().Descriptor().FullName())),
},
{
desc: getFieldDesc(t, &cases.Int32None{}, "val"),
ex: cel.IntType,
},
}

c := NewCache()
for _, tc := range tests {
test := tc
t.Run(string(test.desc.FullName()), func(t *testing.T) {
t.Parallel()
typ := c.getCELType(test.desc, test.forItems)
assert.Equal(t, test.ex.String(), typ.String())
})
}
}
38 changes: 0 additions & 38 deletions internal/constraints/lookups.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package constraints

import (
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"github.com/google/cel-go/cel"
"google.golang.org/protobuf/reflect/protoreflect"
)

Expand Down Expand Up @@ -91,40 +90,3 @@ func ExpectedWrapperConstraints(fqn protoreflect.FullName) (desc protoreflect.Fi
return nil, false
}
}

// ProtoKindToCELType maps a protoreflect.Kind to a compatible cel.Type.
func ProtoKindToCELType(kind protoreflect.Kind) *cel.Type {
switch kind {
case
protoreflect.FloatKind,
protoreflect.DoubleKind:
return cel.DoubleType
case
protoreflect.Int32Kind,
protoreflect.Int64Kind,
protoreflect.Sint32Kind,
protoreflect.Sint64Kind,
protoreflect.Sfixed32Kind,
protoreflect.Sfixed64Kind,
protoreflect.EnumKind:
return cel.IntType
case
protoreflect.Uint32Kind,
protoreflect.Uint64Kind,
protoreflect.Fixed32Kind,
protoreflect.Fixed64Kind:
return cel.UintType
case protoreflect.BoolKind:
return cel.BoolType
case protoreflect.StringKind:
return cel.StringType
case protoreflect.BytesKind:
return cel.BytesType
case
protoreflect.MessageKind,
protoreflect.GroupKind:
return cel.DynType
default:
return cel.DynType
}
}
3 changes: 2 additions & 1 deletion internal/constraints/lookups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package constraints
import (
"testing"

"github.com/bufbuild/protovalidate-go/internal/expression"
"github.com/google/cel-go/cel"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -83,7 +84,7 @@ func TestProtoKindToCELType(t *testing.T) {
kind, typ := k, ty
t.Run(kind.String(), func(t *testing.T) {
t.Parallel()
assert.Equal(t, typ, ProtoKindToCELType(kind))
assert.Equal(t, typ, expression.ProtoKindToCELType(kind))
})
}
}
18 changes: 6 additions & 12 deletions internal/evaluator/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,18 +277,12 @@ func (bldr *Builder) processFieldExpressions(
if len(exprs) == 0 {
return nil
}
var opts []cel.EnvOption
if fieldDesc.Kind() == protoreflect.MessageKind {
opts = []cel.EnvOption{
cel.Types(dynamicpb.NewMessage(fieldDesc.ContainingMessage())),
cel.Types(dynamicpb.NewMessage(fieldDesc.Message())),
cel.Variable("this", cel.ObjectType(string(fieldDesc.Message().FullName()))),
}
} else {
opts = []cel.EnvOption{
cel.Variable("this", constraints.ProtoKindToCELType(fieldDesc.Kind())),
}
}

celTyp := expression.ProtoFieldToCELType(fieldDesc, false, false)
opts := append(
expression.RequiredCELEnvOptions(fieldDesc),
cel.Variable("this", celTyp),
)
compiledExpressions, err := expression.Compile(exprs, bldr.env, opts...)
if err != nil {
return err
Expand Down
113 changes: 113 additions & 0 deletions internal/expression/lookups.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
"github.com/google/cel-go/cel"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"
)

// ProtoKindToCELType maps a protoreflect.Kind to a compatible cel.Type.
func ProtoKindToCELType(kind protoreflect.Kind) *cel.Type {
switch kind {
case
protoreflect.FloatKind,
protoreflect.DoubleKind:
return cel.DoubleType
case
protoreflect.Int32Kind,
protoreflect.Int64Kind,
protoreflect.Sint32Kind,
protoreflect.Sint64Kind,
protoreflect.Sfixed32Kind,
protoreflect.Sfixed64Kind,
protoreflect.EnumKind:
return cel.IntType
case
protoreflect.Uint32Kind,
protoreflect.Uint64Kind,
protoreflect.Fixed32Kind,
protoreflect.Fixed64Kind:
return cel.UintType
case protoreflect.BoolKind:
return cel.BoolType
case protoreflect.StringKind:
return cel.StringType
case protoreflect.BytesKind:
return cel.BytesType
case
protoreflect.MessageKind,
protoreflect.GroupKind:
return cel.DynType
default:
return cel.DynType
}
}

// ProtoFieldToCELType resolves the CEL value type for the provided
// FieldDescriptor. If generic is true, the specific subtypes of map and
// repeated fields will be replaced with cel.DynType. If forItems is true, the
// type for the repeated list items is returned instead of the list type itself.
func ProtoFieldToCELType(fieldDesc protoreflect.FieldDescriptor, generic, forItems bool) *cel.Type {
if !forItems {
switch {
case fieldDesc.IsMap():
if generic {
return cel.MapType(cel.DynType, cel.DynType)
}
keyType := ProtoFieldToCELType(fieldDesc.MapKey(), false, true)
valType := ProtoFieldToCELType(fieldDesc.MapValue(), false, true)
return cel.MapType(keyType, valType)
case fieldDesc.IsList():
if generic {
return cel.ListType(cel.DynType)
}
itemType := ProtoFieldToCELType(fieldDesc, false, true)
return cel.ListType(itemType)
}
}

if fieldDesc.Kind() == protoreflect.MessageKind {
switch fqn := fieldDesc.Message().FullName(); fqn {
case "google.protobuf.Any":
return cel.AnyType
case "google.protobuf.Duration":
return cel.DurationType
case "google.protobuf.Timestamp":
return cel.TimestampType
default:
return cel.ObjectType(string(fqn))
}
}
return ProtoKindToCELType(fieldDesc.Kind())
}

// RequiredCELEnvOptions returns the options required to have expressions which
// rely on the provided descriptor.
func RequiredCELEnvOptions(fieldDesc protoreflect.FieldDescriptor) []cel.EnvOption {
if fieldDesc.IsMap() {
return append(
RequiredCELEnvOptions(fieldDesc.MapKey()),
RequiredCELEnvOptions(fieldDesc.MapValue())...,
)
}
if fieldDesc.Kind() == protoreflect.MessageKind {
return []cel.EnvOption{
cel.Types(dynamicpb.NewMessage(fieldDesc.Message())),
}
}
return nil
}
Loading

0 comments on commit 5202cdc

Please sign in to comment.