From ec219b3c15fbd507493664eaba615cf590e141c1 Mon Sep 17 00:00:00 2001 From: Joshua Humphries Date: Wed, 23 May 2018 14:13:24 -0400 Subject: [PATCH] allow the use of proto source files directly, instead of having to compile to protoset files (#32) --- cmd/grpcurl/grpcurl.go | 43 ++++++++++++-- grpcurl.go | 47 +++++++++++++++ grpcurl_test.go | 131 +++++++++++++++++++++++++---------------- 3 files changed, 164 insertions(+), 57 deletions(-) diff --git a/cmd/grpcurl/grpcurl.go b/cmd/grpcurl/grpcurl.go index 29e8032..618f4af 100644 --- a/cmd/grpcurl/grpcurl.go +++ b/cmd/grpcurl/grpcurl.go @@ -52,6 +52,8 @@ var ( `File containing client private key, to present to the server. Not valid with -plaintext option. Must also provide -cert option.`) protoset multiString + protoFiles multiString + importPaths multiString addlHeaders multiString rpcHeaders multiString reflHeaders multiString @@ -102,7 +104,26 @@ func init() { 'list' action lists the services found in the given descriptors (vs. those exposed by the remote server), and the 'describe' action describes symbols found in the given descriptors. May specify more than one via - multiple -protoset flags.`) + multiple -protoset flags. It is an error to use both -protoset and + -proto flags.`) + flag.Var(&protoFiles, "proto", + `The name of a proto source file. Source files given will be used to + determine the RPC schema instead of querying for it from the remote + server via the GRPC reflection API. When set: the 'list' action lists + the services found in the given files and their imports (vs. those + exposed by the remote server), and the 'describe' action describes + symbols found in the given files. May specify more than one via + multiple -proto flags. Imports will be resolved using the given + -import-path flags. Multiple proto files can be specified by specifying + multiple -proto flags. It is an error to use both -protoset and -proto + flags.`) + flag.Var(&importPaths, "import-path", + `The path to a directory from which proto sources can be imported, + for use with -proto flags. Multiple import paths can be configured by + specifying multiple -import-path flags. Paths will be searched in the + order given. If no import paths are given, all files (including all + imports) must be provided as -proto flags, and grpcurl will attempt to + resolve all import statements from the set of file names given.`) } type multiString []string @@ -189,11 +210,17 @@ func main() { if invoke && target == "" { fail(nil, "No host:port specified.") } - if len(protoset) == 0 && target == "" { - fail(nil, "No host:port specified and no protoset specified.") + if len(protoset) == 0 && len(protoFiles) == 0 && target == "" { + fail(nil, "No host:port specified, no protoset specified, and no proto sources specified.") } if len(protoset) > 0 && len(reflHeaders) > 0 { - warn("The -reflect-header argument is not used when -protoset files are used ") + warn("The -reflect-header argument is not used when -protoset files are used.") + } + if len(protoset) > 0 && len(protoFiles) > 0 { + fail(nil, "Use either -protoset files or -proto files, but not both.") + } + if len(importPaths) > 0 && len(protoFiles) == 0 { + warn("The -import-path argument is not used unless -proto files are used.") } ctx := context.Background() @@ -260,7 +287,13 @@ func main() { var err error descSource, err = grpcurl.DescriptorSourceFromProtoSets(protoset...) if err != nil { - fail(err, "Failed to process proto descriptor sets") + fail(err, "Failed to process proto descriptor sets.") + } + } else if len(protoFiles) > 0 { + var err error + descSource, err = grpcurl.DescriptorSourceFromProtoFiles(importPaths, protoFiles...) + if err != nil { + fail(err, "Failed to process proto source files.") } } else { md := grpcurl.MetadataFromHeaders(append(addlHeaders, reflHeaders...)) diff --git a/grpcurl.go b/grpcurl.go index 3cd8f8f..474ff7c 100644 --- a/grpcurl.go +++ b/grpcurl.go @@ -26,6 +26,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/desc/protoparse" "github.com/jhump/protoreflect/dynamic" "github.com/jhump/protoreflect/dynamic/grpcdynamic" "github.com/jhump/protoreflect/grpcreflect" @@ -74,6 +75,21 @@ func DescriptorSourceFromProtoSets(fileNames ...string) (DescriptorSource, error return DescriptorSourceFromFileDescriptorSet(files) } +// DescriptorSourceFromProtoFiles creates a DescriptorSource that is backed by the named files, +// whose contents are Protocol Buffer source files. The given importPaths are used to locate +// any imported files. +func DescriptorSourceFromProtoFiles(importPaths []string, fileNames ...string) (DescriptorSource, error) { + p := protoparse.Parser{ + ImportPaths: importPaths, + InferImportPaths: len(importPaths) == 0, + } + fds, err := p.ParseFiles(fileNames...) + if err != nil { + return nil, fmt.Errorf("could not parse given files: %v", err) + } + return DescriptorSourceFromFileDescriptors(fds...) +} + // DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the FileDescriptorSet. func DescriptorSourceFromFileDescriptorSet(files *descriptor.FileDescriptorSet) (DescriptorSource, error) { unresolved := map[string]*descriptor.FileDescriptorProto{} @@ -114,6 +130,37 @@ func resolveFileDescriptor(unresolved map[string]*descriptor.FileDescriptorProto return result, nil } +// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the given +// file descriptors +func DescriptorSourceFromFileDescriptors(files ...*desc.FileDescriptor) (DescriptorSource, error) { + fds := map[string]*desc.FileDescriptor{} + for _, fd := range files { + if err := addFile(fd, fds); err != nil { + return nil, err + } + } + return &fileSource{files: fds}, nil +} + +func addFile(fd *desc.FileDescriptor, fds map[string]*desc.FileDescriptor) error { + name := fd.GetName() + if existing, ok := fds[name]; ok { + // already added this file + if existing != fd { + // doh! duplicate files provided + return fmt.Errorf("given files include multiple copies of %q", name) + } + return nil + } + fds[name] = fd + for _, dep := range fd.GetDependencies() { + if err := addFile(dep, fds); err != nil { + return err + } + } + return nil +} + type fileSource struct { files map[string]*desc.FileDescriptor er *dynamic.ExtensionRegistry diff --git a/grpcurl_test.go b/grpcurl_test.go index 453fcb2..5c300ce 100644 --- a/grpcurl_test.go +++ b/grpcurl_test.go @@ -28,19 +28,32 @@ import ( ) var ( - sourceProtoset DescriptorSource - ccProtoset *grpc.ClientConn + sourceProtoset DescriptorSource + sourceProtoFiles DescriptorSource + ccNoReflect *grpc.ClientConn sourceReflect DescriptorSource ccReflect *grpc.ClientConn + + descSources []descSourceCase ) +type descSourceCase struct { + name string + source DescriptorSource + includeRefl bool +} + func TestMain(m *testing.M) { var err error sourceProtoset, err = DescriptorSourceFromProtoSets("testing/test.protoset") if err != nil { panic(err) } + sourceProtoFiles, err = DescriptorSourceFromProtoFiles(nil, "../../../google.golang.org/grpc/interop/grpc_testing/test.proto") + if err != nil { + panic(err) + } // Create a server that includes the reflection service svrReflect := grpc.NewServer() @@ -83,17 +96,23 @@ func TestMain(m *testing.M) { // And a corresponding client ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if ccProtoset, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portProtoset), + if ccNoReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portProtoset), grpc.WithInsecure(), grpc.WithBlock()); err != nil { panic(err) } - defer ccProtoset.Close() + defer ccNoReflect.Close() + + descSources = []descSourceCase{ + {"protoset", sourceProtoset, false}, + {"proto", sourceProtoFiles, false}, + {"reflect", sourceReflect, true}, + } os.Exit(m.Run()) } func TestServerDoesNotSupportReflection(t *testing.T) { - refClient := grpcreflect.NewClient(context.Background(), reflectpb.NewServerReflectionClient(ccProtoset)) + refClient := grpcreflect.NewClient(context.Background(), reflectpb.NewServerReflectionClient(ccNoReflect)) defer refClient.Reset() refSource := DescriptorSourceFromServer(context.Background(), refClient) @@ -108,7 +127,7 @@ func TestServerDoesNotSupportReflection(t *testing.T) { t.Errorf("ListMethods should have returned ErrReflectionNotSupported; instead got %v", err) } - err = InvokeRpc(context.Background(), refSource, ccProtoset, "FooService/Method", nil, nil, nil) + err = InvokeRpc(context.Background(), refSource, ccNoReflect, "FooService/Method", nil, nil, nil) // InvokeRpc wraps the error, so we just verify the returned error includes the right message if err == nil || !strings.Contains(err.Error(), ErrReflectionNotSupported.Error()) { t.Errorf("InvokeRpc should have returned ErrReflectionNotSupported; instead got %v", err) @@ -137,12 +156,12 @@ func TestProtosetWithImports(t *testing.T) { } } -func TestListServicesProtoset(t *testing.T) { - doTestListServices(t, sourceProtoset, false) -} - -func TestListServicesReflect(t *testing.T) { - doTestListServices(t, sourceReflect, true) +func TestListServices(t *testing.T) { + for _, ds := range descSources { + t.Run(ds.name, func(t *testing.T) { + doTestListServices(t, ds.source, ds.includeRefl) + }) + } } func doTestListServices(t *testing.T, source DescriptorSource, includeReflection bool) { @@ -164,12 +183,12 @@ func doTestListServices(t *testing.T, source DescriptorSource, includeReflection } } -func TestListMethodsProtoset(t *testing.T) { - doTestListMethods(t, sourceProtoset, false) -} - -func TestListMethodsReflect(t *testing.T) { - doTestListMethods(t, sourceReflect, true) +func TestListMethods(t *testing.T) { + for _, ds := range descSources { + t.Run(ds.name, func(t *testing.T) { + doTestListMethods(t, ds.source, ds.includeRefl) + }) + } } func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection bool) { @@ -216,12 +235,12 @@ func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection } } -func TestDescribeProtoset(t *testing.T) { - doTestDescribe(t, sourceProtoset) -} - -func TestDescribeReflect(t *testing.T) { - doTestDescribe(t, sourceReflect) +func TestDescribe(t *testing.T) { + for _, ds := range descSources { + t.Run(ds.name, func(t *testing.T) { + doTestDescribe(t, ds.source) + }) + } } func doTestDescribe(t *testing.T, source DescriptorSource) { @@ -296,12 +315,20 @@ const ( }` ) -func TestUnaryProtoset(t *testing.T) { - doTestUnary(t, ccProtoset, sourceProtoset) +func getCC(includeRefl bool) *grpc.ClientConn { + if includeRefl { + return ccReflect + } else { + return ccNoReflect + } } -func TestUnaryReflect(t *testing.T) { - doTestUnary(t, ccReflect, sourceReflect) +func TestUnary(t *testing.T) { + for _, ds := range descSources { + t.Run(ds.name, func(t *testing.T) { + doTestUnary(t, getCC(ds.includeRefl), ds.source) + }) + } } func doTestUnary(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { @@ -328,12 +355,12 @@ func doTestUnary(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { h.check(t, "grpc.testing.TestService.UnaryCall", codes.NotFound, 1, 0) } -func TestClientStreamProtoset(t *testing.T) { - doTestClientStream(t, ccProtoset, sourceProtoset) -} - -func TestClientStreamReflect(t *testing.T) { - doTestClientStream(t, ccReflect, sourceReflect) +func TestClientStream(t *testing.T) { + for _, ds := range descSources { + t.Run(ds.name, func(t *testing.T) { + doTestClientStream(t, getCC(ds.includeRefl), ds.source) + }) + } } func doTestClientStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { @@ -373,12 +400,12 @@ func doTestClientStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSour h.check(t, "grpc.testing.TestService.StreamingInputCall", codes.Internal, 3, 0) } -func TestServerStreamProtoset(t *testing.T) { - doTestServerStream(t, ccProtoset, sourceProtoset) -} - -func TestServerStreamReflect(t *testing.T) { - doTestServerStream(t, ccReflect, sourceReflect) +func TestServerStream(t *testing.T) { + for _, ds := range descSources { + t.Run(ds.name, func(t *testing.T) { + doTestServerStream(t, getCC(ds.includeRefl), ds.source) + }) + } } func doTestServerStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { @@ -435,12 +462,12 @@ func doTestServerStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSour h.check(t, "grpc.testing.TestService.StreamingOutputCall", codes.AlreadyExists, 1, 5) } -func TestHalfDuplexStreamProtoset(t *testing.T) { - doTestHalfDuplexStream(t, ccProtoset, sourceProtoset) -} - -func TestHalfDuplexStreamReflect(t *testing.T) { - doTestHalfDuplexStream(t, ccReflect, sourceReflect) +func TestHalfDuplexStream(t *testing.T) { + for _, ds := range descSources { + t.Run(ds.name, func(t *testing.T) { + doTestHalfDuplexStream(t, getCC(ds.includeRefl), ds.source) + }) + } } func doTestHalfDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { @@ -480,12 +507,12 @@ func doTestHalfDuplexStream(t *testing.T, cc *grpc.ClientConn, source Descriptor h.check(t, "grpc.testing.TestService.HalfDuplexCall", codes.DataLoss, 3, 3) } -func TestFullDuplexStreamProtoset(t *testing.T) { - doTestFullDuplexStream(t, ccProtoset, sourceProtoset) -} - -func TestFullDuplexStreamReflect(t *testing.T) { - doTestFullDuplexStream(t, ccReflect, sourceReflect) +func TestFullDuplexStream(t *testing.T) { + for _, ds := range descSources { + t.Run(ds.name, func(t *testing.T) { + doTestFullDuplexStream(t, getCC(ds.includeRefl), ds.source) + }) + } } func doTestFullDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {