add option to support text format (#54)

* augments grpcurl package API in order to handle multiple formats
* deprecates old signature for InvokeRpc
* add command-line flag to use protobuf text format instead of JSON
* use AnyResolver when marshaling to/from JSON
This commit is contained in:
Joshua Humphries 2018-10-16 21:26:16 -04:00 committed by GitHub
parent 397a8c18ca
commit e00ef3eb7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 752 additions and 104 deletions

View File

@ -0,0 +1,303 @@
package main
import (
"bytes"
"fmt"
"io"
"strings"
"testing"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes/struct"
"github.com/jhump/protoreflect/desc"
"google.golang.org/grpc/metadata"
"github.com/fullstorydev/grpcurl"
)
func TestRequestFactory(t *testing.T) {
source, err := grpcurl.DescriptorSourceFromProtoSets("../../testing/example.protoset")
if err != nil {
t.Fatalf("failed to create descriptor source: %v", err)
}
msg, err := makeProto()
if err != nil {
t.Fatalf("failed to create message: %v", err)
}
testCases := []struct {
format string
input string
expectedOutput []proto.Message
}{
{
format: "json",
input: "",
},
{
format: "json",
input: messageAsJSON,
expectedOutput: []proto.Message{msg},
},
{
format: "json",
input: messageAsJSON + messageAsJSON + messageAsJSON,
expectedOutput: []proto.Message{msg, msg, msg},
},
{
// unlike JSON, empty input yields one empty message (vs. zero messages)
format: "text",
input: "",
expectedOutput: []proto.Message{&structpb.Value{}},
},
{
format: "text",
input: messageAsText,
expectedOutput: []proto.Message{msg},
},
{
format: "text",
input: messageAsText + string(textSeparatorChar),
expectedOutput: []proto.Message{msg, &structpb.Value{}},
},
{
format: "text",
input: messageAsText + string(textSeparatorChar) + messageAsText + string(textSeparatorChar) + messageAsText,
expectedOutput: []proto.Message{msg, msg, msg},
},
}
for i, tc := range testCases {
name := fmt.Sprintf("#%d, %s, %d message(s)", i+1, tc.format, len(tc.expectedOutput))
rf, _ := formatDetails(tc.format, source, false, strings.NewReader(tc.input))
numReqs := 0
for {
var req structpb.Value
err := rf.next(&req)
if err == io.EOF {
break
} else if err != nil {
t.Errorf("%s, msg %d: unexpected error: %v", name, numReqs, err)
}
if !proto.Equal(&req, tc.expectedOutput[numReqs]) {
t.Errorf("%s, msg %d: incorrect message;\nexpecting:\n%v\ngot:\n%v", name, numReqs, tc.expectedOutput[numReqs], &req)
}
numReqs++
}
if rf.numRequests() != numReqs {
t.Errorf("%s: factory reported wrong number of requests: expecting %d, got %d", name, numReqs, rf.numRequests())
}
}
}
// Handler prints response data (and headers/trailers in verbose mode).
// This verifies that we get the right output in both JSON and proto text modes.
func TestHandler(t *testing.T) {
source, err := grpcurl.DescriptorSourceFromProtoSets("../../testing/example.protoset")
if err != nil {
t.Fatalf("failed to create descriptor source: %v", err)
}
d, err := source.FindSymbol("TestService.GetFiles")
if err != nil {
t.Fatalf("failed to find method 'TestService.GetFiles': %v", err)
}
md, ok := d.(*desc.MethodDescriptor)
if !ok {
t.Fatalf("wrong kind of descriptor found: %T", d)
}
reqHeaders := metadata.Pairs("foo", "123", "bar", "456")
respHeaders := metadata.Pairs("foo", "abc", "bar", "def", "baz", "xyz")
respTrailers := metadata.Pairs("a", "1", "b", "2", "c", "3")
rsp, err := makeProto()
if err != nil {
t.Fatalf("failed to create response message: %v", err)
}
for _, format := range []string{"json", "text"} {
for _, numMessages := range []int{1, 3} {
for _, verbose := range []bool{true, false} {
name := fmt.Sprintf("%s, %d message(s)", format, numMessages)
if verbose {
name += ", verbose"
}
_, formatter := formatDetails(format, source, verbose, nil)
var buf bytes.Buffer
h := handler{
out: &buf,
descSource: source,
verbose: verbose,
formatter: formatter,
}
h.OnResolveMethod(md)
h.OnSendHeaders(reqHeaders)
h.OnReceiveHeaders(respHeaders)
for i := 0; i < numMessages; i++ {
h.OnReceiveResponse(rsp)
}
h.OnReceiveTrailers(nil, respTrailers)
expectedOutput := ""
if verbose {
expectedOutput += verbosePrefix
}
for i := 0; i < numMessages; i++ {
if verbose {
expectedOutput += verboseResponseHeader
}
if format == "json" {
expectedOutput += messageAsJSON
} else {
if i > 0 && !verbose {
expectedOutput += string(textSeparatorChar)
}
expectedOutput += messageAsText
}
}
if verbose {
expectedOutput += verboseSuffix
}
out := buf.String()
if !compare(out, expectedOutput) {
t.Errorf("%s: Incorrect output.", name) // Expected:\n%s\nGot:\n%s", name, expectedOutput, out)
}
}
}
}
}
// compare checks that actual and expected are equal, returning true if so.
// A simple equality check (==) does not suffice because jsonpb formats
// structpb.Value strangely. So if that formatting gets fixed, we don't
// want this test in grpcurl to suddenly start failing. So we check each
// line and compare the lines after stripping whitespace (which removes
// the jsonpb format anomalies).
func compare(actual, expected string) bool {
actualLines := strings.Split(actual, "\n")
expectedLines := strings.Split(expected, "\n")
if len(actualLines) != len(expectedLines) {
return false
}
for i := 0; i < len(actualLines); i++ {
if strings.TrimSpace(actualLines[i]) != strings.TrimSpace(expectedLines[i]) {
return false
}
}
return true
}
func makeProto() (proto.Message, error) {
var rsp structpb.Value
err := jsonpb.UnmarshalString(`{
"foo": ["abc", "def", "ghi"],
"bar": { "a": 1, "b": 2 },
"baz": true,
"null": null
}`, &rsp)
if err != nil {
return nil, err
}
return &rsp, nil
}
var (
verbosePrefix = `
Resolved method descriptor:
{
"name": "GetFiles",
"inputType": ".TestRequest",
"outputType": ".TestResponse",
"options": {
}
}
Request metadata to send:
bar: 456
foo: 123
Response headers received:
bar: def
baz: xyz
foo: abc
`
verboseSuffix = `
Response trailers received:
a: 1
b: 2
c: 3
`
verboseResponseHeader = `
Response contents:
`
messageAsJSON = `{
"bar": {
"a": 1,
"b": 2
},
"baz": true,
"foo": [
"abc",
"def",
"ghi"
],
"null": null
}
`
messageAsText = `struct_value: <
fields: <
key: "bar"
value: <
struct_value: <
fields: <
key: "a"
value: <
number_value: 1
>
>
fields: <
key: "b"
value: <
number_value: 2
>
>
>
>
>
fields: <
key: "baz"
value: <
bool_value: true
>
>
fields: <
key: "foo"
value: <
list_value: <
values: <
string_value: "abc"
>
values: <
string_value: "def"
>
values: <
string_value: "ghi"
>
>
>
>
fields: <
key: "null"
value: <
null_value: NULL_VALUE
>
>
>
`
)

View File

@ -1,12 +1,15 @@
// Command grpcurl makes GRPC requests (a la cURL, but HTTP/2). It can use a supplied descriptor file or // Command grpcurl makes GRPC requests (a la cURL, but HTTP/2). It can use a supplied descriptor
// service reflection to translate JSON request data into the appropriate protobuf request data and vice // file, protobuf sources, or service reflection to translate JSON or text request data into the
// versa for presenting the response contents. // appropriate protobuf messages and vice versa for presenting the response contents.
package main package main
import ( import (
"bufio"
"bytes"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@ -62,12 +65,22 @@ var (
rpcHeaders multiString rpcHeaders multiString
reflHeaders multiString reflHeaders multiString
authority = flag.String("authority", "", authority = flag.String("authority", "",
":authority pseudo header value to be passed along with underlying HTTP/2 requests. It defaults to `host [ \":\" port ]` part of the target url.") `:authority pseudo header value to be passed along with underlying HTTP/2
requests. It defaults to 'host [ ":" port ]' part of the target url.`)
data = flag.String("d", "", data = flag.String("d", "",
`JSON request contents. If the value is '@' then the request contents are `Data for request contents. If the value is '@' then the request contents
read from stdin. For calls that accept a stream of requests, the are read from stdin. For calls that accept a stream of requests, the
contents should include all such request messages concatenated together contents should include all such request messages concatenated together
(optionally separated by whitespace).`) (optionally separated by whitespace).`)
format = flag.String("format", "json",
`The format of request data. The allowed values are 'json' or 'text'. For
'json', the input data must be in JSON format. Multiple request values may
be concatenated (messages with a JSON representation other than object
must be separated by whitespace, such as a newline). For 'text', the input
data must be in the protobuf text format, in which case multiple request
values must be separated by the "record separate" ASCII character: 0x1E.
The stream should not end in a record separator. If it does, it will be
interpreted as a final, blank message after the separator.`)
connectTimeout = flag.String("connect-timeout", "", connectTimeout = flag.String("connect-timeout", "",
`The maximum time, in seconds, to wait for connection to be established. `The maximum time, in seconds, to wait for connection to be established.
Defaults to 10 seconds.`) Defaults to 10 seconds.`)
@ -81,9 +94,9 @@ var (
preventing batch jobs that use grpcurl from hanging due to slow or bad preventing batch jobs that use grpcurl from hanging due to slow or bad
network links or due to incorrect stream method usage.`) network links or due to incorrect stream method usage.`)
emitDefaults = flag.Bool("emit-defaults", false, emitDefaults = flag.Bool("emit-defaults", false,
`Emit default values from JSON-encoded responses.`) `Emit default values for JSON-encoded responses.`)
msgTemplate = flag.Bool("msg-template", false, msgTemplate = flag.Bool("msg-template", false,
`When describing messages, show a JSON template for the message type.`) `When describing messages, show a template of input data.`)
verbose = flag.Bool("v", false, verbose = flag.Bool("v", false,
`Enable verbose output.`) `Enable verbose output.`)
serverName = flag.String("servername", "", "Override servername when validating TLS certificate.") serverName = flag.String("servername", "", "Override servername when validating TLS certificate.")
@ -168,6 +181,9 @@ func main() {
if (*key == "") != (*cert == "") { if (*key == "") != (*cert == "") {
fail(nil, "The -cert and -key arguments must be used together and both be present.") fail(nil, "The -cert and -key arguments must be used together and both be present.")
} }
if *format != "json" && *format != "text" {
fail(nil, "The -format option must be 'json' or 'text.")
}
args := flag.Args() args := flag.Args()
@ -417,11 +433,18 @@ func main() {
// create a request to invoke an RPC // create a request to invoke an RPC
tmpl := makeTemplate(dynamic.NewMessage(dsc)) tmpl := makeTemplate(dynamic.NewMessage(dsc))
fmt.Println("\nMessage template:") fmt.Println("\nMessage template:")
if *format == "json" {
jsm := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} jsm := jsonpb.Marshaler{Indent: " ", EmitDefaults: true}
err := jsm.Marshal(os.Stdout, tmpl) err := jsm.Marshal(os.Stdout, tmpl)
if err != nil { if err != nil {
fail(err, "Failed to print template for message %s", s) fail(err, "Failed to print template for message %s", s)
} }
} else /* *format == "text" */ {
err := proto.MarshalText(os.Stdout, tmpl)
if err != nil {
fail(err, "Failed to print template for message %s", s)
}
}
fmt.Println() fmt.Println()
} }
} }
@ -431,28 +454,36 @@ func main() {
if cc == nil { if cc == nil {
cc = dial() cc = dial()
} }
var dec *json.Decoder var in io.Reader
if *data == "@" { if *data == "@" {
dec = json.NewDecoder(os.Stdin) in = os.Stdin
} else { } else {
dec = json.NewDecoder(strings.NewReader(*data)) in = strings.NewReader(*data)
} }
h := &handler{dec: dec, descSource: descSource} rf, formatter := formatDetails(*format, descSource, *verbose, in)
err := grpcurl.InvokeRpc(ctx, descSource, cc, symbol, append(addlHeaders, rpcHeaders...), h, h.getRequestData) h := handler{
out: os.Stdout,
descSource: descSource,
formatter: formatter,
verbose: *verbose,
}
err := grpcurl.InvokeRPC(ctx, descSource, cc, symbol, append(addlHeaders, rpcHeaders...), &h, rf.next)
if err != nil { if err != nil {
fail(err, "Error invoking method %q", symbol) fail(err, "Error invoking method %q", symbol)
} }
reqSuffix := "" reqSuffix := ""
respSuffix := "" respSuffix := ""
if h.reqCount != 1 { reqCount := rf.numRequests()
if reqCount != 1 {
reqSuffix = "s" reqSuffix = "s"
} }
if h.respCount != 1 { if h.respCount != 1 {
respSuffix = "s" respSuffix = "s"
} }
if *verbose { if *verbose {
fmt.Printf("Sent %d request%s and received %d response%s\n", h.reqCount, reqSuffix, h.respCount, respSuffix) fmt.Printf("Sent %d request%s and received %d response%s\n", reqCount, reqSuffix, h.respCount, respSuffix)
} }
if h.stat.Code() != codes.OK { if h.stat.Code() != codes.OK {
fmt.Fprintf(os.Stderr, "ERROR:\n Code: %s\n Message: %s\n", h.stat.Code().String(), h.stat.Message()) fmt.Fprintf(os.Stderr, "ERROR:\n Code: %s\n Message: %s\n", h.stat.Code().String(), h.stat.Message())
@ -512,64 +543,88 @@ func fail(err error, msg string, args ...interface{}) {
} }
} }
func anyResolver(source grpcurl.DescriptorSource) (jsonpb.AnyResolver, error) {
files, err := grpcurl.GetAllFiles(source)
if err != nil {
return nil, err
}
var er dynamic.ExtensionRegistry
for _, fd := range files {
er.AddExtensionsFromFile(fd)
}
mf := dynamic.NewMessageFactoryWithExtensionRegistry(&er)
return dynamic.AnyResolver(mf, files...), nil
}
func formatDetails(format string, descSource grpcurl.DescriptorSource, verbose bool, in io.Reader) (requestFactory, func(proto.Message) (string, error)) {
if format == "json" {
resolver, err := anyResolver(descSource)
if err != nil {
fail(err, "Error creating message resolver")
}
marshaler := jsonpb.Marshaler{
EmitDefaults: *emitDefaults,
Indent: " ",
AnyResolver: resolver,
}
return newJsonFactory(in, resolver), marshaler.MarshalToString
}
/* else *format == "text" */
// if not verbose output, then also include record delimiters
// before each message (other than the first) so output could
// potentially piped to another grpcurl process
tf := textFormatter{useSeparator: !verbose}
return newTextFactory(in), tf.format
}
type handler struct { type handler struct {
dec *json.Decoder out io.Writer
descSource grpcurl.DescriptorSource descSource grpcurl.DescriptorSource
reqCount int
respCount int respCount int
stat *status.Status stat *status.Status
formatter func(proto.Message) (string, error)
verbose bool
} }
func (h *handler) OnResolveMethod(md *desc.MethodDescriptor) { func (h *handler) OnResolveMethod(md *desc.MethodDescriptor) {
if *verbose { if h.verbose {
txt, err := grpcurl.GetDescriptorText(md, h.descSource) txt, err := grpcurl.GetDescriptorText(md, h.descSource)
if err == nil { if err == nil {
fmt.Printf("\nResolved method descriptor:\n%s\n", txt) fmt.Fprintf(h.out, "\nResolved method descriptor:\n%s\n", txt)
} }
} }
} }
func (*handler) OnSendHeaders(md metadata.MD) { func (h *handler) OnSendHeaders(md metadata.MD) {
if *verbose { if h.verbose {
fmt.Printf("\nRequest metadata to send:\n%s\n", grpcurl.MetadataToString(md)) fmt.Fprintf(h.out, "\nRequest metadata to send:\n%s\n", grpcurl.MetadataToString(md))
} }
} }
func (h *handler) getRequestData() ([]byte, error) { func (h *handler) OnReceiveHeaders(md metadata.MD) {
// we don't use a mutex, though this methods will be called from different goroutine if h.verbose {
// than other methods for bidi calls, because this method does not share any state fmt.Fprintf(h.out, "\nResponse headers received:\n%s\n", grpcurl.MetadataToString(md))
// with the other methods.
var msg json.RawMessage
if err := h.dec.Decode(&msg); err != nil {
return nil, err
}
h.reqCount++
return msg, nil
}
func (*handler) OnReceiveHeaders(md metadata.MD) {
if *verbose {
fmt.Printf("\nResponse headers received:\n%s\n", grpcurl.MetadataToString(md))
} }
} }
func (h *handler) OnReceiveResponse(resp proto.Message) { func (h *handler) OnReceiveResponse(resp proto.Message) {
h.respCount++ h.respCount++
if *verbose { if h.verbose {
fmt.Print("\nResponse contents:\n") fmt.Fprint(h.out, "\nResponse contents:\n")
} }
jsm := jsonpb.Marshaler{EmitDefaults: *emitDefaults, Indent: " "} respStr, err := h.formatter(resp)
respStr, err := jsm.MarshalToString(resp)
if err != nil { if err != nil {
fail(err, "failed to generate JSON form of response message") fail(err, "failed to generate %s form of response message", *format)
} }
fmt.Println(respStr) fmt.Fprintln(h.out, respStr)
} }
func (h *handler) OnReceiveTrailers(stat *status.Status, md metadata.MD) { func (h *handler) OnReceiveTrailers(stat *status.Status, md metadata.MD) {
h.stat = stat h.stat = stat
if *verbose { if h.verbose {
fmt.Printf("\nResponse trailers received:\n%s\n", grpcurl.MetadataToString(md)) fmt.Fprintf(h.out, "\nResponse trailers received:\n%s\n", grpcurl.MetadataToString(md))
} }
} }
@ -633,3 +688,116 @@ func makeTemplate(msg proto.Message) proto.Message {
} }
return dm return dm
} }
type requestFactory interface {
next(proto.Message) error
numRequests() int
}
type jsonFactory struct {
dec *json.Decoder
unmarshaler jsonpb.Unmarshaler
requestCount int
}
func newJsonFactory(in io.Reader, resolver jsonpb.AnyResolver) *jsonFactory {
return &jsonFactory{
dec: json.NewDecoder(in),
unmarshaler: jsonpb.Unmarshaler{AnyResolver: resolver},
}
}
func (f *jsonFactory) next(m proto.Message) error {
var msg json.RawMessage
if err := f.dec.Decode(&msg); err != nil {
return err
}
f.requestCount++
return f.unmarshaler.Unmarshal(bytes.NewReader(msg), m)
}
func (f *jsonFactory) numRequests() int {
return f.requestCount
}
const (
textSeparatorChar = 0x1e
)
type textFactory struct {
r *bufio.Reader
err error
requestCount int
}
func newTextFactory(in io.Reader) *textFactory {
return &textFactory{r: bufio.NewReader(in)}
}
func (f *textFactory) next(m proto.Message) error {
if f.err != nil {
return f.err
}
var b []byte
b, f.err = f.r.ReadBytes(textSeparatorChar)
if f.err != nil && f.err != io.EOF {
return f.err
}
// remove delimiter
if len(b) > 0 && b[len(b)-1] == textSeparatorChar {
b = b[:len(b)-1]
}
f.requestCount++
return proto.UnmarshalText(string(b), m)
}
func (f *textFactory) numRequests() int {
return f.requestCount
}
type textFormatter struct {
useSeparator bool
numFormatted int
}
func (tf *textFormatter) format(m proto.Message) (string, error) {
var buf bytes.Buffer
if tf.useSeparator && tf.numFormatted > 0 {
if err := buf.WriteByte(textSeparatorChar); err != nil {
return "", err
}
}
// If message implements MarshalText method (such as a *dynamic.Message),
// it won't get details about whether or not to format to text compactly
// or with indentation. So first see if the message also implements a
// MarshalTextIndent method and use that instead if available.
type indentMarshaler interface {
MarshalTextIndent() ([]byte, error)
}
if indenter, ok := m.(indentMarshaler); ok {
b, err := indenter.MarshalTextIndent()
if err != nil {
return "", err
}
if _, err := buf.Write(b); err != nil {
return "", err
}
} else if err := proto.MarshalText(&buf, m); err != nil {
return "", err
}
// no trailing newline needed
str := buf.String()
if str[len(str)-1] == '\n' {
str = str[:len(str)-1]
}
tf.numFormatted++
return str, nil
}

View File

@ -48,7 +48,7 @@ var ErrReflectionNotSupported = errors.New("server does not support the reflecti
// proto (like a file generated by protoc) or a remote server that supports the reflection API. // proto (like a file generated by protoc) or a remote server that supports the reflection API.
type DescriptorSource interface { type DescriptorSource interface {
// ListServices returns a list of fully-qualified service names. It will be all services in a set of // ListServices returns a list of fully-qualified service names. It will be all services in a set of
// descriptor files or the set of all services exposed by a GRPC server. // descriptor files or the set of all services exposed by a gRPC server.
ListServices() ([]string, error) ListServices() ([]string, error)
// FindSymbol returns a descriptor for the given fully-qualified symbol name. // FindSymbol returns a descriptor for the given fully-qualified symbol name.
FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error)
@ -181,6 +181,20 @@ func (fs *fileSource) ListServices() ([]string, error) {
return sl, nil return sl, nil
} }
// GetAllFiles returns all of the underlying file descriptors. This is
// more thorough and more efficient than the fallback strategy used by
// the GetAllFiles package method, for enumerating all files from a
// descriptor source.
func (fs *fileSource) GetAllFiles() ([]*desc.FileDescriptor, error) {
files := make([]*desc.FileDescriptor, len(fs.files))
i := 0
for _, fd := range fs.files {
files[i] = fd
i++
}
return files, nil
}
func (fs *fileSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) { func (fs *fileSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) {
for _, fd := range fs.files { for _, fd := range fs.files {
if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil { if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil {
@ -200,7 +214,7 @@ func (fs *fileSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescri
return fs.er.AllExtensionsForType(typeName), nil return fs.er.AllExtensionsForType(typeName), nil
} }
// DescriptorSourceFromServer creates a DescriptorSource that uses the given GRPC reflection client // DescriptorSourceFromServer creates a DescriptorSource that uses the given gRPC reflection client
// to interrogate a server for descriptor information. If the server does not support the reflection // to interrogate a server for descriptor information. If the server does not support the reflection
// API then the various DescriptorSource methods will return ErrReflectionNotSupported // API then the various DescriptorSource methods will return ErrReflectionNotSupported
func DescriptorSourceFromServer(ctx context.Context, refClient *grpcreflect.Client) DescriptorSource { func DescriptorSourceFromServer(ctx context.Context, refClient *grpcreflect.Client) DescriptorSource {
@ -265,6 +279,75 @@ func ListServices(source DescriptorSource) ([]string, error) {
return svcs, nil return svcs, nil
} }
type sourceWithFiles interface {
GetAllFiles() ([]*desc.FileDescriptor, error)
}
var _ sourceWithFiles = (*fileSource)(nil)
// GetAllFiles uses the given descriptor source to return a list of file descriptors.
func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) {
var files []*desc.FileDescriptor
srcFiles, ok := source.(sourceWithFiles)
if ok {
var err error
files, err = srcFiles.GetAllFiles()
if err != nil {
return nil, err
}
} else {
// Source does not implement GetAllFiles method, so use ListServices
// and grab files from there.
allFiles := map[string]*desc.FileDescriptor{}
svcNames, err := source.ListServices()
if err != nil {
return nil, err
}
for _, name := range svcNames {
d, err := source.FindSymbol(name)
if err != nil {
return nil, err
}
addAllFilesToSet(d.GetFile(), allFiles)
}
files = make([]*desc.FileDescriptor, len(allFiles))
i := 0
for _, fd := range allFiles {
files[i] = fd
i++
}
}
sort.Sort(filesByName(files))
return files, nil
}
type filesByName []*desc.FileDescriptor
func (f filesByName) Len() int {
return len(f)
}
func (f filesByName) Less(i, j int) bool {
return f[i].GetName() < f[j].GetName()
}
func (f filesByName) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}
func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) {
if _, ok := all[fd.GetName()]; ok {
// already added
return
}
all[fd.GetName()] = fd
for _, dep := range fd.GetDependencies() {
addAllFilesToSet(dep, all)
}
}
// ListMethods uses the given descriptor source to return a sorted list of method names // ListMethods uses the given descriptor source to return a sorted list of method names
// for the specified fully-qualified service name. // for the specified fully-qualified service name.
func ListMethods(source DescriptorSource, serviceName string) ([]string, error) { func ListMethods(source DescriptorSource, serviceName string) ([]string, error) {
@ -319,27 +402,54 @@ type InvocationEventHandler interface {
} }
// RequestMessageSupplier is a function that is called to retrieve request // RequestMessageSupplier is a function that is called to retrieve request
// messages for a GRPC operation. The message contents must be valid JSON. If // messages for a GRPC operation. This type is deprecated and will be removed in
// the supplier has no more messages, it should return nil, io.EOF. // a future release.
//
// Deprecated: This is only used with the deprecated InvokeRpc. Instead, use
// RequestSupplier with InvokeRPC.
type RequestMessageSupplier func() ([]byte, error) type RequestMessageSupplier func() ([]byte, error)
// InvokeRpc uses the given GRPC connection to invoke the given method. The given descriptor source // InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated
// and will be removed in a future release. It just delegates to the similarly named InvokeRPC
// method, whose signature is only slightly different.
//
// Deprecated: use InvokeRPC instead.
func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string,
headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error {
return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error {
// New function is almost identical, but the request supplier function works differently.
// So we adapt the logic here to maintain compatibility.
data, err := requestData()
if err != nil {
return err
}
return jsonpb.Unmarshal(bytes.NewReader(data), m)
})
}
// RequestSupplier is a function that is called to populate messages for a gRPC operation. The
// function should populate the given message or return a non-nil error. If the supplier has no
// more messages, it should return io.EOF. When it returns io.EOF, it should not in any way
// modify the given message argument.
type RequestSupplier func(proto.Message) error
// InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source
// is used to determine the type of method and the type of request and response message. The given // is used to determine the type of method and the type of request and response message. The given
// headers are sent as request metadata. Methods on the given event handler are called as the // headers are sent as request metadata. Methods on the given event handler are called as the
// invocation proceeds. // invocation proceeds.
// //
// The given requestData function supplies the actual data to send. It should return io.EOF when // The given requestData function supplies the actual data to send. It should return io.EOF when
// there is no more request data. If it returns a nil error then the returned JSON message should // there is no more request data. If the method being invoked is a unary or server-streaming RPC
// not be blank. If the method being invoked is a unary or server-streaming RPC (e.g. exactly one // (e.g. exactly one request message) and there is no request data (e.g. the first invocation of
// request message) and there is no request data (e.g. the first invocation of the function returns // the function returns io.EOF), then an empty request message is sent.
// io.EOF), then a blank request message is sent, as if the request data were an empty object: "{}".
// //
// If the requestData function and the given event handler coordinate or share any state, they should // If the requestData function and the given event handler coordinate or share any state, they should
// be thread-safe. This is because the requestData function may be called from a different goroutine // be thread-safe. This is because the requestData function may be called from a different goroutine
// than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where // than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where
// one goroutine sends request messages and another consumes the response messages). // one goroutine sends request messages and another consumes the response messages).
func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string, func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string,
headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error { headers []string, handler InvocationEventHandler, requestData RequestSupplier) error {
md := MetadataFromHeaders(headers) md := MetadataFromHeaders(headers)
@ -381,7 +491,7 @@ func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn
handler.OnSendHeaders(md) handler.OnSendHeaders(md)
ctx = metadata.NewOutgoingContext(ctx, md) ctx = metadata.NewOutgoingContext(ctx, md)
stub := grpcdynamic.NewStubWithMessageFactory(cc, msgFactory) stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@ -397,21 +507,15 @@ func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn
} }
func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
requestData RequestMessageSupplier, req proto.Message) error { requestData RequestSupplier, req proto.Message) error {
data, err := requestData() err := requestData(req)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return fmt.Errorf("error getting request data: %v", err) return fmt.Errorf("error getting request data: %v", err)
} }
if len(data) != 0 {
err = jsonpb.UnmarshalString(string(data), req)
if err != nil {
return fmt.Errorf("could not parse given request body as message of type %q: %v", md.GetInputType().GetFullyQualifiedName(), err)
}
}
if err != io.EOF { if err != io.EOF {
// verify there is no second message, which is a usage error // verify there is no second message, which is a usage error
_, err := requestData() err := requestData(req)
if err == nil { if err == nil {
return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
} else if err != io.EOF { } else if err != io.EOF {
@ -443,7 +547,7 @@ func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDesc
} }
func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
requestData RequestMessageSupplier, req proto.Message) error { requestData RequestSupplier, req proto.Message) error {
// invoke the RPC! // invoke the RPC!
str, err := stub.InvokeRpcClientStream(ctx, md) str, err := stub.InvokeRpcClientStream(ctx, md)
@ -451,8 +555,7 @@ func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met
// Upload each request message in the stream // Upload each request message in the stream
var resp proto.Message var resp proto.Message
for err == nil { for err == nil {
var data []byte err = requestData(req)
data, err = requestData()
if err == io.EOF { if err == io.EOF {
resp, err = str.CloseAndReceive() resp, err = str.CloseAndReceive()
break break
@ -460,12 +563,6 @@ func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met
if err != nil { if err != nil {
return fmt.Errorf("error getting request data: %v", err) return fmt.Errorf("error getting request data: %v", err)
} }
if len(data) != 0 {
err = jsonpb.UnmarshalString(string(data), req)
if err != nil {
return fmt.Errorf("could not parse given request body as message of type %q: %v", md.GetInputType().GetFullyQualifiedName(), err)
}
}
err = str.SendMsg(req) err = str.SendMsg(req)
if err == io.EOF { if err == io.EOF {
@ -500,21 +597,15 @@ func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met
} }
func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
requestData RequestMessageSupplier, req proto.Message) error { requestData RequestSupplier, req proto.Message) error {
data, err := requestData() err := requestData(req)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return fmt.Errorf("error getting request data: %v", err) return fmt.Errorf("error getting request data: %v", err)
} }
if len(data) != 0 {
err = jsonpb.UnmarshalString(string(data), req)
if err != nil {
return fmt.Errorf("could not parse given request body as message of type %q: %v", md.GetInputType().GetFullyQualifiedName(), err)
}
}
if err != io.EOF { if err != io.EOF {
// verify there is no second message, which is a usage error // verify there is no second message, which is a usage error
_, err := requestData() err := requestData(req)
if err == nil { if err == nil {
return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
} else if err != io.EOF { } else if err != io.EOF {
@ -555,7 +646,7 @@ func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met
} }
func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
requestData RequestMessageSupplier, req proto.Message) error { requestData RequestSupplier, req proto.Message) error {
// invoke the RPC! // invoke the RPC!
str, err := stub.InvokeRpcBidiStream(ctx, md) str, err := stub.InvokeRpcBidiStream(ctx, md)
@ -572,9 +663,8 @@ func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescr
// Concurrently upload each request message in the stream // Concurrently upload each request message in the stream
var err error var err error
var data []byte
for err == nil { for err == nil {
data, err = requestData() err = requestData(req)
if err == io.EOF { if err == io.EOF {
err = str.CloseSend() err = str.CloseSend()
@ -584,13 +674,6 @@ func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescr
err = fmt.Errorf("error getting request data: %v", err) err = fmt.Errorf("error getting request data: %v", err)
break break
} }
if len(data) != 0 {
err = jsonpb.UnmarshalString(string(data), req)
if err != nil {
err = fmt.Errorf("could not parse given request body as message of type %q: %v", md.GetInputType().GetFullyQualifiedName(), err)
break
}
}
err = str.SendMsg(req) err = str.SendMsg(req)
@ -700,16 +783,29 @@ func MetadataToString(md metadata.MD) string {
if len(md) == 0 { if len(md) == 0 {
return "(empty)" return "(empty)"
} }
keys := make([]string, 0, len(md))
for k := range md {
keys = append(keys, k)
}
sort.Strings(keys)
var b bytes.Buffer var b bytes.Buffer
for k, vs := range md { first := true
for _, k := range keys {
vs := md[k]
for _, v := range vs { for _, v := range vs {
if first {
first = false
} else {
b.WriteString("\n")
}
b.WriteString(k) b.WriteString(k)
b.WriteString(": ") b.WriteString(": ")
if strings.HasSuffix(k, "-bin") { if strings.HasSuffix(k, "-bin") {
v = base64.StdEncoding.EncodeToString([]byte(v)) v = base64.StdEncoding.EncodeToString([]byte(v))
} }
b.WriteString(v) b.WriteString(v)
b.WriteString("\n")
} }
} }
return b.String() return b.String()
@ -840,7 +936,7 @@ func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (
return dm, nil return dm, nil
} }
// ClientTransportCredentials builds transport credentials for a GRPC client using the // ClientTransportCredentials builds transport credentials for a gRPC client using the
// given properties. If cacertFile is blank, only standard trusted certs are used to // given properties. If cacertFile is blank, only standard trusted certs are used to
// verify the server certs. If clientCertFile is blank, the client will not use a client // verify the server certs. If clientCertFile is blank, the client will not use a client
// certificate. If clientCertFile is not blank then clientKeyFile must not be blank. // certificate. If clientCertFile is not blank then clientKeyFile must not be blank.
@ -877,7 +973,7 @@ func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertF
return credentials.NewTLS(&tlsConf), nil return credentials.NewTLS(&tlsConf), nil
} }
// ServerTransportCredentials builds transport credentials for a GRPC server using the // ServerTransportCredentials builds transport credentials for a gRPC server using the
// given properties. If cacertFile is blank, the server will not request client certs // given properties. If cacertFile is blank, the server will not request client certs
// unless requireClientCerts is true. When requireClientCerts is false and cacertFile is // unless requireClientCerts is true. When requireClientCerts is false and cacertFile is
// not blank, the server will verify client certs when presented, but will not require // not blank, the server will verify client certs when presented, but will not require

View File

@ -44,6 +44,10 @@ type descSourceCase struct {
includeRefl bool includeRefl bool
} }
// NB: These tests intentionally use the deprecated InvokeRpc since that
// calls the other (non-deprecated InvokeRPC). That allows the tests to
// easily exercise both functions.
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
var err error var err error
sourceProtoset, err = DescriptorSourceFromProtoSets("testing/test.protoset") sourceProtoset, err = DescriptorSourceFromProtoSets("testing/test.protoset")
@ -235,6 +239,73 @@ func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection
} }
} }
func TestGetAllFiles(t *testing.T) {
expectedFiles := []string{"testing/test.proto"}
// server reflection picks up filename from linked in Go package,
// which indicates "grpc_testing/test.proto", not our local copy.
expectedFilesWithReflection := []string{"grpc_reflection_v1alpha/reflection.proto", "grpc_testing/test.proto"}
for _, ds := range descSources {
t.Run(ds.name, func(t *testing.T) {
files, err := GetAllFiles(ds.source)
if err != nil {
t.Fatalf("failed to get all files: %v", err)
}
names := fileNames(files)
expected := expectedFiles
if ds.includeRefl {
expected = expectedFilesWithReflection
}
if !reflect.DeepEqual(expected, names) {
t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expected, names)
}
})
}
// try cases with more complicated set of files
otherSourceProtoset, err := DescriptorSourceFromProtoSets("testing/test.protoset", "testing/example.protoset")
if err != nil {
t.Fatal(err.Error())
}
otherSourceProtoFiles, err := DescriptorSourceFromProtoFiles(nil, "testing/test.proto", "testing/example.proto")
if err != nil {
t.Fatal(err.Error())
}
otherDescSources := []descSourceCase{
{"protoset[b]", otherSourceProtoset, false},
{"proto[b]", otherSourceProtoFiles, false},
}
expectedFiles = []string{
"google/protobuf/any.proto",
"google/protobuf/descriptor.proto",
"google/protobuf/empty.proto",
"google/protobuf/timestamp.proto",
"testing/example.proto",
"testing/example2.proto",
"testing/test.proto",
}
for _, ds := range otherDescSources {
t.Run(ds.name, func(t *testing.T) {
files, err := GetAllFiles(ds.source)
if err != nil {
t.Fatalf("failed to get all files: %v", err)
}
names := fileNames(files)
if !reflect.DeepEqual(expectedFiles, names) {
t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expectedFiles, names)
}
})
}
}
func fileNames(files []*desc.FileDescriptor) []string {
names := make([]string, len(files))
for i, f := range files {
names[i] = f.GetName()
}
return names
}
func TestDescribe(t *testing.T) { func TestDescribe(t *testing.T) {
for _, ds := range descSources { for _, ds := range descSources {
t.Run(ds.name, func(t *testing.T) { t.Run(ds.name, func(t *testing.T) {

View File

@ -7,8 +7,8 @@ cd "$(dirname $0)"
# Run this script to generate files used by tests. # Run this script to generate files used by tests.
echo "Creating protosets..." echo "Creating protosets..."
protoc ../../../google.golang.org/grpc/interop/grpc_testing/test.proto \ protoc testing/test.proto \
-I../../../ --include_imports \ --include_imports \
--descriptor_set_out=testing/test.protoset --descriptor_set_out=testing/test.protoset
protoc testing/example.proto \ protoc testing/example.proto \

View File

@ -3,9 +3,11 @@ syntax = "proto3";
import "google/protobuf/descriptor.proto"; import "google/protobuf/descriptor.proto";
import "google/protobuf/empty.proto"; import "google/protobuf/empty.proto";
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";
import "testing/example2.proto";
message TestRequest { message TestRequest {
repeated string file_names = 1; repeated string file_names = 1;
repeated Extension extensions = 2;
} }
message TestResponse { message TestResponse {

Binary file not shown.

8
testing/example2.proto Normal file
View File

@ -0,0 +1,8 @@
syntax = "proto3";
import "google/protobuf/any.proto";
message Extension {
uint64 id = 1;
google.protobuf.Any data = 2;
}

Binary file not shown.