Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve correct CEL types consistently for all field types #80

Merged
merged 3 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions celext/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/ext"
"google.golang.org/protobuf/reflect/protoregistry"
)

// DefaultEnv produces a cel.Env with the necessary cel.EnvOption and
Expand All @@ -37,6 +38,11 @@ import (
// of the local timezone.
func DefaultEnv(useUTC bool) (*cel.Env, error) {
return cel.NewEnv(
// we bind in the global type registry optimistically to ensure expressions
// operating against Any WKTs can resolve their underlying type if it's
// known to the application. They will otherwise fail with a runtime error
// if the type is unknown.
cel.TypeDescs(protoregistry.GlobalFiles),
cel.Lib(lib{
useUTC: useUTC,
}),
Expand Down
30 changes: 1 addition & 29 deletions internal/constraints/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (c *Cache) prepareEnvironment(
) (*cel.Env, error) {
env, err := env.Extend(
cel.Types(rules.Interface()),
cel.Variable("this", c.getCELType(fieldDesc, forItems)),
cel.Variable("this", expression.ProtoFieldToCELType(fieldDesc, true, forItems)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own knowledge, why is generic set to true here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are compiling the standard constraint expressions the first time we see them, and we use a concrete field it's attached to as the basis for type information. repeated and map fields are generic over their elements, so when we build (say) the unique expression for RepeatedRules, we cannot assume that the element type of the list is always going to be that of the field we're using since another repeated field of a different type might also specify unique.

cel.Variable("rules",
cel.ObjectType(string(rules.Descriptor().FullName()))),
)
Expand Down Expand Up @@ -168,31 +168,3 @@ func (c *Cache) getExpectedConstraintDescriptor(
return expected, ok
}
}

// getCELType resolves the CEL value type for the provided FieldDescriptor. If
// forItems is true, the type for the repeated list items is returned instead of
// the list type itself.
func (c *Cache) getCELType(fieldDesc protoreflect.FieldDescriptor, forItems bool) *cel.Type {
if !forItems {
switch {
case fieldDesc.IsMap():
return cel.MapType(cel.DynType, cel.DynType)
case fieldDesc.IsList():
return cel.ListType(cel.DynType)
}
}

if fieldDesc.Kind() == protoreflect.MessageKind {
switch fqn := fieldDesc.Message().FullName(); fqn {
case "google.protobuf.Any":
return cel.AnyType
case "google.protobuf.Duration":
return cel.DurationType
case "google.protobuf.Timestamp":
return cel.TimestampType
default:
return cel.ObjectType(string(fqn))
}
}
return ProtoKindToCELType(fieldDesc.Kind())
}
55 changes: 0 additions & 55 deletions internal/constraints/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"github.com/bufbuild/protovalidate-go/celext"
"github.com/bufbuild/protovalidate-go/internal/gen/buf/validate/conformance/cases"
"github.com/google/cel-go/cel"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -197,57 +196,3 @@ func TestCache_GetExpectedConstraintDescriptor(t *testing.T) {
})
}
}

func TestCache_GetCELType(t *testing.T) {
t.Parallel()

tests := []struct {
desc protoreflect.FieldDescriptor
forItems bool
ex *cel.Type
}{
{
desc: getFieldDesc(t, &cases.MapNone{}, "val"),
ex: cel.MapType(cel.DynType, cel.DynType),
},
{
desc: getFieldDesc(t, &cases.RepeatedNone{}, "val"),
ex: cel.ListType(cel.DynType),
},
{
desc: getFieldDesc(t, &cases.RepeatedNone{}, "val"),
forItems: true,
ex: cel.IntType,
},
{
desc: getFieldDesc(t, &cases.AnyNone{}, "val"),
ex: cel.AnyType,
},
{
desc: getFieldDesc(t, &cases.DurationNone{}, "val"),
ex: cel.DurationType,
},
{
desc: getFieldDesc(t, &cases.TimestampNone{}, "val"),
ex: cel.TimestampType,
},
{
desc: getFieldDesc(t, &cases.MessageNone{}, "val"),
ex: cel.ObjectType(string(((&cases.MessageNone{}).GetVal()).ProtoReflect().Descriptor().FullName())),
},
{
desc: getFieldDesc(t, &cases.Int32None{}, "val"),
ex: cel.IntType,
},
}

c := NewCache()
for _, tc := range tests {
test := tc
t.Run(string(test.desc.FullName()), func(t *testing.T) {
t.Parallel()
typ := c.getCELType(test.desc, test.forItems)
assert.Equal(t, test.ex.String(), typ.String())
})
}
}
38 changes: 0 additions & 38 deletions internal/constraints/lookups.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package constraints

import (
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"github.com/google/cel-go/cel"
"google.golang.org/protobuf/reflect/protoreflect"
)

Expand Down Expand Up @@ -91,40 +90,3 @@ func ExpectedWrapperConstraints(fqn protoreflect.FullName) (desc protoreflect.Fi
return nil, false
}
}

// ProtoKindToCELType maps a protoreflect.Kind to a compatible cel.Type.
func ProtoKindToCELType(kind protoreflect.Kind) *cel.Type {
switch kind {
case
protoreflect.FloatKind,
protoreflect.DoubleKind:
return cel.DoubleType
case
protoreflect.Int32Kind,
protoreflect.Int64Kind,
protoreflect.Sint32Kind,
protoreflect.Sint64Kind,
protoreflect.Sfixed32Kind,
protoreflect.Sfixed64Kind,
protoreflect.EnumKind:
return cel.IntType
case
protoreflect.Uint32Kind,
protoreflect.Uint64Kind,
protoreflect.Fixed32Kind,
protoreflect.Fixed64Kind:
return cel.UintType
case protoreflect.BoolKind:
return cel.BoolType
case protoreflect.StringKind:
return cel.StringType
case protoreflect.BytesKind:
return cel.BytesType
case
protoreflect.MessageKind,
protoreflect.GroupKind:
return cel.DynType
default:
return cel.DynType
}
}
3 changes: 2 additions & 1 deletion internal/constraints/lookups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package constraints
import (
"testing"

"github.com/bufbuild/protovalidate-go/internal/expression"
"github.com/google/cel-go/cel"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -83,7 +84,7 @@ func TestProtoKindToCELType(t *testing.T) {
kind, typ := k, ty
t.Run(kind.String(), func(t *testing.T) {
t.Parallel()
assert.Equal(t, typ, ProtoKindToCELType(kind))
assert.Equal(t, typ, expression.ProtoKindToCELType(kind))
})
}
}
18 changes: 6 additions & 12 deletions internal/evaluator/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,18 +277,12 @@ func (bldr *Builder) processFieldExpressions(
if len(exprs) == 0 {
return nil
}
var opts []cel.EnvOption
if fieldDesc.Kind() == protoreflect.MessageKind {
opts = []cel.EnvOption{
cel.Types(dynamicpb.NewMessage(fieldDesc.ContainingMessage())),
cel.Types(dynamicpb.NewMessage(fieldDesc.Message())),
cel.Variable("this", cel.ObjectType(string(fieldDesc.Message().FullName()))),
}
} else {
opts = []cel.EnvOption{
cel.Variable("this", constraints.ProtoKindToCELType(fieldDesc.Kind())),
}
}

celTyp := expression.ProtoFieldToCELType(fieldDesc, false, false)
opts := append(
expression.RequiredCELEnvOptions(fieldDesc),
cel.Variable("this", celTyp),
)
compiledExpressions, err := expression.Compile(exprs, bldr.env, opts...)
if err != nil {
return err
Expand Down
113 changes: 113 additions & 0 deletions internal/expression/lookups.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
"github.com/google/cel-go/cel"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"
)

// ProtoKindToCELType maps a protoreflect.Kind to a compatible cel.Type.
func ProtoKindToCELType(kind protoreflect.Kind) *cel.Type {
switch kind {
case
protoreflect.FloatKind,
protoreflect.DoubleKind:
return cel.DoubleType
case
protoreflect.Int32Kind,
protoreflect.Int64Kind,
protoreflect.Sint32Kind,
protoreflect.Sint64Kind,
protoreflect.Sfixed32Kind,
protoreflect.Sfixed64Kind,
protoreflect.EnumKind:
return cel.IntType
case
protoreflect.Uint32Kind,
protoreflect.Uint64Kind,
protoreflect.Fixed32Kind,
protoreflect.Fixed64Kind:
return cel.UintType
case protoreflect.BoolKind:
return cel.BoolType
case protoreflect.StringKind:
return cel.StringType
case protoreflect.BytesKind:
return cel.BytesType
case
protoreflect.MessageKind,
protoreflect.GroupKind:
return cel.DynType
default:
return cel.DynType
}
}

// ProtoFieldToCELType resolves the CEL value type for the provided
// FieldDescriptor. If generic is true, the specific subtypes of map and
// repeated fields will be replaced with cel.DynType. If forItems is true, the
// type for the repeated list items is returned instead of the list type itself.
func ProtoFieldToCELType(fieldDesc protoreflect.FieldDescriptor, generic, forItems bool) *cel.Type {
if !forItems {
switch {
case fieldDesc.IsMap():
if generic {
return cel.MapType(cel.DynType, cel.DynType)
}
keyType := ProtoFieldToCELType(fieldDesc.MapKey(), false, true)
valType := ProtoFieldToCELType(fieldDesc.MapValue(), false, true)
return cel.MapType(keyType, valType)
case fieldDesc.IsList():
if generic {
return cel.ListType(cel.DynType)
}
itemType := ProtoFieldToCELType(fieldDesc, false, true)
return cel.ListType(itemType)
}
}

if fieldDesc.Kind() == protoreflect.MessageKind {
switch fqn := fieldDesc.Message().FullName(); fqn {
case "google.protobuf.Any":
return cel.AnyType
case "google.protobuf.Duration":
return cel.DurationType
case "google.protobuf.Timestamp":
return cel.TimestampType
default:
return cel.ObjectType(string(fqn))
}
}
return ProtoKindToCELType(fieldDesc.Kind())
}

// RequiredCELEnvOptions returns the options required to have expressions which
// rely on the provided descriptor.
func RequiredCELEnvOptions(fieldDesc protoreflect.FieldDescriptor) []cel.EnvOption {
if fieldDesc.IsMap() {
return append(
RequiredCELEnvOptions(fieldDesc.MapKey()),
RequiredCELEnvOptions(fieldDesc.MapValue())...,
)
}
if fieldDesc.Kind() == protoreflect.MessageKind {
return []cel.EnvOption{
cel.Types(dynamicpb.NewMessage(fieldDesc.Message())),
}
}
return nil
}
Loading