From 9a4bbacdd6c0c295d93e72874c3cd1c46f4f7c18 Mon Sep 17 00:00:00 2001 From: Joshua Humphries Date: Thu, 18 Oct 2018 23:51:38 -0400 Subject: [PATCH] some pre-factoring and small fixes (#58) * organize into multiple files * make listing methods show fully-qualified names * address small feedback from recent change (trim then check if empty) --- cmd/grpcurl/grpcurl.go | 3 +- desc_source.go | 248 +++++++++++++++++ grpcurl.go | 605 +---------------------------------------- grpcurl_test.go | 16 +- invoke.go | 385 ++++++++++++++++++++++++++ 5 files changed, 644 insertions(+), 613 deletions(-) create mode 100644 desc_source.go create mode 100644 invoke.go diff --git a/cmd/grpcurl/grpcurl.go b/cmd/grpcurl/grpcurl.go index 6aca975..ed12176 100644 --- a/cmd/grpcurl/grpcurl.go +++ b/cmd/grpcurl/grpcurl.go @@ -566,10 +566,11 @@ func prettify(docString string) string { // from each line in the doc string j := 0 for _, part := range parts { + part = strings.TrimSpace(part) if part == "" { continue } - parts[j] = strings.TrimSpace(part) + parts[j] = part j++ } diff --git a/desc_source.go b/desc_source.go new file mode 100644 index 0000000..167980f --- /dev/null +++ b/desc_source.go @@ -0,0 +1,248 @@ +package grpcurl + +import ( + "errors" + "fmt" + "io/ioutil" + "sync" + + "github.com/golang/protobuf/proto" + descpb "github.com/golang/protobuf/protoc-gen-go/descriptor" + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/desc/protoparse" + "github.com/jhump/protoreflect/dynamic" + "github.com/jhump/protoreflect/grpcreflect" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// ErrReflectionNotSupported is returned by DescriptorSource operations that +// rely on interacting with the reflection service when the source does not +// actually expose the reflection service. When this occurs, an alternate source +// (like file descriptor sets) must be used. +var ErrReflectionNotSupported = errors.New("server does not support the reflection API") + +// DescriptorSource is a source of protobuf descriptor information. It can be backed by a FileDescriptorSet +// proto (like a file generated by protoc) or a remote server that supports the reflection API. +type DescriptorSource interface { + // ListServices returns a list of fully-qualified service names. It will be all services in a set of + // descriptor files or the set of all services exposed by a gRPC server. + ListServices() ([]string, error) + // FindSymbol returns a descriptor for the given fully-qualified symbol name. + FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) + // AllExtensionsForType returns all known extension fields that extend the given message type name. + AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) +} + +// DescriptorSourceFromProtoSets creates a DescriptorSource that is backed by the named files, whose contents +// are encoded FileDescriptorSet protos. +func DescriptorSourceFromProtoSets(fileNames ...string) (DescriptorSource, error) { + files := &descpb.FileDescriptorSet{} + for _, fileName := range fileNames { + b, err := ioutil.ReadFile(fileName) + if err != nil { + return nil, fmt.Errorf("could not load protoset file %q: %v", fileName, err) + } + var fs descpb.FileDescriptorSet + err = proto.Unmarshal(b, &fs) + if err != nil { + return nil, fmt.Errorf("could not parse contents of protoset file %q: %v", fileName, err) + } + files.File = append(files.File, fs.File...) + } + return DescriptorSourceFromFileDescriptorSet(files) +} + +// DescriptorSourceFromProtoFiles creates a DescriptorSource that is backed by the named files, +// whose contents are Protocol Buffer source files. The given importPaths are used to locate +// any imported files. +func DescriptorSourceFromProtoFiles(importPaths []string, fileNames ...string) (DescriptorSource, error) { + p := protoparse.Parser{ + ImportPaths: importPaths, + InferImportPaths: len(importPaths) == 0, + } + fds, err := p.ParseFiles(fileNames...) + if err != nil { + return nil, fmt.Errorf("could not parse given files: %v", err) + } + return DescriptorSourceFromFileDescriptors(fds...) +} + +// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the FileDescriptorSet. +func DescriptorSourceFromFileDescriptorSet(files *descpb.FileDescriptorSet) (DescriptorSource, error) { + unresolved := map[string]*descpb.FileDescriptorProto{} + for _, fd := range files.File { + unresolved[fd.GetName()] = fd + } + resolved := map[string]*desc.FileDescriptor{} + for _, fd := range files.File { + _, err := resolveFileDescriptor(unresolved, resolved, fd.GetName()) + if err != nil { + return nil, err + } + } + return &fileSource{files: resolved}, nil +} + +func resolveFileDescriptor(unresolved map[string]*descpb.FileDescriptorProto, resolved map[string]*desc.FileDescriptor, filename string) (*desc.FileDescriptor, error) { + if r, ok := resolved[filename]; ok { + return r, nil + } + fd, ok := unresolved[filename] + if !ok { + return nil, fmt.Errorf("no descriptor found for %q", filename) + } + deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency())) + for _, dep := range fd.GetDependency() { + depFd, err := resolveFileDescriptor(unresolved, resolved, dep) + if err != nil { + return nil, err + } + deps = append(deps, depFd) + } + result, err := desc.CreateFileDescriptor(fd, deps...) + if err != nil { + return nil, err + } + resolved[filename] = result + return result, nil +} + +// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the given +// file descriptors +func DescriptorSourceFromFileDescriptors(files ...*desc.FileDescriptor) (DescriptorSource, error) { + fds := map[string]*desc.FileDescriptor{} + for _, fd := range files { + if err := addFile(fd, fds); err != nil { + return nil, err + } + } + return &fileSource{files: fds}, nil +} + +func addFile(fd *desc.FileDescriptor, fds map[string]*desc.FileDescriptor) error { + name := fd.GetName() + if existing, ok := fds[name]; ok { + // already added this file + if existing != fd { + // doh! duplicate files provided + return fmt.Errorf("given files include multiple copies of %q", name) + } + return nil + } + fds[name] = fd + for _, dep := range fd.GetDependencies() { + if err := addFile(dep, fds); err != nil { + return err + } + } + return nil +} + +type fileSource struct { + files map[string]*desc.FileDescriptor + er *dynamic.ExtensionRegistry + erInit sync.Once +} + +func (fs *fileSource) ListServices() ([]string, error) { + set := map[string]bool{} + for _, fd := range fs.files { + for _, svc := range fd.GetServices() { + set[svc.GetFullyQualifiedName()] = true + } + } + sl := make([]string, 0, len(set)) + for svc := range set { + sl = append(sl, svc) + } + return sl, nil +} + +// GetAllFiles returns all of the underlying file descriptors. This is +// more thorough and more efficient than the fallback strategy used by +// the GetAllFiles package method, for enumerating all files from a +// descriptor source. +func (fs *fileSource) GetAllFiles() ([]*desc.FileDescriptor, error) { + files := make([]*desc.FileDescriptor, len(fs.files)) + i := 0 + for _, fd := range fs.files { + files[i] = fd + i++ + } + return files, nil +} + +func (fs *fileSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) { + for _, fd := range fs.files { + if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil { + return dsc, nil + } + } + return nil, notFound("Symbol", fullyQualifiedName) +} + +func (fs *fileSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) { + fs.erInit.Do(func() { + fs.er = &dynamic.ExtensionRegistry{} + for _, fd := range fs.files { + fs.er.AddExtensionsFromFile(fd) + } + }) + return fs.er.AllExtensionsForType(typeName), nil +} + +// DescriptorSourceFromServer creates a DescriptorSource that uses the given gRPC reflection client +// to interrogate a server for descriptor information. If the server does not support the reflection +// API then the various DescriptorSource methods will return ErrReflectionNotSupported +func DescriptorSourceFromServer(_ context.Context, refClient *grpcreflect.Client) DescriptorSource { + return serverSource{client: refClient} +} + +type serverSource struct { + client *grpcreflect.Client +} + +func (ss serverSource) ListServices() ([]string, error) { + svcs, err := ss.client.ListServices() + return svcs, reflectionSupport(err) +} + +func (ss serverSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) { + file, err := ss.client.FileContainingSymbol(fullyQualifiedName) + if err != nil { + return nil, reflectionSupport(err) + } + d := file.FindSymbol(fullyQualifiedName) + if d == nil { + return nil, notFound("Symbol", fullyQualifiedName) + } + return d, nil +} + +func (ss serverSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) { + var exts []*desc.FieldDescriptor + nums, err := ss.client.AllExtensionNumbersForType(typeName) + if err != nil { + return nil, reflectionSupport(err) + } + for _, fieldNum := range nums { + ext, err := ss.client.ResolveExtension(typeName, fieldNum) + if err != nil { + return nil, reflectionSupport(err) + } + exts = append(exts, ext) + } + return exts, nil +} + +func reflectionSupport(err error) error { + if err == nil { + return nil + } + if stat, ok := status.FromError(err); ok && stat.Code() == codes.Unimplemented { + return ErrReflectionNotSupported + } + return err +} diff --git a/grpcurl.go b/grpcurl.go index 4360be4..e221f1f 100644 --- a/grpcurl.go +++ b/grpcurl.go @@ -13,262 +13,22 @@ import ( "encoding/base64" "errors" "fmt" - "io" "io/ioutil" "net" "sort" "strings" - "sync" - "sync/atomic" "time" - "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" - descpb "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/jhump/protoreflect/desc" - "github.com/jhump/protoreflect/desc/protoparse" "github.com/jhump/protoreflect/desc/protoprint" "github.com/jhump/protoreflect/dynamic" - "github.com/jhump/protoreflect/dynamic/grpcdynamic" - "github.com/jhump/protoreflect/grpcreflect" "golang.org/x/net/context" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" ) -// ErrReflectionNotSupported is returned by DescriptorSource operations that -// rely on interacting with the reflection service when the source does not -// actually expose the reflection service. When this occurs, an alternate source -// (like file descriptor sets) must be used. -var ErrReflectionNotSupported = errors.New("server does not support the reflection API") - -// DescriptorSource is a source of protobuf descriptor information. It can be backed by a FileDescriptorSet -// proto (like a file generated by protoc) or a remote server that supports the reflection API. -type DescriptorSource interface { - // ListServices returns a list of fully-qualified service names. It will be all services in a set of - // descriptor files or the set of all services exposed by a gRPC server. - ListServices() ([]string, error) - // FindSymbol returns a descriptor for the given fully-qualified symbol name. - FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) - // AllExtensionsForType returns all known extension fields that extend the given message type name. - AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) -} - -// DescriptorSourceFromProtoSets creates a DescriptorSource that is backed by the named files, whose contents -// are encoded FileDescriptorSet protos. -func DescriptorSourceFromProtoSets(fileNames ...string) (DescriptorSource, error) { - files := &descpb.FileDescriptorSet{} - for _, fileName := range fileNames { - b, err := ioutil.ReadFile(fileName) - if err != nil { - return nil, fmt.Errorf("could not load protoset file %q: %v", fileName, err) - } - var fs descpb.FileDescriptorSet - err = proto.Unmarshal(b, &fs) - if err != nil { - return nil, fmt.Errorf("could not parse contents of protoset file %q: %v", fileName, err) - } - files.File = append(files.File, fs.File...) - } - return DescriptorSourceFromFileDescriptorSet(files) -} - -// DescriptorSourceFromProtoFiles creates a DescriptorSource that is backed by the named files, -// whose contents are Protocol Buffer source files. The given importPaths are used to locate -// any imported files. -func DescriptorSourceFromProtoFiles(importPaths []string, fileNames ...string) (DescriptorSource, error) { - p := protoparse.Parser{ - ImportPaths: importPaths, - InferImportPaths: len(importPaths) == 0, - } - fds, err := p.ParseFiles(fileNames...) - if err != nil { - return nil, fmt.Errorf("could not parse given files: %v", err) - } - return DescriptorSourceFromFileDescriptors(fds...) -} - -// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the FileDescriptorSet. -func DescriptorSourceFromFileDescriptorSet(files *descpb.FileDescriptorSet) (DescriptorSource, error) { - unresolved := map[string]*descpb.FileDescriptorProto{} - for _, fd := range files.File { - unresolved[fd.GetName()] = fd - } - resolved := map[string]*desc.FileDescriptor{} - for _, fd := range files.File { - _, err := resolveFileDescriptor(unresolved, resolved, fd.GetName()) - if err != nil { - return nil, err - } - } - return &fileSource{files: resolved}, nil -} - -func resolveFileDescriptor(unresolved map[string]*descpb.FileDescriptorProto, resolved map[string]*desc.FileDescriptor, filename string) (*desc.FileDescriptor, error) { - if r, ok := resolved[filename]; ok { - return r, nil - } - fd, ok := unresolved[filename] - if !ok { - return nil, fmt.Errorf("no descriptor found for %q", filename) - } - deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency())) - for _, dep := range fd.GetDependency() { - depFd, err := resolveFileDescriptor(unresolved, resolved, dep) - if err != nil { - return nil, err - } - deps = append(deps, depFd) - } - result, err := desc.CreateFileDescriptor(fd, deps...) - if err != nil { - return nil, err - } - resolved[filename] = result - return result, nil -} - -// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the given -// file descriptors -func DescriptorSourceFromFileDescriptors(files ...*desc.FileDescriptor) (DescriptorSource, error) { - fds := map[string]*desc.FileDescriptor{} - for _, fd := range files { - if err := addFile(fd, fds); err != nil { - return nil, err - } - } - return &fileSource{files: fds}, nil -} - -func addFile(fd *desc.FileDescriptor, fds map[string]*desc.FileDescriptor) error { - name := fd.GetName() - if existing, ok := fds[name]; ok { - // already added this file - if existing != fd { - // doh! duplicate files provided - return fmt.Errorf("given files include multiple copies of %q", name) - } - return nil - } - fds[name] = fd - for _, dep := range fd.GetDependencies() { - if err := addFile(dep, fds); err != nil { - return err - } - } - return nil -} - -type fileSource struct { - files map[string]*desc.FileDescriptor - er *dynamic.ExtensionRegistry - erInit sync.Once -} - -func (fs *fileSource) ListServices() ([]string, error) { - set := map[string]bool{} - for _, fd := range fs.files { - for _, svc := range fd.GetServices() { - set[svc.GetFullyQualifiedName()] = true - } - } - sl := make([]string, 0, len(set)) - for svc := range set { - sl = append(sl, svc) - } - return sl, nil -} - -// GetAllFiles returns all of the underlying file descriptors. This is -// more thorough and more efficient than the fallback strategy used by -// the GetAllFiles package method, for enumerating all files from a -// descriptor source. -func (fs *fileSource) GetAllFiles() ([]*desc.FileDescriptor, error) { - files := make([]*desc.FileDescriptor, len(fs.files)) - i := 0 - for _, fd := range fs.files { - files[i] = fd - i++ - } - return files, nil -} - -func (fs *fileSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) { - for _, fd := range fs.files { - if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil { - return dsc, nil - } - } - return nil, notFound("Symbol", fullyQualifiedName) -} - -func (fs *fileSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) { - fs.erInit.Do(func() { - fs.er = &dynamic.ExtensionRegistry{} - for _, fd := range fs.files { - fs.er.AddExtensionsFromFile(fd) - } - }) - return fs.er.AllExtensionsForType(typeName), nil -} - -// DescriptorSourceFromServer creates a DescriptorSource that uses the given gRPC reflection client -// to interrogate a server for descriptor information. If the server does not support the reflection -// API then the various DescriptorSource methods will return ErrReflectionNotSupported -func DescriptorSourceFromServer(ctx context.Context, refClient *grpcreflect.Client) DescriptorSource { - return serverSource{client: refClient} -} - -type serverSource struct { - client *grpcreflect.Client -} - -func (ss serverSource) ListServices() ([]string, error) { - svcs, err := ss.client.ListServices() - return svcs, reflectionSupport(err) -} - -func (ss serverSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) { - file, err := ss.client.FileContainingSymbol(fullyQualifiedName) - if err != nil { - return nil, reflectionSupport(err) - } - d := file.FindSymbol(fullyQualifiedName) - if d == nil { - return nil, notFound("Symbol", fullyQualifiedName) - } - return d, nil -} - -func (ss serverSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) { - var exts []*desc.FieldDescriptor - nums, err := ss.client.AllExtensionNumbersForType(typeName) - if err != nil { - return nil, reflectionSupport(err) - } - for _, fieldNum := range nums { - ext, err := ss.client.ResolveExtension(typeName, fieldNum) - if err != nil { - return nil, reflectionSupport(err) - } - exts = append(exts, ext) - } - return exts, nil -} - -func reflectionSupport(err error) error { - if err == nil { - return nil - } - if stat, ok := status.FromError(err); ok && stat.Code() == codes.Unimplemented { - return ErrReflectionNotSupported - } - return err -} - // ListServices uses the given descriptor source to return a sorted list of fully-qualified // service names. func ListServices(source DescriptorSource) ([]string, error) { @@ -361,365 +121,13 @@ func ListMethods(source DescriptorSource, serviceName string) ([]string, error) } else { methods := make([]string, 0, len(sd.GetMethods())) for _, method := range sd.GetMethods() { - methods = append(methods, method.GetName()) + methods = append(methods, method.GetFullyQualifiedName()) } sort.Strings(methods) return methods, nil } } -type notFoundError string - -func notFound(kind, name string) error { - return notFoundError(fmt.Sprintf("%s not found: %s", kind, name)) -} - -func (e notFoundError) Error() string { - return string(e) -} - -func isNotFoundError(err error) bool { - if grpcreflect.IsElementNotFoundError(err) { - return true - } - _, ok := err.(notFoundError) - return ok -} - -// InvocationEventHandler is a bag of callbacks for handling events that occur in the course -// of invoking an RPC. The handler also provides request data that is sent. The callbacks are -// generally called in the order they are listed below. -type InvocationEventHandler interface { - // OnResolveMethod is called with a descriptor of the method that is being invoked. - OnResolveMethod(*desc.MethodDescriptor) - // OnSendHeaders is called with the request metadata that is being sent. - OnSendHeaders(metadata.MD) - // OnReceiveHeaders is called when response headers have been received. - OnReceiveHeaders(metadata.MD) - // OnReceiveResponse is called for each response message received. - OnReceiveResponse(proto.Message) - // OnReceiveTrailers is called when response trailers and final RPC status have been received. - OnReceiveTrailers(*status.Status, metadata.MD) -} - -// RequestMessageSupplier is a function that is called to retrieve request -// messages for a GRPC operation. This type is deprecated and will be removed in -// a future release. -// -// Deprecated: This is only used with the deprecated InvokeRpc. Instead, use -// RequestSupplier with InvokeRPC. -type RequestMessageSupplier func() ([]byte, error) - -// InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated -// and will be removed in a future release. It just delegates to the similarly named InvokeRPC -// method, whose signature is only slightly different. -// -// Deprecated: use InvokeRPC instead. -func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string, - headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error { - - return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error { - // New function is almost identical, but the request supplier function works differently. - // So we adapt the logic here to maintain compatibility. - data, err := requestData() - if err != nil { - return err - } - return jsonpb.Unmarshal(bytes.NewReader(data), m) - }) -} - -// RequestSupplier is a function that is called to populate messages for a gRPC operation. The -// function should populate the given message or return a non-nil error. If the supplier has no -// more messages, it should return io.EOF. When it returns io.EOF, it should not in any way -// modify the given message argument. -type RequestSupplier func(proto.Message) error - -// InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source -// is used to determine the type of method and the type of request and response message. The given -// headers are sent as request metadata. Methods on the given event handler are called as the -// invocation proceeds. -// -// The given requestData function supplies the actual data to send. It should return io.EOF when -// there is no more request data. If the method being invoked is a unary or server-streaming RPC -// (e.g. exactly one request message) and there is no request data (e.g. the first invocation of -// the function returns io.EOF), then an empty request message is sent. -// -// If the requestData function and the given event handler coordinate or share any state, they should -// be thread-safe. This is because the requestData function may be called from a different goroutine -// than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where -// one goroutine sends request messages and another consumes the response messages). -func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string, - headers []string, handler InvocationEventHandler, requestData RequestSupplier) error { - - md := MetadataFromHeaders(headers) - - svc, mth := parseSymbol(methodName) - if svc == "" || mth == "" { - return fmt.Errorf("given method name %q is not in expected format: 'service/method' or 'service.method'", methodName) - } - dsc, err := source.FindSymbol(svc) - if err != nil { - if isNotFoundError(err) { - return fmt.Errorf("target server does not expose service %q", svc) - } - return fmt.Errorf("failed to query for service descriptor %q: %v", svc, err) - } - sd, ok := dsc.(*desc.ServiceDescriptor) - if !ok { - return fmt.Errorf("target server does not expose service %q", svc) - } - mtd := sd.FindMethodByName(mth) - if mtd == nil { - return fmt.Errorf("service %q does not include a method named %q", svc, mth) - } - - handler.OnResolveMethod(mtd) - - // we also download any applicable extensions so we can provide full support for parsing user-provided data - var ext dynamic.ExtensionRegistry - alreadyFetched := map[string]bool{} - if err = fetchAllExtensions(source, &ext, mtd.GetInputType(), alreadyFetched); err != nil { - return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetInputType().GetFullyQualifiedName(), err) - } - if err = fetchAllExtensions(source, &ext, mtd.GetOutputType(), alreadyFetched); err != nil { - return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetOutputType().GetFullyQualifiedName(), err) - } - - msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext) - req := msgFactory.NewMessage(mtd.GetInputType()) - - handler.OnSendHeaders(md) - ctx = metadata.NewOutgoingContext(ctx, md) - - stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory) - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - if mtd.IsClientStreaming() && mtd.IsServerStreaming() { - return invokeBidi(ctx, stub, mtd, handler, requestData, req) - } else if mtd.IsClientStreaming() { - return invokeClientStream(ctx, stub, mtd, handler, requestData, req) - } else if mtd.IsServerStreaming() { - return invokeServerStream(ctx, stub, mtd, handler, requestData, req) - } else { - return invokeUnary(ctx, stub, mtd, handler, requestData, req) - } -} - -func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, - requestData RequestSupplier, req proto.Message) error { - - err := requestData(req) - if err != nil && err != io.EOF { - return fmt.Errorf("error getting request data: %v", err) - } - if err != io.EOF { - // verify there is no second message, which is a usage error - err := requestData(req) - if err == nil { - return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) - } else if err != io.EOF { - return fmt.Errorf("error getting request data: %v", err) - } - } - - // Now we can actually invoke the RPC! - var respHeaders metadata.MD - var respTrailers metadata.MD - resp, err := stub.InvokeRpc(ctx, md, req, grpc.Trailer(&respTrailers), grpc.Header(&respHeaders)) - - stat, ok := status.FromError(err) - if !ok { - // Error codes sent from the server will get printed differently below. - // So just bail for other kinds of errors here. - return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) - } - - handler.OnReceiveHeaders(respHeaders) - - if stat.Code() == codes.OK { - handler.OnReceiveResponse(resp) - } - - handler.OnReceiveTrailers(stat, respTrailers) - - return nil -} - -func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, - requestData RequestSupplier, req proto.Message) error { - - // invoke the RPC! - str, err := stub.InvokeRpcClientStream(ctx, md) - - // Upload each request message in the stream - var resp proto.Message - for err == nil { - err = requestData(req) - if err == io.EOF { - resp, err = str.CloseAndReceive() - break - } - if err != nil { - return fmt.Errorf("error getting request data: %v", err) - } - - err = str.SendMsg(req) - if err == io.EOF { - // We get EOF on send if the server says "go away" - // We have to use CloseAndReceive to get the actual code - resp, err = str.CloseAndReceive() - break - } - - req.Reset() - } - - // finally, process response data - stat, ok := status.FromError(err) - if !ok { - // Error codes sent from the server will get printed differently below. - // So just bail for other kinds of errors here. - return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) - } - - if respHeaders, err := str.Header(); err == nil { - handler.OnReceiveHeaders(respHeaders) - } - - if stat.Code() == codes.OK { - handler.OnReceiveResponse(resp) - } - - handler.OnReceiveTrailers(stat, str.Trailer()) - - return nil -} - -func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, - requestData RequestSupplier, req proto.Message) error { - - err := requestData(req) - if err != nil && err != io.EOF { - return fmt.Errorf("error getting request data: %v", err) - } - if err != io.EOF { - // verify there is no second message, which is a usage error - err := requestData(req) - if err == nil { - return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) - } else if err != io.EOF { - return fmt.Errorf("error getting request data: %v", err) - } - } - - // Now we can actually invoke the RPC! - str, err := stub.InvokeRpcServerStream(ctx, md, req) - - if respHeaders, err := str.Header(); err == nil { - handler.OnReceiveHeaders(respHeaders) - } - - // Download each response message - for err == nil { - var resp proto.Message - resp, err = str.RecvMsg() - if err != nil { - if err == io.EOF { - err = nil - } - break - } - handler.OnReceiveResponse(resp) - } - - stat, ok := status.FromError(err) - if !ok { - // Error codes sent from the server will get printed differently below. - // So just bail for other kinds of errors here. - return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) - } - - handler.OnReceiveTrailers(stat, str.Trailer()) - - return nil -} - -func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, - requestData RequestSupplier, req proto.Message) error { - - // invoke the RPC! - str, err := stub.InvokeRpcBidiStream(ctx, md) - - var wg sync.WaitGroup - var sendErr atomic.Value - - defer wg.Wait() - - if err == nil { - wg.Add(1) - go func() { - defer wg.Done() - - // Concurrently upload each request message in the stream - var err error - for err == nil { - err = requestData(req) - - if err == io.EOF { - err = str.CloseSend() - break - } - if err != nil { - err = fmt.Errorf("error getting request data: %v", err) - break - } - - err = str.SendMsg(req) - - req.Reset() - } - - if err != nil { - sendErr.Store(err) - } - }() - } - - if respHeaders, err := str.Header(); err == nil { - handler.OnReceiveHeaders(respHeaders) - } - - // Download each response message - for err == nil { - var resp proto.Message - resp, err = str.RecvMsg() - if err != nil { - if err == io.EOF { - err = nil - } - break - } - handler.OnReceiveResponse(resp) - } - - if se, ok := sendErr.Load().(error); ok && se != io.EOF { - err = se - } - - stat, ok := status.FromError(err) - if !ok { - // Error codes sent from the server will get printed differently below. - // So just bail for other kinds of errors here. - return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) - } - - handler.OnReceiveTrailers(stat, str.Trailer()) - - return nil -} - // MetadataFromHeaders converts a list of header strings (each string in // "Header-Name: Header-Value" form) into metadata. If a string has a header // name without a value (e.g. does not contain a colon), the value is assumed @@ -767,17 +175,6 @@ func decode(val string) (string, error) { return "", firstErr } -func parseSymbol(svcAndMethod string) (string, string) { - pos := strings.LastIndex(svcAndMethod, "/") - if pos < 0 { - pos = strings.LastIndex(svcAndMethod, ".") - if pos < 0 { - return "", "" - } - } - return svcAndMethod[:pos], svcAndMethod[pos+1:] -} - // MetadataToString returns a string representation of the given metadata, for // displaying to users. func MetadataToString(md metadata.MD) string { diff --git a/grpcurl_test.go b/grpcurl_test.go index 932137f..2c6fab7 100644 --- a/grpcurl_test.go +++ b/grpcurl_test.go @@ -201,12 +201,12 @@ func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection t.Fatalf("failed to list methods for TestService: %v", err) } expected := []string{ - "EmptyCall", - "FullDuplexCall", - "HalfDuplexCall", - "StreamingInputCall", - "StreamingOutputCall", - "UnaryCall", + "grpc.testing.TestService.EmptyCall", + "grpc.testing.TestService.FullDuplexCall", + "grpc.testing.TestService.HalfDuplexCall", + "grpc.testing.TestService.StreamingInputCall", + "grpc.testing.TestService.StreamingOutputCall", + "grpc.testing.TestService.UnaryCall", } if !reflect.DeepEqual(expected, names) { t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names) @@ -218,7 +218,7 @@ func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection if err != nil { t.Fatalf("failed to list methods for ServerReflection: %v", err) } - expected = []string{"ServerReflectionInfo"} + expected = []string{"grpc.reflection.v1alpha.ServerReflection.ServerReflectionInfo"} } else { // without reflection, we see all services defined in the same test.proto file, which is the // TestService as well as UnimplementedService @@ -226,7 +226,7 @@ func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection if err != nil { t.Fatalf("failed to list methods for ServerReflection: %v", err) } - expected = []string{"UnimplementedCall"} + expected = []string{"grpc.testing.UnimplementedService.UnimplementedCall"} } if !reflect.DeepEqual(expected, names) { t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names) diff --git a/invoke.go b/invoke.go new file mode 100644 index 0000000..39e8514 --- /dev/null +++ b/invoke.go @@ -0,0 +1,385 @@ +package grpcurl + +import ( + "bytes" + "fmt" + "io" + "strings" + "sync" + "sync/atomic" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/dynamic" + "github.com/jhump/protoreflect/dynamic/grpcdynamic" + "github.com/jhump/protoreflect/grpcreflect" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// InvocationEventHandler is a bag of callbacks for handling events that occur in the course +// of invoking an RPC. The handler also provides request data that is sent. The callbacks are +// generally called in the order they are listed below. +type InvocationEventHandler interface { + // OnResolveMethod is called with a descriptor of the method that is being invoked. + OnResolveMethod(*desc.MethodDescriptor) + // OnSendHeaders is called with the request metadata that is being sent. + OnSendHeaders(metadata.MD) + // OnReceiveHeaders is called when response headers have been received. + OnReceiveHeaders(metadata.MD) + // OnReceiveResponse is called for each response message received. + OnReceiveResponse(proto.Message) + // OnReceiveTrailers is called when response trailers and final RPC status have been received. + OnReceiveTrailers(*status.Status, metadata.MD) +} + +// RequestMessageSupplier is a function that is called to retrieve request +// messages for a GRPC operation. This type is deprecated and will be removed in +// a future release. +// +// Deprecated: This is only used with the deprecated InvokeRpc. Instead, use +// RequestSupplier with InvokeRPC. +type RequestMessageSupplier func() ([]byte, error) + +// InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated +// and will be removed in a future release. It just delegates to the similarly named InvokeRPC +// method, whose signature is only slightly different. +// +// Deprecated: use InvokeRPC instead. +func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string, + headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error { + + return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error { + // New function is almost identical, but the request supplier function works differently. + // So we adapt the logic here to maintain compatibility. + data, err := requestData() + if err != nil { + return err + } + return jsonpb.Unmarshal(bytes.NewReader(data), m) + }) +} + +// RequestSupplier is a function that is called to populate messages for a gRPC operation. The +// function should populate the given message or return a non-nil error. If the supplier has no +// more messages, it should return io.EOF. When it returns io.EOF, it should not in any way +// modify the given message argument. +type RequestSupplier func(proto.Message) error + +// InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source +// is used to determine the type of method and the type of request and response message. The given +// headers are sent as request metadata. Methods on the given event handler are called as the +// invocation proceeds. +// +// The given requestData function supplies the actual data to send. It should return io.EOF when +// there is no more request data. If the method being invoked is a unary or server-streaming RPC +// (e.g. exactly one request message) and there is no request data (e.g. the first invocation of +// the function returns io.EOF), then an empty request message is sent. +// +// If the requestData function and the given event handler coordinate or share any state, they should +// be thread-safe. This is because the requestData function may be called from a different goroutine +// than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where +// one goroutine sends request messages and another consumes the response messages). +func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string, + headers []string, handler InvocationEventHandler, requestData RequestSupplier) error { + + md := MetadataFromHeaders(headers) + + svc, mth := parseSymbol(methodName) + if svc == "" || mth == "" { + return fmt.Errorf("given method name %q is not in expected format: 'service/method' or 'service.method'", methodName) + } + dsc, err := source.FindSymbol(svc) + if err != nil { + if isNotFoundError(err) { + return fmt.Errorf("target server does not expose service %q", svc) + } + return fmt.Errorf("failed to query for service descriptor %q: %v", svc, err) + } + sd, ok := dsc.(*desc.ServiceDescriptor) + if !ok { + return fmt.Errorf("target server does not expose service %q", svc) + } + mtd := sd.FindMethodByName(mth) + if mtd == nil { + return fmt.Errorf("service %q does not include a method named %q", svc, mth) + } + + handler.OnResolveMethod(mtd) + + // we also download any applicable extensions so we can provide full support for parsing user-provided data + var ext dynamic.ExtensionRegistry + alreadyFetched := map[string]bool{} + if err = fetchAllExtensions(source, &ext, mtd.GetInputType(), alreadyFetched); err != nil { + return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetInputType().GetFullyQualifiedName(), err) + } + if err = fetchAllExtensions(source, &ext, mtd.GetOutputType(), alreadyFetched); err != nil { + return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetOutputType().GetFullyQualifiedName(), err) + } + + msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext) + req := msgFactory.NewMessage(mtd.GetInputType()) + + handler.OnSendHeaders(md) + ctx = metadata.NewOutgoingContext(ctx, md) + + stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + if mtd.IsClientStreaming() && mtd.IsServerStreaming() { + return invokeBidi(ctx, stub, mtd, handler, requestData, req) + } else if mtd.IsClientStreaming() { + return invokeClientStream(ctx, stub, mtd, handler, requestData, req) + } else if mtd.IsServerStreaming() { + return invokeServerStream(ctx, stub, mtd, handler, requestData, req) + } else { + return invokeUnary(ctx, stub, mtd, handler, requestData, req) + } +} + +func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, + requestData RequestSupplier, req proto.Message) error { + + err := requestData(req) + if err != nil && err != io.EOF { + return fmt.Errorf("error getting request data: %v", err) + } + if err != io.EOF { + // verify there is no second message, which is a usage error + err := requestData(req) + if err == nil { + return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) + } else if err != io.EOF { + return fmt.Errorf("error getting request data: %v", err) + } + } + + // Now we can actually invoke the RPC! + var respHeaders metadata.MD + var respTrailers metadata.MD + resp, err := stub.InvokeRpc(ctx, md, req, grpc.Trailer(&respTrailers), grpc.Header(&respHeaders)) + + stat, ok := status.FromError(err) + if !ok { + // Error codes sent from the server will get printed differently below. + // So just bail for other kinds of errors here. + return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) + } + + handler.OnReceiveHeaders(respHeaders) + + if stat.Code() == codes.OK { + handler.OnReceiveResponse(resp) + } + + handler.OnReceiveTrailers(stat, respTrailers) + + return nil +} + +func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, + requestData RequestSupplier, req proto.Message) error { + + // invoke the RPC! + str, err := stub.InvokeRpcClientStream(ctx, md) + + // Upload each request message in the stream + var resp proto.Message + for err == nil { + err = requestData(req) + if err == io.EOF { + resp, err = str.CloseAndReceive() + break + } + if err != nil { + return fmt.Errorf("error getting request data: %v", err) + } + + err = str.SendMsg(req) + if err == io.EOF { + // We get EOF on send if the server says "go away" + // We have to use CloseAndReceive to get the actual code + resp, err = str.CloseAndReceive() + break + } + + req.Reset() + } + + // finally, process response data + stat, ok := status.FromError(err) + if !ok { + // Error codes sent from the server will get printed differently below. + // So just bail for other kinds of errors here. + return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) + } + + if respHeaders, err := str.Header(); err == nil { + handler.OnReceiveHeaders(respHeaders) + } + + if stat.Code() == codes.OK { + handler.OnReceiveResponse(resp) + } + + handler.OnReceiveTrailers(stat, str.Trailer()) + + return nil +} + +func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, + requestData RequestSupplier, req proto.Message) error { + + err := requestData(req) + if err != nil && err != io.EOF { + return fmt.Errorf("error getting request data: %v", err) + } + if err != io.EOF { + // verify there is no second message, which is a usage error + err := requestData(req) + if err == nil { + return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) + } else if err != io.EOF { + return fmt.Errorf("error getting request data: %v", err) + } + } + + // Now we can actually invoke the RPC! + str, err := stub.InvokeRpcServerStream(ctx, md, req) + + if respHeaders, err := str.Header(); err == nil { + handler.OnReceiveHeaders(respHeaders) + } + + // Download each response message + for err == nil { + var resp proto.Message + resp, err = str.RecvMsg() + if err != nil { + if err == io.EOF { + err = nil + } + break + } + handler.OnReceiveResponse(resp) + } + + stat, ok := status.FromError(err) + if !ok { + // Error codes sent from the server will get printed differently below. + // So just bail for other kinds of errors here. + return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) + } + + handler.OnReceiveTrailers(stat, str.Trailer()) + + return nil +} + +func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, + requestData RequestSupplier, req proto.Message) error { + + // invoke the RPC! + str, err := stub.InvokeRpcBidiStream(ctx, md) + + var wg sync.WaitGroup + var sendErr atomic.Value + + defer wg.Wait() + + if err == nil { + wg.Add(1) + go func() { + defer wg.Done() + + // Concurrently upload each request message in the stream + var err error + for err == nil { + err = requestData(req) + + if err == io.EOF { + err = str.CloseSend() + break + } + if err != nil { + err = fmt.Errorf("error getting request data: %v", err) + break + } + + err = str.SendMsg(req) + + req.Reset() + } + + if err != nil { + sendErr.Store(err) + } + }() + } + + if respHeaders, err := str.Header(); err == nil { + handler.OnReceiveHeaders(respHeaders) + } + + // Download each response message + for err == nil { + var resp proto.Message + resp, err = str.RecvMsg() + if err != nil { + if err == io.EOF { + err = nil + } + break + } + handler.OnReceiveResponse(resp) + } + + if se, ok := sendErr.Load().(error); ok && se != io.EOF { + err = se + } + + stat, ok := status.FromError(err) + if !ok { + // Error codes sent from the server will get printed differently below. + // So just bail for other kinds of errors here. + return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) + } + + handler.OnReceiveTrailers(stat, str.Trailer()) + + return nil +} + +type notFoundError string + +func notFound(kind, name string) error { + return notFoundError(fmt.Sprintf("%s not found: %s", kind, name)) +} + +func (e notFoundError) Error() string { + return string(e) +} + +func isNotFoundError(err error) bool { + if grpcreflect.IsElementNotFoundError(err) { + return true + } + _, ok := err.(notFoundError) + return ok +} + +func parseSymbol(svcAndMethod string) (string, string) { + pos := strings.LastIndex(svcAndMethod, "/") + if pos < 0 { + pos = strings.LastIndex(svcAndMethod, ".") + if pos < 0 { + return "", "" + } + } + return svcAndMethod[:pos], svcAndMethod[pos+1:] +}