diff --git a/compiler.go b/compiler.go index 0bb7f182..4ad0b0b1 100644 --- a/compiler.go +++ b/compiler.go @@ -435,6 +435,7 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker. } } + var descriptorProtoRes *result if len(imports) > 0 { t.r.setBlockedOn(imports) @@ -455,12 +456,7 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker. } results[i] = res } - capacity := len(results) - if wantsDescriptorProto { - capacity++ - } - deps = make([]linker.File, len(results), capacity) - var descriptorProtoRes *result + deps = make([]linker.File, len(results)) if wantsDescriptorProto { descriptorProtoRes = t.e.compile(ctx, descriptorProtoPath) } @@ -488,17 +484,6 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker. return nil, ctx.Err() } } - if descriptorProtoRes != nil { - select { - case <-descriptorProtoRes.ready: - // descriptor.proto wasn't explicitly imported, so we can ignore a failure - if descriptorProtoRes.err == nil { - deps = append(deps, descriptorProtoRes.res) - } - case <-ctx.Done(): - return nil, ctx.Err() - } - } // all deps resolved t.r.setBlockedOn(nil) @@ -509,7 +494,7 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker. t.released = false } - return t.link(parseRes, deps) + return t.link(ctx, parseRes, deps, descriptorProtoRes) } func (e *executor) checkForDependencyCycle(res *result, sequence []string, pos ast.SourcePos, checked map[string]struct{}) error { @@ -568,12 +553,26 @@ func findImportPos(res parser.Result, dep string) ast.SourcePos { return ast.UnknownPos(res.FileNode().Name()) } -func (t *task) link(parseRes parser.Result, deps linker.Files) (linker.File, error) { +func (t *task) link(ctx context.Context, parseRes parser.Result, deps linker.Files, descriptorProtoRes *result) (linker.File, error) { file, err := linker.Link(parseRes, deps, t.e.sym, t.h) if err != nil { return nil, err } - optsIndex, err := options.InterpretOptions(file, t.h) + + var interpretOpts []options.InterpreterOption + if descriptorProtoRes != nil { + select { + case <-descriptorProtoRes.ready: + // descriptor.proto wasn't explicitly imported, so we can ignore a failure + if descriptorProtoRes.err == nil { + interpretOpts = []options.InterpreterOption{options.WithOverrideDescriptorProto(descriptorProtoRes.res)} + } + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + optsIndex, err := options.InterpretOptions(file, t.h, interpretOpts...) if err != nil { return nil, err } diff --git a/internal/benchmarks/go.mod b/internal/benchmarks/go.mod index 6233fa4d..f9814189 100644 --- a/internal/benchmarks/go.mod +++ b/internal/benchmarks/go.mod @@ -5,7 +5,7 @@ go 1.19 require ( github.com/bufbuild/protocompile v0.0.0-20221004230924-06a336f5b6be github.com/igrmk/treemap/v2 v2.0.1 - github.com/jhump/protoreflect v1.13.0 + github.com/jhump/protoreflect v1.14.1 github.com/stretchr/testify v1.8.0 google.golang.org/protobuf v1.28.2-0.20220831092852-f930b1dc76e8 ) diff --git a/internal/benchmarks/go.sum b/internal/benchmarks/go.sum index 2a1b0589..3df978ee 100644 --- a/internal/benchmarks/go.sum +++ b/internal/benchmarks/go.sum @@ -39,8 +39,8 @@ github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSl github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= github.com/jhump/protoreflect v1.11.0/go.mod h1:U7aMIjN0NWq9swDP7xDdoMfRHb35uiuTd3Z9nFXJf5E= -github.com/jhump/protoreflect v1.13.0 h1:zrrZqa7JAc2YGgPSzZZkmUXJ5G6NRPdxOg/9t7ISImA= -github.com/jhump/protoreflect v1.13.0/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI= +github.com/jhump/protoreflect v1.14.1 h1:N88q7JkxTHWFEqReuTsYH1dPIwXxA0ITNQp7avLY10s= +github.com/jhump/protoreflect v1.14.1/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/linker/descriptors.go b/linker/descriptors.go index 20968936..fe19bf3c 100644 --- a/linker/descriptors.go +++ b/linker/descriptors.go @@ -206,7 +206,7 @@ func asSourceLocations(srcInfoProtos []*descriptorpb.SourceCodeInfo_Location) [] func pathStr(p protoreflect.SourcePath) string { var buf bytes.Buffer for _, v := range p { - fmt.Fprintf(&buf, "%x:", v) + _, _ = fmt.Fprintf(&buf, "%x:", v) } return buf.String() } @@ -1869,10 +1869,6 @@ func (r *result) FindDescriptorByName(name protoreflect.FullName) protoreflect.D return r.descriptors[fqn] } -func (r *result) importsAsFiles() Files { - return r.deps -} - func (r *result) hasSource() bool { n := r.FileNode() _, ok := n.(*ast.FileNode) diff --git a/linker/files.go b/linker/files.go index 3acdcade..06c13491 100644 --- a/linker/files.go +++ b/linker/files.go @@ -28,7 +28,7 @@ import ( // File is like a super-powered protoreflect.FileDescriptor. It includes helpful // methods for looking up elements in the descriptor and can be used to create a -// resolver for all of the file's transitive closure of dependencies. (See +// resolver for all the file's transitive closure of dependencies. (See // ResolverFromFile.) type File interface { protoreflect.FileDescriptor @@ -42,10 +42,6 @@ type File interface { // that extends the given message name. If no such extension is defined in this // file, nil is returned. FindExtensionByNumber(message protoreflect.FullName, tag protoreflect.FieldNumber) protoreflect.ExtensionTypeDescriptor - // Imports returns this file's imports. These are only the files directly - // imported by the file. Indirect transitive dependencies will not be in - // the returned slice. - importsAsFiles() Files } // NewFile converts a protoreflect.FileDescriptor to a File. The given deps must @@ -147,10 +143,6 @@ func (f *file) FindExtensionByNumber(msg protoreflect.FullName, tag protoreflect return findExtension(f, msg, tag) } -func (f *file) importsAsFiles() Files { - return f.deps -} - var _ File = (*file)(nil) // Files represents a set of protobuf files. It is a slice of File values, but @@ -187,58 +179,53 @@ type Resolver interface { protoregistry.ExtensionTypeResolver } -// ResolverFromFile returns a Resolver that uses the given file plus all of its -// imports as the source of descriptors. If a given query cannot be answered with -// these files, the query will fail with a protoregistry.NotFound error. This -// does not recursively search the entire transitive closure; it only searches -// the given file and its immediate dependencies. This is useful for resolving -// elements visible to the file. -// -// If the given file is the result of a call to Link, then all dependencies -// provided in the call to Link are searched (which could actually include more -// than just the file's direct imports). +// ResolverFromFile returns a Resolver that can resolve any element that is +// visible to the given file. It will search the given file, its imports, and +// any transitive public imports. // // Note that this function does not compute any additional indexes for efficient // search, so queries generally take linear time, O(n) where n is the number of -// files in the transitive closure of the given file. Queries for an extension +// files whose elements are visible to the given file. Queries for an extension // by number are linear with the number of messages and extensions defined across -// all the files. +// those files. func ResolverFromFile(f File) Resolver { - return fileResolver{ - f: f, - deps: f.importsAsFiles().AsResolver(), - } + return fileResolver{f: f} } type fileResolver struct { - f File - deps Resolver + f File } func (r fileResolver) FindFileByPath(path string) (protoreflect.FileDescriptor, error) { - if r.f.Path() == path { - return r.f, nil - } - return r.deps.FindFileByPath(path) + return resolveInFile(r.f, false, nil, func(f File) (protoreflect.FileDescriptor, error) { + if f.Path() == path { + return f, nil + } + return nil, protoregistry.NotFound + }) } func (r fileResolver) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) { - d := r.f.FindDescriptorByName(name) - if d != nil { - return d, nil - } - return r.deps.FindDescriptorByName(name) + return resolveInFile(r.f, false, nil, func(f File) (protoreflect.Descriptor, error) { + if d := f.FindDescriptorByName(name); d != nil { + return d, nil + } + return nil, protoregistry.NotFound + }) } func (r fileResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) { - d := r.f.FindDescriptorByName(message) - if d != nil { - if md, ok := d.(protoreflect.MessageDescriptor); ok { + return resolveInFile(r.f, false, nil, func(f File) (protoreflect.MessageType, error) { + d := f.FindDescriptorByName(message) + if d != nil { + md, ok := d.(protoreflect.MessageDescriptor) + if !ok { + return nil, fmt.Errorf("%q is %s, not a message", message, descriptorTypeWithArticle(d)) + } return dynamicpb.NewMessageType(md), nil } return nil, protoregistry.NotFound - } - return r.deps.FindMessageByName(message) + }) } func (r fileResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) { @@ -248,35 +235,34 @@ func (r fileResolver) FindMessageByURL(url string) (protoreflect.MessageType, er func messageNameFromURL(url string) string { lastSlash := strings.LastIndexByte(url, '/') - var fullName string - if lastSlash >= 0 { - fullName = url[lastSlash+1:] - } else { - fullName = url - } - return fullName + return url[lastSlash+1:] } func (r fileResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { - d := r.f.FindDescriptorByName(field) - if d != nil { - if extd, ok := d.(protoreflect.ExtensionTypeDescriptor); ok { - return extd.Type(), nil - } - if fld, ok := d.(protoreflect.FieldDescriptor); ok && fld.IsExtension() { + return resolveInFile(r.f, false, nil, func(f File) (protoreflect.ExtensionType, error) { + d := f.FindDescriptorByName(field) + if d != nil { + fld, ok := d.(protoreflect.FieldDescriptor) + if !ok || !fld.IsExtension() { + return nil, fmt.Errorf("%q is %s, not an extension", field, descriptorTypeWithArticle(d)) + } + if extd, ok := fld.(protoreflect.ExtensionTypeDescriptor); ok { + return extd.Type(), nil + } return dynamicpb.NewExtensionType(fld), nil } return nil, protoregistry.NotFound - } - return r.deps.FindExtensionByName(field) + }) } func (r fileResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { - ext := findExtension(r.f, message, field) - if ext != nil { - return ext.Type(), nil - } - return r.deps.FindExtensionByNumber(message, field) + return resolveInFile(r.f, false, nil, func(f File) (protoreflect.ExtensionType, error) { + ext := findExtension(f, message, field) + if ext != nil { + return ext.Type(), nil + } + return nil, protoregistry.NotFound + }) } type filesResolver []File diff --git a/linker/linker.go b/linker/linker.go index d8d1c9a6..345c9bcf 100644 --- a/linker/linker.go +++ b/linker/linker.go @@ -106,24 +106,7 @@ func Link(parsed parser.Result, dependencies Files, symbols *Symbols, handler *r type Result interface { File parser.Result - // ResolveEnumType returns an enum descriptor for the given named enum that - // is available in this file. If no such element is available or if the - // named element is not an enum, nil is returned. - ResolveEnumType(protoreflect.FullName) protoreflect.EnumDescriptor - // ResolveMessageType returns a message descriptor for the given named - // message that is available in this file. If no such element is available - // or if the named element is not a message, nil is returned. - ResolveMessageType(protoreflect.FullName) protoreflect.MessageDescriptor - // ResolveOptionsType returns a message descriptor for the given options - // type. This is like ResolveMessageType but searches the result's entire - // set of transitive dependencies without regard for visibility. If no - // such element is available or if the named element is not a message, nil - // is returned. - ResolveOptionsType(protoreflect.FullName) protoreflect.MessageDescriptor - // ResolveExtension returns an extension descriptor for the given named - // extension that is available in this file. If no such element is available - // or if the named element is not an extension, nil is returned. - ResolveExtension(protoreflect.FullName) protoreflect.ExtensionTypeDescriptor + // ResolveMessageLiteralExtensionName returns the fully qualified name for // an identifier for extension field names in message literals. ResolveMessageLiteralExtensionName(ast.IdentValueNode) string diff --git a/linker/resolve.go b/linker/resolve.go index 485aba05..f79d1a53 100644 --- a/linker/resolve.go +++ b/linker/resolve.go @@ -15,13 +15,14 @@ package linker import ( + "errors" "fmt" "strings" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" - "google.golang.org/protobuf/types/dynamicpb" "github.com/bufbuild/protocompile/ast" "github.com/bufbuild/protocompile/internal" @@ -29,58 +30,65 @@ import ( "github.com/bufbuild/protocompile/walk" ) -func (r *result) ResolveMessageType(name protoreflect.FullName) protoreflect.MessageDescriptor { - d := r.resolveElement(name) - if md, ok := d.(protoreflect.MessageDescriptor); ok { - return md - } - return nil -} - -func (r *result) ResolveOptionsType(name protoreflect.FullName) protoreflect.MessageDescriptor { - d, _ := ResolverFromFile(r).FindDescriptorByName(name) - md, _ := d.(protoreflect.MessageDescriptor) - if md != nil && md.ParentFile() != nil { - r.markUsed(md.ParentFile().Path()) - } - return md +func (r *result) ResolveMessageLiteralExtensionName(node ast.IdentValueNode) string { + return r.optionQualifiedNames[node] } -func (r *result) ResolveEnumType(name protoreflect.FullName) protoreflect.EnumDescriptor { - d := r.resolveElement(name) - if ed, ok := d.(protoreflect.EnumDescriptor); ok { - return ed +func (r *result) resolveElement(name protoreflect.FullName) protoreflect.Descriptor { + if len(name) > 0 && name[0] == '.' { + name = name[1:] } - return nil + res, _ := resolveInFile(r, false, nil, func(f File) (protoreflect.Descriptor, error) { + d := resolveElementInFile(name, f) + if d != nil { + return d, nil + } + return nil, protoregistry.NotFound + }) + return res } -func (r *result) ResolveExtension(name protoreflect.FullName) protoreflect.ExtensionTypeDescriptor { - d := r.resolveElement(name) - if ed, ok := d.(protoreflect.ExtensionDescriptor); ok { - if !ed.IsExtension() { - return nil - } - if td, ok := ed.(protoreflect.ExtensionTypeDescriptor); ok { - return td +func resolveInFile[T any](f File, publicImportsOnly bool, checked []string, fn func(File) (T, error)) (T, error) { + var zero T + path := f.Path() + for _, str := range checked { + if str == path { + // already checked + return zero, protoregistry.NotFound } - return dynamicpb.NewExtensionType(ed).TypeDescriptor() } - return nil -} - -func (r *result) ResolveMessageLiteralExtensionName(node ast.IdentValueNode) string { - return r.optionQualifiedNames[node] -} + checked = append(checked, path) -func (r *result) resolveElement(name protoreflect.FullName) protoreflect.Descriptor { - if len(name) > 0 && name[0] == '.' { - name = name[1:] + res, err := fn(f) + if err == nil { + // found it + return res, nil } - importedFd, res := resolveElement(r, name, false, nil) - if importedFd != nil { - r.markUsed(importedFd.Path()) + if !errors.Is(err, protoregistry.NotFound) { + return zero, err } - return res + + imports := f.Imports() + for i, l := 0, imports.Len(); i < l; i++ { + imp := imports.Get(i) + if publicImportsOnly && !imp.IsPublic { + continue + } + res, err := resolveInFile(f.FindImportByPath(imp.Path()), true, checked, fn) + if errors.Is(err, protoregistry.NotFound) { + continue + } + if err != nil { + return zero, err + } + if !imp.IsPublic { + if r, ok := f.(*result); ok { + r.markUsed(imp.Path()) + } + } + return res, nil + } + return zero, err } func (r *result) markUsed(importPath string) { @@ -117,38 +125,6 @@ func (r *result) CheckForUnusedImports(handler *reporter.Handler) { } } -func resolveElement(f File, fqn protoreflect.FullName, publicImportsOnly bool, checked []string) (imported File, d protoreflect.Descriptor) { - path := f.Path() - for _, str := range checked { - if str == path { - // already checked - return nil, nil - } - } - checked = append(checked, path) - - r := resolveElementInFile(fqn, f) - if r != nil { - // not imported, but present in f - return nil, r - } - - // When publicImportsOnly = false, we are searching only directly imported symbols. But - // we also need to search transitive public imports due to semantics of public imports. - for i := 0; i < f.Imports().Len(); i++ { - dep := f.Imports().Get(i) - if dep.IsPublic || !publicImportsOnly { - depFile := f.FindImportByPath(dep.Path()) - _, d := resolveElement(depFile, fqn, true, checked) - if d != nil { - return depFile, d - } - } - } - - return nil, nil -} - func descriptorTypeWithArticle(d protoreflect.Descriptor) string { switch d := d.(type) { case protoreflect.MessageDescriptor: diff --git a/options/options.go b/options/options.go index c26f136c..a01d059a 100644 --- a/options/options.go +++ b/options/options.go @@ -27,6 +27,7 @@ package options import ( "bytes" + "errors" "fmt" "math" "sort" @@ -35,6 +36,7 @@ import ( "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/dynamicpb" @@ -52,20 +54,17 @@ import ( type Index map[*ast.OptionNode][]int32 type interpreter struct { - file file - resolver linker.Resolver - container optionsContainer - lenient bool - reporter *reporter.Handler - index Index + file file + resolver linker.Resolver + container optionsContainer + overrideDescriptorProto linker.File + lenient bool + reporter *reporter.Handler + index Index } type file interface { parser.Result - ResolveEnumType(protoreflect.FullName) protoreflect.EnumDescriptor - ResolveMessageType(protoreflect.FullName) protoreflect.MessageDescriptor - ResolveOptionsType(protoreflect.FullName) protoreflect.MessageDescriptor - ResolveExtension(protoreflect.FullName) protoreflect.ExtensionTypeDescriptor ResolveMessageLiteralExtensionName(ast.IdentValueNode) string } @@ -73,24 +72,23 @@ type noResolveFile struct { parser.Result } -func (n noResolveFile) ResolveEnumType(name protoreflect.FullName) protoreflect.EnumDescriptor { - return nil -} - -func (n noResolveFile) ResolveMessageType(name protoreflect.FullName) protoreflect.MessageDescriptor { - return nil +func (n noResolveFile) ResolveMessageLiteralExtensionName(ast.IdentValueNode) string { + return "" } -func (n noResolveFile) ResolveOptionsType(name protoreflect.FullName) protoreflect.MessageDescriptor { - return nil -} +// InterpreterOption is an option that can be passed to InterpretOptions and +// its variants. +type InterpreterOption func(*interpreter) -func (n noResolveFile) ResolveExtension(name protoreflect.FullName) protoreflect.ExtensionTypeDescriptor { - return nil -} - -func (n noResolveFile) ResolveMessageLiteralExtensionName(ast.IdentValueNode) string { - return "" +// WithOverrideDescriptorProto returns an option that indicates that the given file +// should be consulted when looking up a definition for an option type. The given +// file should usually have the path "google/protobuf/descriptor.proto". The given +// file will only be consulted if the option type is otherwise not visible to the +// file whose options are being interpreted. +func WithOverrideDescriptorProto(f linker.File) InterpreterOption { + return func(interp *interpreter) { + interp.overrideDescriptorProto = f + } } // InterpretOptions interprets options in the given linked result, returning @@ -100,8 +98,8 @@ func (n noResolveFile) ResolveMessageLiteralExtensionName(ast.IdentValueNode) st // // The given handler is used to report errors and warnings. If any errors are // reported, this function returns a non-nil error. -func InterpretOptions(linked linker.Result, handler *reporter.Handler) (Index, error) { - return interpretOptions(false, linked, handler) +func InterpretOptions(linked linker.Result, handler *reporter.Handler, opts ...InterpreterOption) (Index, error) { + return interpretOptions(false, linked, linker.ResolverFromFile(linked), handler, opts) } // InterpretOptionsLenient interprets options in a lenient/best-effort way in @@ -113,8 +111,8 @@ func InterpretOptions(linked linker.Result, handler *reporter.Handler) (Index, e // In lenient more, errors resolving option names and type errors are ignored. // Any options that are uninterpretable (due to such errors) will remain in the // "uninterpreted_option" fields. -func InterpretOptionsLenient(linked linker.Result) (Index, error) { - return interpretOptions(true, linked, reporter.NewHandler(nil)) +func InterpretOptionsLenient(linked linker.Result, opts ...InterpreterOption) (Index, error) { + return interpretOptions(true, linked, linker.ResolverFromFile(linked), reporter.NewHandler(nil), opts) } // InterpretUnlinkedOptions does a best-effort attempt to interpret options in @@ -128,20 +126,21 @@ func InterpretOptionsLenient(linked linker.Result) (Index, error) { // interpreted. Other errors resolving option names or type errors will be // effectively ignored. Any options that are uninterpretable (due to such // errors) will remain in the "uninterpreted_option" fields. -func InterpretUnlinkedOptions(parsed parser.Result) (Index, error) { - return interpretOptions(true, noResolveFile{parsed}, reporter.NewHandler(nil)) +func InterpretUnlinkedOptions(parsed parser.Result, opts ...InterpreterOption) (Index, error) { + return interpretOptions(true, noResolveFile{parsed}, nil, reporter.NewHandler(nil), opts) } -func interpretOptions(lenient bool, file file, handler *reporter.Handler) (Index, error) { +func interpretOptions(lenient bool, file file, res linker.Resolver, handler *reporter.Handler, interpOpts []InterpreterOption) (Index, error) { interp := interpreter{ file: file, + resolver: res, lenient: lenient, reporter: handler, index: Index{}, } interp.container, _ = file.(optionsContainer) - if f, ok := file.(linker.File); ok { - interp.resolver = linker.ResolverFromFile(f) + for _, opt := range interpOpts { + opt(&interp) } fd := file.FileDescriptorProto() @@ -202,6 +201,54 @@ func interpretOptions(lenient bool, file file, handler *reporter.Handler) (Index return interp.index, nil } +func resolveDescriptor[T protoreflect.Descriptor](res linker.Resolver, name string) T { + var zero T + if res == nil { + return zero + } + if len(name) > 0 && name[0] == '.' { + name = name[1:] + } + desc, _ := res.FindDescriptorByName(protoreflect.FullName(name)) + typedDesc, ok := desc.(T) + if ok { + return typedDesc + } + return zero +} + +func (interp *interpreter) resolveExtensionType(name string) (protoreflect.ExtensionTypeDescriptor, error) { + if interp.resolver == nil { + return nil, protoregistry.NotFound + } + if len(name) > 0 && name[0] == '.' { + name = name[1:] + } + ext, err := interp.resolver.FindExtensionByName(protoreflect.FullName(name)) + if err != nil { + return nil, err + } + return ext.TypeDescriptor(), nil +} + +func (interp *interpreter) resolveOptionsType(name string) protoreflect.MessageDescriptor { + md := resolveDescriptor[protoreflect.MessageDescriptor](interp.resolver, name) + if md != nil { + return md + } + if interp.overrideDescriptorProto == nil { + return nil + } + if len(name) > 0 && name[0] == '.' { + name = name[1:] + } + desc := interp.overrideDescriptorProto.FindDescriptorByName(protoreflect.FullName(name)) + if md, ok := desc.(protoreflect.MessageDescriptor); ok { + return md + } + return nil +} + func (interp *interpreter) nodeInfo(n ast.Node) ast.NodeInfo { return interp.file.FileNode().NodeInfo(n) } @@ -348,7 +395,7 @@ func (interp *interpreter) processDefaultOption(scope string, fqn string, fld *d } var v interface{} if fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_ENUM { - ed := interp.file.ResolveEnumType(protoreflect.FullName(fld.GetTypeName())) + ed := resolveDescriptor[protoreflect.EnumDescriptor](interp.resolver, fld.GetTypeName()) _, name, err := interp.enumFieldValue(mc, ed, val, false) if err != nil { return -1, interp.reporter.HandleError(err) @@ -689,7 +736,7 @@ func (interp *interpreter) interpretOptions(fqn string, element, opts proto.Mess optsFqn := string(optsDesc.FullName()) var msg protoreflect.Message // see if the parse included an override copy for these options - if md := interp.file.ResolveOptionsType(protoreflect.FullName(optsFqn)); md != nil { + if md := interp.resolveOptionsType(optsFqn); md != nil { dm := dynamicpb.NewMessage(md) if err := cloneInto(dm, opts, nil); err != nil { node := interp.file.Node(element) @@ -963,11 +1010,14 @@ func (interp *interpreter) interpretField(mc *internal.MessageContext, msg proto if extName[0] == '.' { extName = extName[1:] /* skip leading dot */ } - fld = interp.file.ResolveExtension(protoreflect.FullName(extName)) - if fld == nil { + var err error + fld, err = interp.resolveExtensionType(extName) + if errors.Is(err, protoregistry.NotFound) { return nil, interp.reporter.HandleErrorf(interp.nodeInfo(node).Start(), "%vunrecognized extension %s of %s", mc, extName, msg.Descriptor().FullName()) + } else if err != nil { + return nil, interp.reporter.HandleErrorWithPos(interp.nodeInfo(node).Start(), err) } if fld.ContainingMessage().FullName() != msg.Descriptor().FullName() { return nil, interp.reporter.HandleErrorf(interp.nodeInfo(node).Start(), @@ -1515,7 +1565,7 @@ func (interp *interpreter) messageLiteralValue(mc *internal.MessageContext, fiel if !ok { return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Val).Start(), "%vtype references for google.protobuf.Any must have message literal value", mc) } - anyMd := interp.file.ResolveMessageType(protoreflect.FullName(msgName)) + anyMd := resolveDescriptor[protoreflect.MessageDescriptor](interp.resolver, string(msgName)) if anyMd == nil { return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name.URLPrefix).Start(), "%vcould not resolve type reference %s", mc, fullURL) } @@ -1544,19 +1594,20 @@ func (interp *interpreter) messageLiteralValue(mc *internal.MessageContext, fiel fdm.Set(valueDescriptor, protoreflect.ValueOfBytes(b)) } else { var ffld protoreflect.FieldDescriptor + var err error if fieldNode.Name.IsExtension() { n := interp.file.ResolveMessageLiteralExtensionName(fieldNode.Name.Name) if n == "" { // this should not be possible! n = string(fieldNode.Name.Name.AsIdentifier()) } - ffld = interp.file.ResolveExtension(protoreflect.FullName(n)) - if ffld == nil { + ffld, err = interp.resolveExtensionType(n) + if errors.Is(err, protoregistry.NotFound) { // may need to qualify with package name // (this should not be necessary!) pkg := mc.File.FileDescriptorProto().GetPackage() if pkg != "" { - ffld = interp.file.ResolveExtension(protoreflect.FullName(pkg + "." + n)) + ffld, err = interp.resolveExtensionType(pkg + "." + n) } } } else { @@ -1569,19 +1620,24 @@ func (interp *interpreter) messageLiteralValue(mc *internal.MessageContext, fiel return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name).Start(), "%vfield %s not found (did you mean the group named %s?)", mc, fieldNode.Name.Value(), ffld.Message().Name()) } if ffld == nil { + err = protoregistry.NotFound // could be a group name for i := 0; i < fmd.Fields().Len(); i++ { fd := fmd.Fields().Get(i) if fd.Kind() == protoreflect.GroupKind && fd.Message().Name() == protoreflect.Name(fieldNode.Name.Value()) { // found it! ffld = fd + err = nil break } } } } - if ffld == nil { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name).Start(), "%vfield %s not found", mc, string(fieldNode.Name.Name.AsIdentifier())) + if errors.Is(err, protoregistry.NotFound) { + return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name).Start(), + "%vfield %s not found", mc, string(fieldNode.Name.Name.AsIdentifier())) + } else if err != nil { + return interpretedFieldValue{}, reporter.Error(interp.nodeInfo(fieldNode.Name).Start(), err) } if fieldNode.Sep == nil && ffld.Message() == nil { // If there is no separator, the field type should be a message. diff --git a/options/options_test.go b/options/options_test.go index fd17a5e5..bf11dbba 100644 --- a/options/options_test.go +++ b/options/options_test.go @@ -20,12 +20,14 @@ import ( "errors" "fmt" "os" + "sort" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/descriptorpb" "github.com/bufbuild/protocompile" @@ -39,15 +41,82 @@ import ( type ident string type aggregate string +func TestCustomOptionsAreKnown(t *testing.T) { + t.Parallel() + for _, withOverride := range []bool{false, true} { + withOverride := withOverride + name := "no overrides" + if withOverride { + name = "with override descriptor.proto" + } + t.Run(name, func(t *testing.T) { + t.Parallel() + sources := map[string]string{ + "test.proto": ` + syntax = "proto3"; + import "other.proto"; + option (string_option) = "abc"; + `, + "other.proto": ` + syntax = "proto3"; + import public "options.proto"; + `, + "options.proto": ` + syntax = "proto3"; + import "google/protobuf/descriptor.proto"; + extend google.protobuf.FileOptions { + string string_option = 10101; + } + `, + } + resolver := protocompile.Resolver(&protocompile.SourceResolver{ + Accessor: protocompile.SourceAccessorFromMap(sources), + }) + if withOverride { + sources["google/protobuf/descriptor.proto"] = ` + syntax = "proto2"; + package google.protobuf; + message FileOptions { + optional string foo = 1; + optional bool bar = 2; + optional int32 baz = 3; + extensions 1000 to max; + } + ` + } else { + resolver = protocompile.WithStandardImports(resolver) + } + compiler := &protocompile.Compiler{ + Resolver: resolver, + } + files, err := compiler.Compile(context.Background(), "test.proto") + require.NoError(t, err) + require.Equal(t, 1, len(files)) + var knownOptionNames []string + fileOptions := files[0].Options().ProtoReflect() + assert.Empty(t, fileOptions.GetUnknown()) + fileOptions.Range(func(fd protoreflect.FieldDescriptor, val protoreflect.Value) bool { + if fd.IsExtension() { + knownOptionNames = append(knownOptionNames, string(fd.FullName())) + } + return true + }) + sort.Strings(knownOptionNames) + assert.Equal(t, []string{"string_option"}, knownOptionNames) + }) + } +} + func TestOptionsInUnlinkedFiles(t *testing.T) { t.Parallel() testCases := []struct { + name string contents string uninterpreted map[string]interface{} checkInterpreted func(*testing.T, *descriptorpb.FileDescriptorProto) }{ { - // file options + name: "file options", contents: `option go_package = "foo.bar"; option (must.link) = "FOO";`, uninterpreted: map[string]interface{}{ "test.proto:(must.link)": "FOO", @@ -57,7 +126,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, }, { - // message options + name: "message options", contents: `message Test { option (must.link) = 1.234; option deprecated = true; }`, uninterpreted: map[string]interface{}{ "Test:(must.link)": 1.234, @@ -67,7 +136,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, }, { - // field options and pseudo-options + name: "field options", contents: `message Test { optional string uid = 1 [(must.link) = 10101, (must.link) = 20202, default = "fubar", json_name = "UID", deprecated = true]; }`, uninterpreted: map[string]interface{}{ "Test.uid:(must.link)": 10101, @@ -80,7 +149,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, }, { - // field where default is uninterpretable + name: "field options, default uninterpretable", contents: `enum TestEnum{ ZERO = 0; ONE = 1; } message Test { optional TestEnum uid = 1 [(must.link) = {foo: bar}, default = ONE, json_name = "UID", deprecated = true]; }`, uninterpreted: map[string]interface{}{ "Test.uid:(must.link)": aggregate("foo : bar"), @@ -92,7 +161,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, }, { - // one-of options + name: "oneof options", contents: `message Test { oneof x { option (must.link) = true; option deprecated = true; string uid = 1; uint64 nnn = 2; } }`, uninterpreted: map[string]interface{}{ "Test.x:(must.link)": ident("true"), @@ -100,7 +169,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, }, { - // extension range options + name: "extension range options", contents: `message Test { extensions 100 to 200 [(must.link) = "foo", deprecated = true]; }`, uninterpreted: map[string]interface{}{ "Test.100-200:(must.link)": "foo", @@ -108,7 +177,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, }, { - // enum options + name: "enum options", contents: `enum Test { option allow_alias = true; option deprecated = true; option (must.link) = 123.456; ZERO = 0; ZILCH = 0; }`, uninterpreted: map[string]interface{}{ "Test:(must.link)": 123.456, @@ -119,7 +188,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, }, { - // enum value options + name: "enum value options", contents: `enum Test { ZERO = 0 [deprecated = true, (must.link) = -222]; }`, uninterpreted: map[string]interface{}{ "Test.ZERO:(must.link)": -222, @@ -129,7 +198,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, }, { - // service options + name: "service options", contents: `service Test { option deprecated = true; option (must.link) = {foo:1, foo:2, bar:3}; }`, uninterpreted: map[string]interface{}{ "Test:(must.link)": aggregate("foo : 1 , foo : 2 , bar : 3"), @@ -139,7 +208,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, }, { - // method options + name: "method options", contents: `import "google/protobuf/empty.proto"; service Test { rpc Foo (google.protobuf.Empty) returns (google.protobuf.Empty) { option deprecated = true; option (must.link) = FOO; } }`, uninterpreted: map[string]interface{}{ "Test.Foo:(must.link)": ident("FOO"), @@ -150,26 +219,30 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { }, } - for i, tc := range testCases { - h := reporter.NewHandler(nil) - ast, err := parser.Parse("test.proto", strings.NewReader(tc.contents), h) - if !assert.Nil(t, err, "case #%d failed to parse", i) { - continue - } - res, err := parser.ResultFromAST(ast, true, h) - if !assert.Nil(t, err, "case #%d failed to produce descriptor proto", i) { - continue - } - _, err = options.InterpretUnlinkedOptions(res) - if !assert.Nil(t, err, "case #%d failed to interpret options", i) { - continue - } - actual := map[string]interface{}{} - buildUninterpretedMapForFile(res.FileDescriptorProto(), actual) - assert.Equal(t, tc.uninterpreted, actual, "case #%d resulted in wrong uninterpreted options", i) - if tc.checkInterpreted != nil { - tc.checkInterpreted(t, res.FileDescriptorProto()) - } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := reporter.NewHandler(nil) + ast, err := parser.Parse("test.proto", strings.NewReader(tc.contents), h) + if !assert.Nil(t, err, "failed to parse") { + return + } + res, err := parser.ResultFromAST(ast, true, h) + if !assert.Nil(t, err, "failed to produce descriptor proto") { + return + } + _, err = options.InterpretUnlinkedOptions(res) + if !assert.Nil(t, err, "failed to interpret options") { + return + } + actual := map[string]interface{}{} + buildUninterpretedMapForFile(res.FileDescriptorProto(), actual) + assert.Equal(t, tc.uninterpreted, actual, "resulted in wrong uninterpreted options") + if tc.checkInterpreted != nil { + tc.checkInterpreted(t, res.FileDescriptorProto()) + } + }) } }