Skip to content

Commit

Permalink
Allow any file to make use of custom descriptor.proto (#97)
Browse files Browse the repository at this point in the history
Before this PR, if an override version of the descriptor protos/options
were in use, the file making use of them had to import
"google/protobuf/descriptor.proto". But this is counter-intuitive since
files typically do _not_ import that file, unless they have explicit
references to types therein or unless they define custom options. So
this PR allows override descriptors to be used, even for files that have
no such import statement.
  • Loading branch information
jhump authored Feb 23, 2023
1 parent 6df82ab commit 7c5114e
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 33 deletions.
74 changes: 68 additions & 6 deletions compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ type executor struct {
cancel context.CancelFunc
sym *linker.Symbols

descriptorProtoCheck sync.Once
descriptorProtoIsCustom bool

mu sync.Mutex
results map[string]*result
}
Expand Down Expand Up @@ -316,6 +319,18 @@ func (e errFailedToResolve) Unwrap() error {
return e.err
}

func (e *executor) hasOverrideDescriptorProto() bool {
e.descriptorProtoCheck.Do(func() {
defer func() {
// ignore a panic here; just assume no custom descriptor.proto
_ = recover()
}()
res, err := e.c.Resolver.FindFileByPath(descriptorProtoPath)
e.descriptorProtoIsCustom = err == nil && res.Desc != standardImports[descriptorProtoPath]
})
return e.descriptorProtoIsCustom
}

func (e *executor) doCompile(ctx context.Context, file string, r *result) {
t := task{e: e, h: e.h.SubHandler(), r: r}
if err := e.s.Acquire(ctx, 1); err != nil {
Expand All @@ -326,7 +341,7 @@ func (e *executor) doCompile(ctx context.Context, file string, r *result) {

sr, err := e.c.Resolver.FindFileByPath(file)
if err != nil {
r.fail(errFailedToResolve{err, file})
r.fail(errFailedToResolve{err: err, path: file})
return
}

Expand Down Expand Up @@ -371,6 +386,8 @@ func (t *task) release() {
}
}

const descriptorProtoPath = "google/protobuf/descriptor.proto"

func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.File, error) {
if r.Desc != nil {
if r.Desc.Path() != name {
Expand All @@ -385,12 +402,38 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
}

var deps []linker.File
if len(parseRes.FileDescriptorProto().Dependency) > 0 {
t.r.setBlockedOn(parseRes.FileDescriptorProto().Dependency)
fileDescriptorProto := parseRes.FileDescriptorProto()
var wantsDescriptorProto bool
imports := fileDescriptorProto.Dependency

if t.e.hasOverrideDescriptorProto() {
// we only consider implicitly including descriptor.proto if it's overridden
if name != descriptorProtoPath {
var includesDescriptorProto bool
for _, dep := range fileDescriptorProto.Dependency {
if dep == descriptorProtoPath {
includesDescriptorProto = true
break
}
}
if !includesDescriptorProto {
wantsDescriptorProto = true
// make a defensive copy so we don't inadvertently mutate
// slice's backing array when adding this implicit dep
importsCopy := make([]string, len(imports)+1)
copy(importsCopy, imports)
importsCopy[len(imports)] = descriptorProtoPath
imports = importsCopy
}
}
}

if len(imports) > 0 {
t.r.setBlockedOn(imports)

results := make([]*result, len(parseRes.FileDescriptorProto().Dependency))
results := make([]*result, len(fileDescriptorProto.Dependency))
checked := map[string]struct{}{}
for i, dep := range parseRes.FileDescriptorProto().Dependency {
for i, dep := range fileDescriptorProto.Dependency {
pos := findImportPos(parseRes, dep)
if name == dep {
// doh! file imports itself
Expand All @@ -405,7 +448,15 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
}
results[i] = res
}
deps = make([]linker.File, len(results))
capacity := len(results)
if wantsDescriptorProto {
capacity++
}
deps = make([]linker.File, len(results), capacity)
var descriptorProtoRes *result
if wantsDescriptorProto {
descriptorProtoRes = t.e.compile(ctx, descriptorProtoPath)
}

// release our semaphore so dependencies can be processed w/out risk of deadlock
t.e.s.Release(1)
Expand All @@ -430,6 +481,17 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
return nil, ctx.Err()
}
}
if descriptorProtoRes != nil {
select {
case <-descriptorProtoRes.ready:
// descriptor.proto wasn't explicitly imported, so we can ignore a failure
if descriptorProtoRes.err == nil {
deps = append(deps, descriptorProtoRes.res)
}
case <-ctx.Done():
return nil, ctx.Err()
}
}

// all deps resolved
t.r.setBlockedOn(nil)
Expand Down
19 changes: 16 additions & 3 deletions compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,17 @@ func TestParseFilesWithDependencies(t *testing.T) {
// Create a dependency-aware parser that should never be called.
compiler := Compiler{
Resolver: ResolverFunc(func(f string) (SearchResult, error) {
if f == "test.proto" {
switch f {
case "test.proto":
return SearchResult{Source: strings.NewReader(`syntax = "proto3";`)}, nil
case descriptorProtoPath:
// used to see if resolver provides custom descriptor.proto
return SearchResult{}, os.ErrNotExist
default:
// no other name should be passed to resolver
t.Errorf("resolver was called for unexpected filename %q", f)
return SearchResult{}, os.ErrNotExist
}
t.Errorf("resolved was called for unexpected filename %q", f)
return SearchResult{}, os.ErrNotExist
}),
}
_, err := compiler.Compile(ctx, "test.proto")
Expand Down Expand Up @@ -261,3 +267,10 @@ func TestPanicHandling(t *testing.T) {
require.True(t, ok)
t.Logf("%v\n\n%v", panicErr, panicErr.Stack)
}

func TestDescriptorProtoPath(t *testing.T) {
t.Parallel()
// sanity check our constant
path := (*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile().Path()
require.Equal(t, descriptorProtoPath, path)
}
18 changes: 13 additions & 5 deletions linker/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (

// File is like a super-powered protoreflect.FileDescriptor. It includes helpful
// methods for looking up elements in the descriptor and can be used to create a
// resolver for all in the file's transitive closure of dependencies. (See
// resolver for all of the file's transitive closure of dependencies. (See
// ResolverFromFile.)
type File interface {
protoreflect.FileDescriptor
Expand Down Expand Up @@ -88,6 +88,8 @@ func newFile(f protoreflect.FileDescriptor, deps Files) (File, error) {
// NewFileRecursive recursively converts a protoreflect.FileDescriptor to a File.
// If f has any dependencies/imports, they are converted, too, including any and
// all transitive dependencies.
//
// If f already implements File, it is returned unchanged.
func NewFileRecursive(f protoreflect.FileDescriptor) (File, error) {
if asFile, ok := f.(File); ok {
return asFile, nil
Expand Down Expand Up @@ -185,10 +187,16 @@ type Resolver interface {
protoregistry.ExtensionTypeResolver
}

// ResolverFromFile returns a Resolver that uses the given file plus its full
// set of transitive dependencies as the source of descriptors. If a given query
// cannot be answered with these files, the query will fail with a
// protoregistry.NotFound error.
// ResolverFromFile returns a Resolver that uses the given file plus all of its
// imports as the source of descriptors. If a given query cannot be answered with
// these files, the query will fail with a protoregistry.NotFound error. This
// does not recursively search the entire transitive closure; it only searches
// the given file and its immediate dependencies. This is useful for resolving
// elements visible to the file.
//
// If the given file is the result of a call to Link, then all dependencies
// provided in the call to Link are searched (which could actually include more
// than just the file's direct imports).
//
// Note that this function does not compute any additional indexes for efficient
// search, so queries generally take linear time, O(n) where n is the number of
Expand Down
6 changes: 6 additions & 0 deletions linker/linker.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ type Result interface {
// message that is available in this file. If no such element is available
// or if the named element is not a message, nil is returned.
ResolveMessageType(protoreflect.FullName) protoreflect.MessageDescriptor
// ResolveOptionsType returns a message descriptor for the given options
// type. This is like ResolveMessageType but searches the result's entire
// set of transitive dependencies without regard for visibility. If no
// such element is available or if the named element is not a message, nil
// is returned.
ResolveOptionsType(protoreflect.FullName) protoreflect.MessageDescriptor
// ResolveExtension returns an extension descriptor for the given named
// extension that is available in this file. If no such element is available
// or if the named element is not an extension, nil is returned.
Expand Down
16 changes: 16 additions & 0 deletions linker/linker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,22 @@ func TestLinkerValidation(t *testing.T) {
},
expectedErr: `foo.proto:14:23: option (bar): value -2147483649 is out of range for an enum`,
},
"success_custom_field_option": {
input: map[string]string{
"google/protobuf/descriptor.proto": `
syntax = "proto2";
package google.protobuf;
message FieldOptions {
optional string some_new_option = 11;
}`,
"bar.proto": `
syntax = "proto3";
package foo.bar.baz;
message Foo {
string bar = 1 [some_new_option="abc"];
}`,
},
},
}

for name, tc := range testCases {
Expand Down
21 changes: 19 additions & 2 deletions linker/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ func (r *result) ResolveMessageType(name protoreflect.FullName) protoreflect.Mes
return nil
}

func (r *result) ResolveOptionsType(name protoreflect.FullName) protoreflect.MessageDescriptor {
d, _ := ResolverFromFile(r).FindDescriptorByName(name)
md, _ := d.(protoreflect.MessageDescriptor)
if md != nil && md.ParentFile() != nil {
r.markUsed(md.ParentFile().Path())
}
return md
}

func (r *result) ResolveEnumType(name protoreflect.FullName) protoreflect.EnumDescriptor {
d := r.resolveElement(name)
if ed, ok := d.(protoreflect.EnumDescriptor); ok {
Expand Down Expand Up @@ -343,7 +352,7 @@ func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, s *Symbols,
}
} else {
// make sure tag is not a duplicate
if err := s.AddExtension(dsc.ParentFile().Package(), dsc.FullName(), tag, file.NodeInfo(node.FieldTag()).Start(), handler); err != nil {
if err := s.AddExtension(packageFor(dsc), dsc.FullName(), tag, file.NodeInfo(node.FieldTag()).Start(), handler); err != nil {
return err
}
}
Expand Down Expand Up @@ -402,7 +411,7 @@ func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, s *Symbols,
f.msgType = dsc
case protoreflect.EnumDescriptor:
proto3 := r.Syntax() == protoreflect.Proto3
enumIsProto3 := dsc.ParentFile().Syntax() == protoreflect.Proto3
enumIsProto3 := dsc.Syntax() == protoreflect.Proto3
if fld.GetExtendee() == "" && proto3 && !enumIsProto3 {
// fields in a proto3 message cannot refer to proto2 enums
return handler.HandleErrorf(file.NodeInfo(node.FieldType()).Start(), "%s: cannot use proto2 enum %s in a proto3 message", scope, fld.GetTypeName())
Expand All @@ -417,6 +426,14 @@ func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, s *Symbols,
return nil
}

func packageFor(dsc protoreflect.Descriptor) protoreflect.FullName {
if dsc.ParentFile() != nil {
return dsc.ParentFile().Package()
}
// Can't access package? Make a best effort guess.
return dsc.FullName().Parent()
}

func isValidMap(mapField protoreflect.FieldDescriptor, mapEntry protoreflect.MessageDescriptor) bool {
return !mapField.IsExtension() &&
mapEntry.Parent() == mapField.ContainingMessage() &&
Expand Down
34 changes: 22 additions & 12 deletions linker/symbols.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"github.com/bufbuild/protocompile/walk"
)

const unknownFilePath = "<unknown file>"

// Symbols is a symbol table that maps names for all program elements to their
// location in source. It also tracks extension tag numbers. This can be used
// to enforce uniqueness for symbol names and tag numbers across many files and
Expand Down Expand Up @@ -121,7 +123,7 @@ func (s *Symbols) importFileWithExtensions(pkg *packageSymbols, fd protoreflect.
}
pos := sourcePositionForNumber(fld)
extendee := fld.ContainingMessage()
if err := s.AddExtension(extendee.ParentFile().Package(), extendee.FullName(), fld.Number(), pos, handler); err != nil {
if err := s.AddExtension(packageFor(extendee), extendee.FullName(), fld.Number(), pos, handler); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -294,9 +296,13 @@ func sourcePositionForPackage(fd protoreflect.FileDescriptor) ast.SourcePos {
}

func sourcePositionFor(d protoreflect.Descriptor) ast.SourcePos {
file := d.ParentFile()
if file == nil {
return ast.UnknownPos(unknownFilePath)
}
path, ok := computePath(d)
if !ok {
return ast.UnknownPos(d.ParentFile().Path())
return ast.UnknownPos(file.Path())
}
namePath := path
switch d.(type) {
Expand All @@ -318,36 +324,40 @@ func sourcePositionFor(d protoreflect.Descriptor) ast.SourcePos {
// NB: shouldn't really happen, but just in case fall back to path to
// descriptor, sans name field
}
loc := d.ParentFile().SourceLocations().ByPath(namePath)
loc := file.SourceLocations().ByPath(namePath)
if isZeroLoc(loc) {
loc = d.ParentFile().SourceLocations().ByPath(path)
loc = file.SourceLocations().ByPath(path)
if isZeroLoc(loc) {
return ast.UnknownPos(d.ParentFile().Path())
return ast.UnknownPos(file.Path())
}
}
return ast.SourcePos{
Filename: d.ParentFile().Path(),
Filename: file.Path(),
Line: loc.StartLine,
Col: loc.StartColumn,
}
}

func sourcePositionForNumber(fd protoreflect.FieldDescriptor) ast.SourcePos {
file := fd.ParentFile()
if file == nil {
return ast.UnknownPos(unknownFilePath)
}
path, ok := computePath(fd)
if !ok {
return ast.UnknownPos(fd.ParentFile().Path())
return ast.UnknownPos(file.Path())
}
numberPath := path
numberPath = append(numberPath, internal.FieldNumberTag)
loc := fd.ParentFile().SourceLocations().ByPath(numberPath)
loc := file.SourceLocations().ByPath(numberPath)
if isZeroLoc(loc) {
loc = fd.ParentFile().SourceLocations().ByPath(path)
loc = file.SourceLocations().ByPath(path)
if isZeroLoc(loc) {
return ast.UnknownPos(fd.ParentFile().Path())
return ast.UnknownPos(file.Path())
}
}
return ast.SourcePos{
Filename: fd.ParentFile().Path(),
Filename: file.Path(),
Line: loc.StartLine,
Col: loc.StartColumn,
}
Expand Down Expand Up @@ -401,7 +411,7 @@ func (s *Symbols) importResultWithExtensions(pkg *packageSymbols, r *result, han
node := r.FieldNode(fd.FieldDescriptorProto())
pos := file.NodeInfo(node.FieldTag()).Start()
extendee := fd.ContainingMessage()
if err := s.AddExtension(extendee.ParentFile().Package(), extendee.FullName(), fd.Number(), pos, handler); err != nil {
if err := s.AddExtension(packageFor(extendee), extendee.FullName(), fd.Number(), pos, handler); err != nil {
return err
}

Expand Down
Loading

0 comments on commit 7c5114e

Please sign in to comment.