diff --git a/cmd/grpcurl/grpcurl.go b/cmd/grpcurl/grpcurl.go index cc5cfa1..70712a6 100644 --- a/cmd/grpcurl/grpcurl.go +++ b/cmd/grpcurl/grpcurl.go @@ -9,6 +9,7 @@ import ( "io" "os" "path/filepath" + "strconv" "strings" "time" @@ -126,6 +127,7 @@ var ( an error to use both -authority and -servername (though this will be permitted if they are both set to the same value, to increase backwards compatibility with earlier releases that allowed both to be set).`)) + reflection = optionalBoolFlag{val: true} ) func init() { @@ -168,6 +170,12 @@ func init() { order given. If no import paths are given, all files (including all imports) must be provided as -proto flags, and grpcurl will attempt to resolve all import statements from the set of file names given.`)) + flags.Var(&reflection, "use-reflection", prettify(` + When true, server reflection will be used to determine the RPC schema. + Defaults to true unless a -proto or -protoset option is provided. If + -use-reflection is used in combination with a -proto or -protoset flag, + the provided descriptor sources will be used in addition to server + reflection to resolve messages and extensions.`)) } type multiString []string @@ -181,6 +189,49 @@ func (s *multiString) Set(value string) error { return nil } +// Uses a file source as a fallback for resolving symbols and extensions, but +// only uses the reflection source for listing services +type compositeSource struct { + reflection grpcurl.DescriptorSource + file grpcurl.DescriptorSource +} + +func (cs compositeSource) ListServices() ([]string, error) { + return cs.reflection.ListServices() +} + +func (cs compositeSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) { + d, err := cs.reflection.FindSymbol(fullyQualifiedName) + if err == nil { + return d, nil + } + return cs.file.FindSymbol(fullyQualifiedName) +} + +func (cs compositeSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) { + exts, err := cs.reflection.AllExtensionsForType(typeName) + if err != nil { + // On error fall back to file source + return cs.file.AllExtensionsForType(typeName) + } + // Track the tag numbers from the reflection source + tags := make(map[int32]bool) + for _, ext := range exts { + tags[ext.GetNumber()] = true + } + fileExts, err := cs.file.AllExtensionsForType(typeName) + if err != nil { + return exts, nil + } + for _, ext := range fileExts { + // Prioritize extensions found via reflection + if !tags[ext.GetNumber()] { + exts = append(exts, ext) + } + } + return exts, nil +} + func main() { flags.Usage = usage flags.Parse(os.Args[1:]) @@ -288,6 +339,14 @@ func main() { if len(importPaths) > 0 && len(protoFiles) == 0 { warn("The -import-path argument is not used unless -proto files are used.") } + if !reflection.val && len(protoset) == 0 && len(protoFiles) == 0 { + fail(nil, "No protoset files or proto files specified and -use-reflection set to false.") + } + + // Protoset or protofiles provided and -use-reflection unset + if !reflection.set && (len(protoset) > 0 || len(protoFiles) > 0) { + reflection.val = false + } ctx := context.Background() if *maxTime > 0 { @@ -372,24 +431,33 @@ func main() { var cc *grpc.ClientConn var descSource grpcurl.DescriptorSource var refClient *grpcreflect.Client + var fileSource grpcurl.DescriptorSource if len(protoset) > 0 { var err error - descSource, err = grpcurl.DescriptorSourceFromProtoSets(protoset...) + fileSource, err = grpcurl.DescriptorSourceFromProtoSets(protoset...) if err != nil { fail(err, "Failed to process proto descriptor sets.") } } else if len(protoFiles) > 0 { var err error - descSource, err = grpcurl.DescriptorSourceFromProtoFiles(importPaths, protoFiles...) + fileSource, err = grpcurl.DescriptorSourceFromProtoFiles(importPaths, protoFiles...) if err != nil { fail(err, "Failed to process proto source files.") } - } else { + } + if reflection.val { md := grpcurl.MetadataFromHeaders(append(addlHeaders, reflHeaders...)) refCtx := metadata.NewOutgoingContext(ctx, md) cc = dial() refClient = grpcreflect.NewClient(refCtx, reflectpb.NewServerReflectionClient(cc)) - descSource = grpcurl.DescriptorSourceFromServer(ctx, refClient) + reflSource := grpcurl.DescriptorSourceFromServer(ctx, refClient) + if fileSource != nil { + descSource = compositeSource{reflSource, fileSource} + } else { + descSource = reflSource + } + } else { + descSource = fileSource } // arrange for the RPCs to be cleanly shutdown @@ -672,3 +740,28 @@ func writeProtoset(descSource grpcurl.DescriptorSource, symbols ...string) error defer f.Close() return grpcurl.WriteProtoset(f, descSource, symbols...) } + +type optionalBoolFlag struct { + set, val bool +} + +func (f *optionalBoolFlag) String() string { + if !f.set { + return "unset" + } + return strconv.FormatBool(f.val) +} + +func (f *optionalBoolFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + f.set = true + f.val = v + return nil +} + +func (f *optionalBoolFlag) IsBoolFlag() bool { + return true +}