some pre-factoring and small fixes (#58)
* organize into multiple files * make listing methods show fully-qualified names * address small feedback from recent change (trim then check if empty)
This commit is contained in:
parent
69ea782936
commit
9a4bbacdd6
|
|
@ -566,10 +566,11 @@ func prettify(docString string) string {
|
||||||
// from each line in the doc string
|
// from each line in the doc string
|
||||||
j := 0
|
j := 0
|
||||||
for _, part := range parts {
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
if part == "" {
|
if part == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
parts[j] = strings.TrimSpace(part)
|
parts[j] = part
|
||||||
j++
|
j++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,248 @@
|
||||||
|
package grpcurl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
|
||||||
|
"github.com/jhump/protoreflect/desc"
|
||||||
|
"github.com/jhump/protoreflect/desc/protoparse"
|
||||||
|
"github.com/jhump/protoreflect/dynamic"
|
||||||
|
"github.com/jhump/protoreflect/grpcreflect"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrReflectionNotSupported is returned by DescriptorSource operations that
|
||||||
|
// rely on interacting with the reflection service when the source does not
|
||||||
|
// actually expose the reflection service. When this occurs, an alternate source
|
||||||
|
// (like file descriptor sets) must be used.
|
||||||
|
var ErrReflectionNotSupported = errors.New("server does not support the reflection API")
|
||||||
|
|
||||||
|
// DescriptorSource is a source of protobuf descriptor information. It can be backed by a FileDescriptorSet
|
||||||
|
// proto (like a file generated by protoc) or a remote server that supports the reflection API.
|
||||||
|
type DescriptorSource interface {
|
||||||
|
// ListServices returns a list of fully-qualified service names. It will be all services in a set of
|
||||||
|
// descriptor files or the set of all services exposed by a gRPC server.
|
||||||
|
ListServices() ([]string, error)
|
||||||
|
// FindSymbol returns a descriptor for the given fully-qualified symbol name.
|
||||||
|
FindSymbol(fullyQualifiedName string) (desc.Descriptor, error)
|
||||||
|
// AllExtensionsForType returns all known extension fields that extend the given message type name.
|
||||||
|
AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptorSourceFromProtoSets creates a DescriptorSource that is backed by the named files, whose contents
|
||||||
|
// are encoded FileDescriptorSet protos.
|
||||||
|
func DescriptorSourceFromProtoSets(fileNames ...string) (DescriptorSource, error) {
|
||||||
|
files := &descpb.FileDescriptorSet{}
|
||||||
|
for _, fileName := range fileNames {
|
||||||
|
b, err := ioutil.ReadFile(fileName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not load protoset file %q: %v", fileName, err)
|
||||||
|
}
|
||||||
|
var fs descpb.FileDescriptorSet
|
||||||
|
err = proto.Unmarshal(b, &fs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not parse contents of protoset file %q: %v", fileName, err)
|
||||||
|
}
|
||||||
|
files.File = append(files.File, fs.File...)
|
||||||
|
}
|
||||||
|
return DescriptorSourceFromFileDescriptorSet(files)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptorSourceFromProtoFiles creates a DescriptorSource that is backed by the named files,
|
||||||
|
// whose contents are Protocol Buffer source files. The given importPaths are used to locate
|
||||||
|
// any imported files.
|
||||||
|
func DescriptorSourceFromProtoFiles(importPaths []string, fileNames ...string) (DescriptorSource, error) {
|
||||||
|
p := protoparse.Parser{
|
||||||
|
ImportPaths: importPaths,
|
||||||
|
InferImportPaths: len(importPaths) == 0,
|
||||||
|
}
|
||||||
|
fds, err := p.ParseFiles(fileNames...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not parse given files: %v", err)
|
||||||
|
}
|
||||||
|
return DescriptorSourceFromFileDescriptors(fds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the FileDescriptorSet.
|
||||||
|
func DescriptorSourceFromFileDescriptorSet(files *descpb.FileDescriptorSet) (DescriptorSource, error) {
|
||||||
|
unresolved := map[string]*descpb.FileDescriptorProto{}
|
||||||
|
for _, fd := range files.File {
|
||||||
|
unresolved[fd.GetName()] = fd
|
||||||
|
}
|
||||||
|
resolved := map[string]*desc.FileDescriptor{}
|
||||||
|
for _, fd := range files.File {
|
||||||
|
_, err := resolveFileDescriptor(unresolved, resolved, fd.GetName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &fileSource{files: resolved}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveFileDescriptor(unresolved map[string]*descpb.FileDescriptorProto, resolved map[string]*desc.FileDescriptor, filename string) (*desc.FileDescriptor, error) {
|
||||||
|
if r, ok := resolved[filename]; ok {
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
fd, ok := unresolved[filename]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no descriptor found for %q", filename)
|
||||||
|
}
|
||||||
|
deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency()))
|
||||||
|
for _, dep := range fd.GetDependency() {
|
||||||
|
depFd, err := resolveFileDescriptor(unresolved, resolved, dep)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
deps = append(deps, depFd)
|
||||||
|
}
|
||||||
|
result, err := desc.CreateFileDescriptor(fd, deps...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resolved[filename] = result
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the given
|
||||||
|
// file descriptors
|
||||||
|
func DescriptorSourceFromFileDescriptors(files ...*desc.FileDescriptor) (DescriptorSource, error) {
|
||||||
|
fds := map[string]*desc.FileDescriptor{}
|
||||||
|
for _, fd := range files {
|
||||||
|
if err := addFile(fd, fds); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &fileSource{files: fds}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addFile(fd *desc.FileDescriptor, fds map[string]*desc.FileDescriptor) error {
|
||||||
|
name := fd.GetName()
|
||||||
|
if existing, ok := fds[name]; ok {
|
||||||
|
// already added this file
|
||||||
|
if existing != fd {
|
||||||
|
// doh! duplicate files provided
|
||||||
|
return fmt.Errorf("given files include multiple copies of %q", name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
fds[name] = fd
|
||||||
|
for _, dep := range fd.GetDependencies() {
|
||||||
|
if err := addFile(dep, fds); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileSource struct {
|
||||||
|
files map[string]*desc.FileDescriptor
|
||||||
|
er *dynamic.ExtensionRegistry
|
||||||
|
erInit sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *fileSource) ListServices() ([]string, error) {
|
||||||
|
set := map[string]bool{}
|
||||||
|
for _, fd := range fs.files {
|
||||||
|
for _, svc := range fd.GetServices() {
|
||||||
|
set[svc.GetFullyQualifiedName()] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sl := make([]string, 0, len(set))
|
||||||
|
for svc := range set {
|
||||||
|
sl = append(sl, svc)
|
||||||
|
}
|
||||||
|
return sl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllFiles returns all of the underlying file descriptors. This is
|
||||||
|
// more thorough and more efficient than the fallback strategy used by
|
||||||
|
// the GetAllFiles package method, for enumerating all files from a
|
||||||
|
// descriptor source.
|
||||||
|
func (fs *fileSource) GetAllFiles() ([]*desc.FileDescriptor, error) {
|
||||||
|
files := make([]*desc.FileDescriptor, len(fs.files))
|
||||||
|
i := 0
|
||||||
|
for _, fd := range fs.files {
|
||||||
|
files[i] = fd
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return files, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *fileSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) {
|
||||||
|
for _, fd := range fs.files {
|
||||||
|
if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil {
|
||||||
|
return dsc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, notFound("Symbol", fullyQualifiedName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *fileSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) {
|
||||||
|
fs.erInit.Do(func() {
|
||||||
|
fs.er = &dynamic.ExtensionRegistry{}
|
||||||
|
for _, fd := range fs.files {
|
||||||
|
fs.er.AddExtensionsFromFile(fd)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return fs.er.AllExtensionsForType(typeName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DescriptorSourceFromServer creates a DescriptorSource that uses the given gRPC reflection client
|
||||||
|
// to interrogate a server for descriptor information. If the server does not support the reflection
|
||||||
|
// API then the various DescriptorSource methods will return ErrReflectionNotSupported
|
||||||
|
func DescriptorSourceFromServer(_ context.Context, refClient *grpcreflect.Client) DescriptorSource {
|
||||||
|
return serverSource{client: refClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverSource struct {
|
||||||
|
client *grpcreflect.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss serverSource) ListServices() ([]string, error) {
|
||||||
|
svcs, err := ss.client.ListServices()
|
||||||
|
return svcs, reflectionSupport(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss serverSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) {
|
||||||
|
file, err := ss.client.FileContainingSymbol(fullyQualifiedName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, reflectionSupport(err)
|
||||||
|
}
|
||||||
|
d := file.FindSymbol(fullyQualifiedName)
|
||||||
|
if d == nil {
|
||||||
|
return nil, notFound("Symbol", fullyQualifiedName)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss serverSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) {
|
||||||
|
var exts []*desc.FieldDescriptor
|
||||||
|
nums, err := ss.client.AllExtensionNumbersForType(typeName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, reflectionSupport(err)
|
||||||
|
}
|
||||||
|
for _, fieldNum := range nums {
|
||||||
|
ext, err := ss.client.ResolveExtension(typeName, fieldNum)
|
||||||
|
if err != nil {
|
||||||
|
return nil, reflectionSupport(err)
|
||||||
|
}
|
||||||
|
exts = append(exts, ext)
|
||||||
|
}
|
||||||
|
return exts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func reflectionSupport(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if stat, ok := status.FromError(err); ok && stat.Code() == codes.Unimplemented {
|
||||||
|
return ErrReflectionNotSupported
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
605
grpcurl.go
605
grpcurl.go
|
|
@ -13,262 +13,22 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/jsonpb"
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
|
|
||||||
"github.com/jhump/protoreflect/desc"
|
"github.com/jhump/protoreflect/desc"
|
||||||
"github.com/jhump/protoreflect/desc/protoparse"
|
|
||||||
"github.com/jhump/protoreflect/desc/protoprint"
|
"github.com/jhump/protoreflect/desc/protoprint"
|
||||||
"github.com/jhump/protoreflect/dynamic"
|
"github.com/jhump/protoreflect/dynamic"
|
||||||
"github.com/jhump/protoreflect/dynamic/grpcdynamic"
|
|
||||||
"github.com/jhump/protoreflect/grpcreflect"
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrReflectionNotSupported is returned by DescriptorSource operations that
|
|
||||||
// rely on interacting with the reflection service when the source does not
|
|
||||||
// actually expose the reflection service. When this occurs, an alternate source
|
|
||||||
// (like file descriptor sets) must be used.
|
|
||||||
var ErrReflectionNotSupported = errors.New("server does not support the reflection API")
|
|
||||||
|
|
||||||
// DescriptorSource is a source of protobuf descriptor information. It can be backed by a FileDescriptorSet
|
|
||||||
// proto (like a file generated by protoc) or a remote server that supports the reflection API.
|
|
||||||
type DescriptorSource interface {
|
|
||||||
// ListServices returns a list of fully-qualified service names. It will be all services in a set of
|
|
||||||
// descriptor files or the set of all services exposed by a gRPC server.
|
|
||||||
ListServices() ([]string, error)
|
|
||||||
// FindSymbol returns a descriptor for the given fully-qualified symbol name.
|
|
||||||
FindSymbol(fullyQualifiedName string) (desc.Descriptor, error)
|
|
||||||
// AllExtensionsForType returns all known extension fields that extend the given message type name.
|
|
||||||
AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DescriptorSourceFromProtoSets creates a DescriptorSource that is backed by the named files, whose contents
|
|
||||||
// are encoded FileDescriptorSet protos.
|
|
||||||
func DescriptorSourceFromProtoSets(fileNames ...string) (DescriptorSource, error) {
|
|
||||||
files := &descpb.FileDescriptorSet{}
|
|
||||||
for _, fileName := range fileNames {
|
|
||||||
b, err := ioutil.ReadFile(fileName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not load protoset file %q: %v", fileName, err)
|
|
||||||
}
|
|
||||||
var fs descpb.FileDescriptorSet
|
|
||||||
err = proto.Unmarshal(b, &fs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not parse contents of protoset file %q: %v", fileName, err)
|
|
||||||
}
|
|
||||||
files.File = append(files.File, fs.File...)
|
|
||||||
}
|
|
||||||
return DescriptorSourceFromFileDescriptorSet(files)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DescriptorSourceFromProtoFiles creates a DescriptorSource that is backed by the named files,
|
|
||||||
// whose contents are Protocol Buffer source files. The given importPaths are used to locate
|
|
||||||
// any imported files.
|
|
||||||
func DescriptorSourceFromProtoFiles(importPaths []string, fileNames ...string) (DescriptorSource, error) {
|
|
||||||
p := protoparse.Parser{
|
|
||||||
ImportPaths: importPaths,
|
|
||||||
InferImportPaths: len(importPaths) == 0,
|
|
||||||
}
|
|
||||||
fds, err := p.ParseFiles(fileNames...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not parse given files: %v", err)
|
|
||||||
}
|
|
||||||
return DescriptorSourceFromFileDescriptors(fds...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the FileDescriptorSet.
|
|
||||||
func DescriptorSourceFromFileDescriptorSet(files *descpb.FileDescriptorSet) (DescriptorSource, error) {
|
|
||||||
unresolved := map[string]*descpb.FileDescriptorProto{}
|
|
||||||
for _, fd := range files.File {
|
|
||||||
unresolved[fd.GetName()] = fd
|
|
||||||
}
|
|
||||||
resolved := map[string]*desc.FileDescriptor{}
|
|
||||||
for _, fd := range files.File {
|
|
||||||
_, err := resolveFileDescriptor(unresolved, resolved, fd.GetName())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &fileSource{files: resolved}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveFileDescriptor(unresolved map[string]*descpb.FileDescriptorProto, resolved map[string]*desc.FileDescriptor, filename string) (*desc.FileDescriptor, error) {
|
|
||||||
if r, ok := resolved[filename]; ok {
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
fd, ok := unresolved[filename]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("no descriptor found for %q", filename)
|
|
||||||
}
|
|
||||||
deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency()))
|
|
||||||
for _, dep := range fd.GetDependency() {
|
|
||||||
depFd, err := resolveFileDescriptor(unresolved, resolved, dep)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
deps = append(deps, depFd)
|
|
||||||
}
|
|
||||||
result, err := desc.CreateFileDescriptor(fd, deps...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resolved[filename] = result
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the given
|
|
||||||
// file descriptors
|
|
||||||
func DescriptorSourceFromFileDescriptors(files ...*desc.FileDescriptor) (DescriptorSource, error) {
|
|
||||||
fds := map[string]*desc.FileDescriptor{}
|
|
||||||
for _, fd := range files {
|
|
||||||
if err := addFile(fd, fds); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &fileSource{files: fds}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func addFile(fd *desc.FileDescriptor, fds map[string]*desc.FileDescriptor) error {
|
|
||||||
name := fd.GetName()
|
|
||||||
if existing, ok := fds[name]; ok {
|
|
||||||
// already added this file
|
|
||||||
if existing != fd {
|
|
||||||
// doh! duplicate files provided
|
|
||||||
return fmt.Errorf("given files include multiple copies of %q", name)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
fds[name] = fd
|
|
||||||
for _, dep := range fd.GetDependencies() {
|
|
||||||
if err := addFile(dep, fds); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type fileSource struct {
|
|
||||||
files map[string]*desc.FileDescriptor
|
|
||||||
er *dynamic.ExtensionRegistry
|
|
||||||
erInit sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fs *fileSource) ListServices() ([]string, error) {
|
|
||||||
set := map[string]bool{}
|
|
||||||
for _, fd := range fs.files {
|
|
||||||
for _, svc := range fd.GetServices() {
|
|
||||||
set[svc.GetFullyQualifiedName()] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sl := make([]string, 0, len(set))
|
|
||||||
for svc := range set {
|
|
||||||
sl = append(sl, svc)
|
|
||||||
}
|
|
||||||
return sl, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllFiles returns all of the underlying file descriptors. This is
|
|
||||||
// more thorough and more efficient than the fallback strategy used by
|
|
||||||
// the GetAllFiles package method, for enumerating all files from a
|
|
||||||
// descriptor source.
|
|
||||||
func (fs *fileSource) GetAllFiles() ([]*desc.FileDescriptor, error) {
|
|
||||||
files := make([]*desc.FileDescriptor, len(fs.files))
|
|
||||||
i := 0
|
|
||||||
for _, fd := range fs.files {
|
|
||||||
files[i] = fd
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
return files, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fs *fileSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) {
|
|
||||||
for _, fd := range fs.files {
|
|
||||||
if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil {
|
|
||||||
return dsc, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, notFound("Symbol", fullyQualifiedName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fs *fileSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) {
|
|
||||||
fs.erInit.Do(func() {
|
|
||||||
fs.er = &dynamic.ExtensionRegistry{}
|
|
||||||
for _, fd := range fs.files {
|
|
||||||
fs.er.AddExtensionsFromFile(fd)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return fs.er.AllExtensionsForType(typeName), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DescriptorSourceFromServer creates a DescriptorSource that uses the given gRPC reflection client
|
|
||||||
// to interrogate a server for descriptor information. If the server does not support the reflection
|
|
||||||
// API then the various DescriptorSource methods will return ErrReflectionNotSupported
|
|
||||||
func DescriptorSourceFromServer(ctx context.Context, refClient *grpcreflect.Client) DescriptorSource {
|
|
||||||
return serverSource{client: refClient}
|
|
||||||
}
|
|
||||||
|
|
||||||
type serverSource struct {
|
|
||||||
client *grpcreflect.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss serverSource) ListServices() ([]string, error) {
|
|
||||||
svcs, err := ss.client.ListServices()
|
|
||||||
return svcs, reflectionSupport(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss serverSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) {
|
|
||||||
file, err := ss.client.FileContainingSymbol(fullyQualifiedName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, reflectionSupport(err)
|
|
||||||
}
|
|
||||||
d := file.FindSymbol(fullyQualifiedName)
|
|
||||||
if d == nil {
|
|
||||||
return nil, notFound("Symbol", fullyQualifiedName)
|
|
||||||
}
|
|
||||||
return d, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ss serverSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) {
|
|
||||||
var exts []*desc.FieldDescriptor
|
|
||||||
nums, err := ss.client.AllExtensionNumbersForType(typeName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, reflectionSupport(err)
|
|
||||||
}
|
|
||||||
for _, fieldNum := range nums {
|
|
||||||
ext, err := ss.client.ResolveExtension(typeName, fieldNum)
|
|
||||||
if err != nil {
|
|
||||||
return nil, reflectionSupport(err)
|
|
||||||
}
|
|
||||||
exts = append(exts, ext)
|
|
||||||
}
|
|
||||||
return exts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func reflectionSupport(err error) error {
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if stat, ok := status.FromError(err); ok && stat.Code() == codes.Unimplemented {
|
|
||||||
return ErrReflectionNotSupported
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListServices uses the given descriptor source to return a sorted list of fully-qualified
|
// ListServices uses the given descriptor source to return a sorted list of fully-qualified
|
||||||
// service names.
|
// service names.
|
||||||
func ListServices(source DescriptorSource) ([]string, error) {
|
func ListServices(source DescriptorSource) ([]string, error) {
|
||||||
|
|
@ -361,365 +121,13 @@ func ListMethods(source DescriptorSource, serviceName string) ([]string, error)
|
||||||
} else {
|
} else {
|
||||||
methods := make([]string, 0, len(sd.GetMethods()))
|
methods := make([]string, 0, len(sd.GetMethods()))
|
||||||
for _, method := range sd.GetMethods() {
|
for _, method := range sd.GetMethods() {
|
||||||
methods = append(methods, method.GetName())
|
methods = append(methods, method.GetFullyQualifiedName())
|
||||||
}
|
}
|
||||||
sort.Strings(methods)
|
sort.Strings(methods)
|
||||||
return methods, nil
|
return methods, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type notFoundError string
|
|
||||||
|
|
||||||
func notFound(kind, name string) error {
|
|
||||||
return notFoundError(fmt.Sprintf("%s not found: %s", kind, name))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e notFoundError) Error() string {
|
|
||||||
return string(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isNotFoundError(err error) bool {
|
|
||||||
if grpcreflect.IsElementNotFoundError(err) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
_, ok := err.(notFoundError)
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvocationEventHandler is a bag of callbacks for handling events that occur in the course
|
|
||||||
// of invoking an RPC. The handler also provides request data that is sent. The callbacks are
|
|
||||||
// generally called in the order they are listed below.
|
|
||||||
type InvocationEventHandler interface {
|
|
||||||
// OnResolveMethod is called with a descriptor of the method that is being invoked.
|
|
||||||
OnResolveMethod(*desc.MethodDescriptor)
|
|
||||||
// OnSendHeaders is called with the request metadata that is being sent.
|
|
||||||
OnSendHeaders(metadata.MD)
|
|
||||||
// OnReceiveHeaders is called when response headers have been received.
|
|
||||||
OnReceiveHeaders(metadata.MD)
|
|
||||||
// OnReceiveResponse is called for each response message received.
|
|
||||||
OnReceiveResponse(proto.Message)
|
|
||||||
// OnReceiveTrailers is called when response trailers and final RPC status have been received.
|
|
||||||
OnReceiveTrailers(*status.Status, metadata.MD)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestMessageSupplier is a function that is called to retrieve request
|
|
||||||
// messages for a GRPC operation. This type is deprecated and will be removed in
|
|
||||||
// a future release.
|
|
||||||
//
|
|
||||||
// Deprecated: This is only used with the deprecated InvokeRpc. Instead, use
|
|
||||||
// RequestSupplier with InvokeRPC.
|
|
||||||
type RequestMessageSupplier func() ([]byte, error)
|
|
||||||
|
|
||||||
// InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated
|
|
||||||
// and will be removed in a future release. It just delegates to the similarly named InvokeRPC
|
|
||||||
// method, whose signature is only slightly different.
|
|
||||||
//
|
|
||||||
// Deprecated: use InvokeRPC instead.
|
|
||||||
func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string,
|
|
||||||
headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error {
|
|
||||||
|
|
||||||
return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error {
|
|
||||||
// New function is almost identical, but the request supplier function works differently.
|
|
||||||
// So we adapt the logic here to maintain compatibility.
|
|
||||||
data, err := requestData()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return jsonpb.Unmarshal(bytes.NewReader(data), m)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestSupplier is a function that is called to populate messages for a gRPC operation. The
|
|
||||||
// function should populate the given message or return a non-nil error. If the supplier has no
|
|
||||||
// more messages, it should return io.EOF. When it returns io.EOF, it should not in any way
|
|
||||||
// modify the given message argument.
|
|
||||||
type RequestSupplier func(proto.Message) error
|
|
||||||
|
|
||||||
// InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source
|
|
||||||
// is used to determine the type of method and the type of request and response message. The given
|
|
||||||
// headers are sent as request metadata. Methods on the given event handler are called as the
|
|
||||||
// invocation proceeds.
|
|
||||||
//
|
|
||||||
// The given requestData function supplies the actual data to send. It should return io.EOF when
|
|
||||||
// there is no more request data. If the method being invoked is a unary or server-streaming RPC
|
|
||||||
// (e.g. exactly one request message) and there is no request data (e.g. the first invocation of
|
|
||||||
// the function returns io.EOF), then an empty request message is sent.
|
|
||||||
//
|
|
||||||
// If the requestData function and the given event handler coordinate or share any state, they should
|
|
||||||
// be thread-safe. This is because the requestData function may be called from a different goroutine
|
|
||||||
// than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where
|
|
||||||
// one goroutine sends request messages and another consumes the response messages).
|
|
||||||
func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string,
|
|
||||||
headers []string, handler InvocationEventHandler, requestData RequestSupplier) error {
|
|
||||||
|
|
||||||
md := MetadataFromHeaders(headers)
|
|
||||||
|
|
||||||
svc, mth := parseSymbol(methodName)
|
|
||||||
if svc == "" || mth == "" {
|
|
||||||
return fmt.Errorf("given method name %q is not in expected format: 'service/method' or 'service.method'", methodName)
|
|
||||||
}
|
|
||||||
dsc, err := source.FindSymbol(svc)
|
|
||||||
if err != nil {
|
|
||||||
if isNotFoundError(err) {
|
|
||||||
return fmt.Errorf("target server does not expose service %q", svc)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to query for service descriptor %q: %v", svc, err)
|
|
||||||
}
|
|
||||||
sd, ok := dsc.(*desc.ServiceDescriptor)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("target server does not expose service %q", svc)
|
|
||||||
}
|
|
||||||
mtd := sd.FindMethodByName(mth)
|
|
||||||
if mtd == nil {
|
|
||||||
return fmt.Errorf("service %q does not include a method named %q", svc, mth)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.OnResolveMethod(mtd)
|
|
||||||
|
|
||||||
// we also download any applicable extensions so we can provide full support for parsing user-provided data
|
|
||||||
var ext dynamic.ExtensionRegistry
|
|
||||||
alreadyFetched := map[string]bool{}
|
|
||||||
if err = fetchAllExtensions(source, &ext, mtd.GetInputType(), alreadyFetched); err != nil {
|
|
||||||
return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetInputType().GetFullyQualifiedName(), err)
|
|
||||||
}
|
|
||||||
if err = fetchAllExtensions(source, &ext, mtd.GetOutputType(), alreadyFetched); err != nil {
|
|
||||||
return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetOutputType().GetFullyQualifiedName(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
|
|
||||||
req := msgFactory.NewMessage(mtd.GetInputType())
|
|
||||||
|
|
||||||
handler.OnSendHeaders(md)
|
|
||||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
|
||||||
|
|
||||||
stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory)
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if mtd.IsClientStreaming() && mtd.IsServerStreaming() {
|
|
||||||
return invokeBidi(ctx, stub, mtd, handler, requestData, req)
|
|
||||||
} else if mtd.IsClientStreaming() {
|
|
||||||
return invokeClientStream(ctx, stub, mtd, handler, requestData, req)
|
|
||||||
} else if mtd.IsServerStreaming() {
|
|
||||||
return invokeServerStream(ctx, stub, mtd, handler, requestData, req)
|
|
||||||
} else {
|
|
||||||
return invokeUnary(ctx, stub, mtd, handler, requestData, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
|
||||||
requestData RequestSupplier, req proto.Message) error {
|
|
||||||
|
|
||||||
err := requestData(req)
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
return fmt.Errorf("error getting request data: %v", err)
|
|
||||||
}
|
|
||||||
if err != io.EOF {
|
|
||||||
// verify there is no second message, which is a usage error
|
|
||||||
err := requestData(req)
|
|
||||||
if err == nil {
|
|
||||||
return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
|
|
||||||
} else if err != io.EOF {
|
|
||||||
return fmt.Errorf("error getting request data: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now we can actually invoke the RPC!
|
|
||||||
var respHeaders metadata.MD
|
|
||||||
var respTrailers metadata.MD
|
|
||||||
resp, err := stub.InvokeRpc(ctx, md, req, grpc.Trailer(&respTrailers), grpc.Header(&respHeaders))
|
|
||||||
|
|
||||||
stat, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
// Error codes sent from the server will get printed differently below.
|
|
||||||
// So just bail for other kinds of errors here.
|
|
||||||
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.OnReceiveHeaders(respHeaders)
|
|
||||||
|
|
||||||
if stat.Code() == codes.OK {
|
|
||||||
handler.OnReceiveResponse(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.OnReceiveTrailers(stat, respTrailers)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
|
||||||
requestData RequestSupplier, req proto.Message) error {
|
|
||||||
|
|
||||||
// invoke the RPC!
|
|
||||||
str, err := stub.InvokeRpcClientStream(ctx, md)
|
|
||||||
|
|
||||||
// Upload each request message in the stream
|
|
||||||
var resp proto.Message
|
|
||||||
for err == nil {
|
|
||||||
err = requestData(req)
|
|
||||||
if err == io.EOF {
|
|
||||||
resp, err = str.CloseAndReceive()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error getting request data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = str.SendMsg(req)
|
|
||||||
if err == io.EOF {
|
|
||||||
// We get EOF on send if the server says "go away"
|
|
||||||
// We have to use CloseAndReceive to get the actual code
|
|
||||||
resp, err = str.CloseAndReceive()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
// finally, process response data
|
|
||||||
stat, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
// Error codes sent from the server will get printed differently below.
|
|
||||||
// So just bail for other kinds of errors here.
|
|
||||||
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if respHeaders, err := str.Header(); err == nil {
|
|
||||||
handler.OnReceiveHeaders(respHeaders)
|
|
||||||
}
|
|
||||||
|
|
||||||
if stat.Code() == codes.OK {
|
|
||||||
handler.OnReceiveResponse(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.OnReceiveTrailers(stat, str.Trailer())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
|
||||||
requestData RequestSupplier, req proto.Message) error {
|
|
||||||
|
|
||||||
err := requestData(req)
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
return fmt.Errorf("error getting request data: %v", err)
|
|
||||||
}
|
|
||||||
if err != io.EOF {
|
|
||||||
// verify there is no second message, which is a usage error
|
|
||||||
err := requestData(req)
|
|
||||||
if err == nil {
|
|
||||||
return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
|
|
||||||
} else if err != io.EOF {
|
|
||||||
return fmt.Errorf("error getting request data: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now we can actually invoke the RPC!
|
|
||||||
str, err := stub.InvokeRpcServerStream(ctx, md, req)
|
|
||||||
|
|
||||||
if respHeaders, err := str.Header(); err == nil {
|
|
||||||
handler.OnReceiveHeaders(respHeaders)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Download each response message
|
|
||||||
for err == nil {
|
|
||||||
var resp proto.Message
|
|
||||||
resp, err = str.RecvMsg()
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
handler.OnReceiveResponse(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
stat, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
// Error codes sent from the server will get printed differently below.
|
|
||||||
// So just bail for other kinds of errors here.
|
|
||||||
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.OnReceiveTrailers(stat, str.Trailer())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
|
||||||
requestData RequestSupplier, req proto.Message) error {
|
|
||||||
|
|
||||||
// invoke the RPC!
|
|
||||||
str, err := stub.InvokeRpcBidiStream(ctx, md)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
var sendErr atomic.Value
|
|
||||||
|
|
||||||
defer wg.Wait()
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
// Concurrently upload each request message in the stream
|
|
||||||
var err error
|
|
||||||
for err == nil {
|
|
||||||
err = requestData(req)
|
|
||||||
|
|
||||||
if err == io.EOF {
|
|
||||||
err = str.CloseSend()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("error getting request data: %v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
err = str.SendMsg(req)
|
|
||||||
|
|
||||||
req.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
sendErr.Store(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
if respHeaders, err := str.Header(); err == nil {
|
|
||||||
handler.OnReceiveHeaders(respHeaders)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Download each response message
|
|
||||||
for err == nil {
|
|
||||||
var resp proto.Message
|
|
||||||
resp, err = str.RecvMsg()
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
handler.OnReceiveResponse(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if se, ok := sendErr.Load().(error); ok && se != io.EOF {
|
|
||||||
err = se
|
|
||||||
}
|
|
||||||
|
|
||||||
stat, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
// Error codes sent from the server will get printed differently below.
|
|
||||||
// So just bail for other kinds of errors here.
|
|
||||||
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.OnReceiveTrailers(stat, str.Trailer())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MetadataFromHeaders converts a list of header strings (each string in
|
// MetadataFromHeaders converts a list of header strings (each string in
|
||||||
// "Header-Name: Header-Value" form) into metadata. If a string has a header
|
// "Header-Name: Header-Value" form) into metadata. If a string has a header
|
||||||
// name without a value (e.g. does not contain a colon), the value is assumed
|
// name without a value (e.g. does not contain a colon), the value is assumed
|
||||||
|
|
@ -767,17 +175,6 @@ func decode(val string) (string, error) {
|
||||||
return "", firstErr
|
return "", firstErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSymbol(svcAndMethod string) (string, string) {
|
|
||||||
pos := strings.LastIndex(svcAndMethod, "/")
|
|
||||||
if pos < 0 {
|
|
||||||
pos = strings.LastIndex(svcAndMethod, ".")
|
|
||||||
if pos < 0 {
|
|
||||||
return "", ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return svcAndMethod[:pos], svcAndMethod[pos+1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// MetadataToString returns a string representation of the given metadata, for
|
// MetadataToString returns a string representation of the given metadata, for
|
||||||
// displaying to users.
|
// displaying to users.
|
||||||
func MetadataToString(md metadata.MD) string {
|
func MetadataToString(md metadata.MD) string {
|
||||||
|
|
|
||||||
|
|
@ -201,12 +201,12 @@ func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection
|
||||||
t.Fatalf("failed to list methods for TestService: %v", err)
|
t.Fatalf("failed to list methods for TestService: %v", err)
|
||||||
}
|
}
|
||||||
expected := []string{
|
expected := []string{
|
||||||
"EmptyCall",
|
"grpc.testing.TestService.EmptyCall",
|
||||||
"FullDuplexCall",
|
"grpc.testing.TestService.FullDuplexCall",
|
||||||
"HalfDuplexCall",
|
"grpc.testing.TestService.HalfDuplexCall",
|
||||||
"StreamingInputCall",
|
"grpc.testing.TestService.StreamingInputCall",
|
||||||
"StreamingOutputCall",
|
"grpc.testing.TestService.StreamingOutputCall",
|
||||||
"UnaryCall",
|
"grpc.testing.TestService.UnaryCall",
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(expected, names) {
|
if !reflect.DeepEqual(expected, names) {
|
||||||
t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names)
|
t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names)
|
||||||
|
|
@ -218,7 +218,7 @@ func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to list methods for ServerReflection: %v", err)
|
t.Fatalf("failed to list methods for ServerReflection: %v", err)
|
||||||
}
|
}
|
||||||
expected = []string{"ServerReflectionInfo"}
|
expected = []string{"grpc.reflection.v1alpha.ServerReflection.ServerReflectionInfo"}
|
||||||
} else {
|
} else {
|
||||||
// without reflection, we see all services defined in the same test.proto file, which is the
|
// without reflection, we see all services defined in the same test.proto file, which is the
|
||||||
// TestService as well as UnimplementedService
|
// TestService as well as UnimplementedService
|
||||||
|
|
@ -226,7 +226,7 @@ func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to list methods for ServerReflection: %v", err)
|
t.Fatalf("failed to list methods for ServerReflection: %v", err)
|
||||||
}
|
}
|
||||||
expected = []string{"UnimplementedCall"}
|
expected = []string{"grpc.testing.UnimplementedService.UnimplementedCall"}
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(expected, names) {
|
if !reflect.DeepEqual(expected, names) {
|
||||||
t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names)
|
t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,385 @@
|
||||||
|
package grpcurl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/jsonpb"
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/jhump/protoreflect/desc"
|
||||||
|
"github.com/jhump/protoreflect/dynamic"
|
||||||
|
"github.com/jhump/protoreflect/dynamic/grpcdynamic"
|
||||||
|
"github.com/jhump/protoreflect/grpcreflect"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InvocationEventHandler is a bag of callbacks for handling events that occur in the course
|
||||||
|
// of invoking an RPC. The handler also provides request data that is sent. The callbacks are
|
||||||
|
// generally called in the order they are listed below.
|
||||||
|
type InvocationEventHandler interface {
|
||||||
|
// OnResolveMethod is called with a descriptor of the method that is being invoked.
|
||||||
|
OnResolveMethod(*desc.MethodDescriptor)
|
||||||
|
// OnSendHeaders is called with the request metadata that is being sent.
|
||||||
|
OnSendHeaders(metadata.MD)
|
||||||
|
// OnReceiveHeaders is called when response headers have been received.
|
||||||
|
OnReceiveHeaders(metadata.MD)
|
||||||
|
// OnReceiveResponse is called for each response message received.
|
||||||
|
OnReceiveResponse(proto.Message)
|
||||||
|
// OnReceiveTrailers is called when response trailers and final RPC status have been received.
|
||||||
|
OnReceiveTrailers(*status.Status, metadata.MD)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestMessageSupplier is a function that is called to retrieve request
|
||||||
|
// messages for a GRPC operation. This type is deprecated and will be removed in
|
||||||
|
// a future release.
|
||||||
|
//
|
||||||
|
// Deprecated: This is only used with the deprecated InvokeRpc. Instead, use
|
||||||
|
// RequestSupplier with InvokeRPC.
|
||||||
|
type RequestMessageSupplier func() ([]byte, error)
|
||||||
|
|
||||||
|
// InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated
|
||||||
|
// and will be removed in a future release. It just delegates to the similarly named InvokeRPC
|
||||||
|
// method, whose signature is only slightly different.
|
||||||
|
//
|
||||||
|
// Deprecated: use InvokeRPC instead.
|
||||||
|
func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string,
|
||||||
|
headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error {
|
||||||
|
|
||||||
|
return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error {
|
||||||
|
// New function is almost identical, but the request supplier function works differently.
|
||||||
|
// So we adapt the logic here to maintain compatibility.
|
||||||
|
data, err := requestData()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return jsonpb.Unmarshal(bytes.NewReader(data), m)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestSupplier is a function that is called to populate messages for a gRPC operation. The
|
||||||
|
// function should populate the given message or return a non-nil error. If the supplier has no
|
||||||
|
// more messages, it should return io.EOF. When it returns io.EOF, it should not in any way
|
||||||
|
// modify the given message argument.
|
||||||
|
type RequestSupplier func(proto.Message) error
|
||||||
|
|
||||||
|
// InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source
|
||||||
|
// is used to determine the type of method and the type of request and response message. The given
|
||||||
|
// headers are sent as request metadata. Methods on the given event handler are called as the
|
||||||
|
// invocation proceeds.
|
||||||
|
//
|
||||||
|
// The given requestData function supplies the actual data to send. It should return io.EOF when
|
||||||
|
// there is no more request data. If the method being invoked is a unary or server-streaming RPC
|
||||||
|
// (e.g. exactly one request message) and there is no request data (e.g. the first invocation of
|
||||||
|
// the function returns io.EOF), then an empty request message is sent.
|
||||||
|
//
|
||||||
|
// If the requestData function and the given event handler coordinate or share any state, they should
|
||||||
|
// be thread-safe. This is because the requestData function may be called from a different goroutine
|
||||||
|
// than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where
|
||||||
|
// one goroutine sends request messages and another consumes the response messages).
|
||||||
|
func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string,
|
||||||
|
headers []string, handler InvocationEventHandler, requestData RequestSupplier) error {
|
||||||
|
|
||||||
|
md := MetadataFromHeaders(headers)
|
||||||
|
|
||||||
|
svc, mth := parseSymbol(methodName)
|
||||||
|
if svc == "" || mth == "" {
|
||||||
|
return fmt.Errorf("given method name %q is not in expected format: 'service/method' or 'service.method'", methodName)
|
||||||
|
}
|
||||||
|
dsc, err := source.FindSymbol(svc)
|
||||||
|
if err != nil {
|
||||||
|
if isNotFoundError(err) {
|
||||||
|
return fmt.Errorf("target server does not expose service %q", svc)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to query for service descriptor %q: %v", svc, err)
|
||||||
|
}
|
||||||
|
sd, ok := dsc.(*desc.ServiceDescriptor)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("target server does not expose service %q", svc)
|
||||||
|
}
|
||||||
|
mtd := sd.FindMethodByName(mth)
|
||||||
|
if mtd == nil {
|
||||||
|
return fmt.Errorf("service %q does not include a method named %q", svc, mth)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.OnResolveMethod(mtd)
|
||||||
|
|
||||||
|
// we also download any applicable extensions so we can provide full support for parsing user-provided data
|
||||||
|
var ext dynamic.ExtensionRegistry
|
||||||
|
alreadyFetched := map[string]bool{}
|
||||||
|
if err = fetchAllExtensions(source, &ext, mtd.GetInputType(), alreadyFetched); err != nil {
|
||||||
|
return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetInputType().GetFullyQualifiedName(), err)
|
||||||
|
}
|
||||||
|
if err = fetchAllExtensions(source, &ext, mtd.GetOutputType(), alreadyFetched); err != nil {
|
||||||
|
return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetOutputType().GetFullyQualifiedName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
|
||||||
|
req := msgFactory.NewMessage(mtd.GetInputType())
|
||||||
|
|
||||||
|
handler.OnSendHeaders(md)
|
||||||
|
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||||
|
|
||||||
|
stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory)
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if mtd.IsClientStreaming() && mtd.IsServerStreaming() {
|
||||||
|
return invokeBidi(ctx, stub, mtd, handler, requestData, req)
|
||||||
|
} else if mtd.IsClientStreaming() {
|
||||||
|
return invokeClientStream(ctx, stub, mtd, handler, requestData, req)
|
||||||
|
} else if mtd.IsServerStreaming() {
|
||||||
|
return invokeServerStream(ctx, stub, mtd, handler, requestData, req)
|
||||||
|
} else {
|
||||||
|
return invokeUnary(ctx, stub, mtd, handler, requestData, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
||||||
|
requestData RequestSupplier, req proto.Message) error {
|
||||||
|
|
||||||
|
err := requestData(req)
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return fmt.Errorf("error getting request data: %v", err)
|
||||||
|
}
|
||||||
|
if err != io.EOF {
|
||||||
|
// verify there is no second message, which is a usage error
|
||||||
|
err := requestData(req)
|
||||||
|
if err == nil {
|
||||||
|
return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
|
||||||
|
} else if err != io.EOF {
|
||||||
|
return fmt.Errorf("error getting request data: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we can actually invoke the RPC!
|
||||||
|
var respHeaders metadata.MD
|
||||||
|
var respTrailers metadata.MD
|
||||||
|
resp, err := stub.InvokeRpc(ctx, md, req, grpc.Trailer(&respTrailers), grpc.Header(&respHeaders))
|
||||||
|
|
||||||
|
stat, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
// Error codes sent from the server will get printed differently below.
|
||||||
|
// So just bail for other kinds of errors here.
|
||||||
|
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.OnReceiveHeaders(respHeaders)
|
||||||
|
|
||||||
|
if stat.Code() == codes.OK {
|
||||||
|
handler.OnReceiveResponse(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.OnReceiveTrailers(stat, respTrailers)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
||||||
|
requestData RequestSupplier, req proto.Message) error {
|
||||||
|
|
||||||
|
// invoke the RPC!
|
||||||
|
str, err := stub.InvokeRpcClientStream(ctx, md)
|
||||||
|
|
||||||
|
// Upload each request message in the stream
|
||||||
|
var resp proto.Message
|
||||||
|
for err == nil {
|
||||||
|
err = requestData(req)
|
||||||
|
if err == io.EOF {
|
||||||
|
resp, err = str.CloseAndReceive()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting request data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = str.SendMsg(req)
|
||||||
|
if err == io.EOF {
|
||||||
|
// We get EOF on send if the server says "go away"
|
||||||
|
// We have to use CloseAndReceive to get the actual code
|
||||||
|
resp, err = str.CloseAndReceive()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
// finally, process response data
|
||||||
|
stat, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
// Error codes sent from the server will get printed differently below.
|
||||||
|
// So just bail for other kinds of errors here.
|
||||||
|
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if respHeaders, err := str.Header(); err == nil {
|
||||||
|
handler.OnReceiveHeaders(respHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stat.Code() == codes.OK {
|
||||||
|
handler.OnReceiveResponse(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.OnReceiveTrailers(stat, str.Trailer())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
||||||
|
requestData RequestSupplier, req proto.Message) error {
|
||||||
|
|
||||||
|
err := requestData(req)
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return fmt.Errorf("error getting request data: %v", err)
|
||||||
|
}
|
||||||
|
if err != io.EOF {
|
||||||
|
// verify there is no second message, which is a usage error
|
||||||
|
err := requestData(req)
|
||||||
|
if err == nil {
|
||||||
|
return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
|
||||||
|
} else if err != io.EOF {
|
||||||
|
return fmt.Errorf("error getting request data: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we can actually invoke the RPC!
|
||||||
|
str, err := stub.InvokeRpcServerStream(ctx, md, req)
|
||||||
|
|
||||||
|
if respHeaders, err := str.Header(); err == nil {
|
||||||
|
handler.OnReceiveHeaders(respHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download each response message
|
||||||
|
for err == nil {
|
||||||
|
var resp proto.Message
|
||||||
|
resp, err = str.RecvMsg()
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
handler.OnReceiveResponse(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
// Error codes sent from the server will get printed differently below.
|
||||||
|
// So just bail for other kinds of errors here.
|
||||||
|
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.OnReceiveTrailers(stat, str.Trailer())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
||||||
|
requestData RequestSupplier, req proto.Message) error {
|
||||||
|
|
||||||
|
// invoke the RPC!
|
||||||
|
str, err := stub.InvokeRpcBidiStream(ctx, md)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var sendErr atomic.Value
|
||||||
|
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Concurrently upload each request message in the stream
|
||||||
|
var err error
|
||||||
|
for err == nil {
|
||||||
|
err = requestData(req)
|
||||||
|
|
||||||
|
if err == io.EOF {
|
||||||
|
err = str.CloseSend()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error getting request data: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
err = str.SendMsg(req)
|
||||||
|
|
||||||
|
req.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
sendErr.Store(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if respHeaders, err := str.Header(); err == nil {
|
||||||
|
handler.OnReceiveHeaders(respHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download each response message
|
||||||
|
for err == nil {
|
||||||
|
var resp proto.Message
|
||||||
|
resp, err = str.RecvMsg()
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
handler.OnReceiveResponse(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if se, ok := sendErr.Load().(error); ok && se != io.EOF {
|
||||||
|
err = se
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
// Error codes sent from the server will get printed differently below.
|
||||||
|
// So just bail for other kinds of errors here.
|
||||||
|
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.OnReceiveTrailers(stat, str.Trailer())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type notFoundError string
|
||||||
|
|
||||||
|
func notFound(kind, name string) error {
|
||||||
|
return notFoundError(fmt.Sprintf("%s not found: %s", kind, name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e notFoundError) Error() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNotFoundError(err error) bool {
|
||||||
|
if grpcreflect.IsElementNotFoundError(err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, ok := err.(notFoundError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSymbol(svcAndMethod string) (string, string) {
|
||||||
|
pos := strings.LastIndex(svcAndMethod, "/")
|
||||||
|
if pos < 0 {
|
||||||
|
pos = strings.LastIndex(svcAndMethod, ".")
|
||||||
|
if pos < 0 {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return svcAndMethod[:pos], svcAndMethod[pos+1:]
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue