diff --git a/cmd/grpcurl/formatting_test.go b/cmd/grpcurl/formatting_test.go new file mode 100644 index 0000000..bc69079 --- /dev/null +++ b/cmd/grpcurl/formatting_test.go @@ -0,0 +1,303 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "strings" + "testing" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes/struct" + "github.com/jhump/protoreflect/desc" + "google.golang.org/grpc/metadata" + + "github.com/fullstorydev/grpcurl" +) + +func TestRequestFactory(t *testing.T) { + source, err := grpcurl.DescriptorSourceFromProtoSets("../../testing/example.protoset") + if err != nil { + t.Fatalf("failed to create descriptor source: %v", err) + } + + msg, err := makeProto() + if err != nil { + t.Fatalf("failed to create message: %v", err) + } + + testCases := []struct { + format string + input string + expectedOutput []proto.Message + }{ + { + format: "json", + input: "", + }, + { + format: "json", + input: messageAsJSON, + expectedOutput: []proto.Message{msg}, + }, + { + format: "json", + input: messageAsJSON + messageAsJSON + messageAsJSON, + expectedOutput: []proto.Message{msg, msg, msg}, + }, + { + // unlike JSON, empty input yields one empty message (vs. zero messages) + format: "text", + input: "", + expectedOutput: []proto.Message{&structpb.Value{}}, + }, + { + format: "text", + input: messageAsText, + expectedOutput: []proto.Message{msg}, + }, + { + format: "text", + input: messageAsText + string(textSeparatorChar), + expectedOutput: []proto.Message{msg, &structpb.Value{}}, + }, + { + format: "text", + input: messageAsText + string(textSeparatorChar) + messageAsText + string(textSeparatorChar) + messageAsText, + expectedOutput: []proto.Message{msg, msg, msg}, + }, + } + + for i, tc := range testCases { + name := fmt.Sprintf("#%d, %s, %d message(s)", i+1, tc.format, len(tc.expectedOutput)) + rf, _ := formatDetails(tc.format, source, false, strings.NewReader(tc.input)) + numReqs := 0 + for { + var req structpb.Value + err := rf.next(&req) + if err == io.EOF { + break + } else if err != nil { + t.Errorf("%s, msg %d: unexpected error: %v", name, numReqs, err) + } + if !proto.Equal(&req, tc.expectedOutput[numReqs]) { + t.Errorf("%s, msg %d: incorrect message;\nexpecting:\n%v\ngot:\n%v", name, numReqs, tc.expectedOutput[numReqs], &req) + } + numReqs++ + } + if rf.numRequests() != numReqs { + t.Errorf("%s: factory reported wrong number of requests: expecting %d, got %d", name, numReqs, rf.numRequests()) + } + } +} + +// Handler prints response data (and headers/trailers in verbose mode). +// This verifies that we get the right output in both JSON and proto text modes. +func TestHandler(t *testing.T) { + source, err := grpcurl.DescriptorSourceFromProtoSets("../../testing/example.protoset") + if err != nil { + t.Fatalf("failed to create descriptor source: %v", err) + } + d, err := source.FindSymbol("TestService.GetFiles") + if err != nil { + t.Fatalf("failed to find method 'TestService.GetFiles': %v", err) + } + md, ok := d.(*desc.MethodDescriptor) + if !ok { + t.Fatalf("wrong kind of descriptor found: %T", d) + } + + reqHeaders := metadata.Pairs("foo", "123", "bar", "456") + respHeaders := metadata.Pairs("foo", "abc", "bar", "def", "baz", "xyz") + respTrailers := metadata.Pairs("a", "1", "b", "2", "c", "3") + rsp, err := makeProto() + if err != nil { + t.Fatalf("failed to create response message: %v", err) + } + + for _, format := range []string{"json", "text"} { + for _, numMessages := range []int{1, 3} { + for _, verbose := range []bool{true, false} { + name := fmt.Sprintf("%s, %d message(s)", format, numMessages) + if verbose { + name += ", verbose" + } + + _, formatter := formatDetails(format, source, verbose, nil) + + var buf bytes.Buffer + h := handler{ + out: &buf, + descSource: source, + verbose: verbose, + formatter: formatter, + } + + h.OnResolveMethod(md) + h.OnSendHeaders(reqHeaders) + h.OnReceiveHeaders(respHeaders) + for i := 0; i < numMessages; i++ { + h.OnReceiveResponse(rsp) + } + h.OnReceiveTrailers(nil, respTrailers) + + expectedOutput := "" + if verbose { + expectedOutput += verbosePrefix + } + for i := 0; i < numMessages; i++ { + if verbose { + expectedOutput += verboseResponseHeader + } + if format == "json" { + expectedOutput += messageAsJSON + } else { + if i > 0 && !verbose { + expectedOutput += string(textSeparatorChar) + } + expectedOutput += messageAsText + } + } + if verbose { + expectedOutput += verboseSuffix + } + + out := buf.String() + if !compare(out, expectedOutput) { + t.Errorf("%s: Incorrect output.", name) // Expected:\n%s\nGot:\n%s", name, expectedOutput, out) + } + } + } + } +} + +// compare checks that actual and expected are equal, returning true if so. +// A simple equality check (==) does not suffice because jsonpb formats +// structpb.Value strangely. So if that formatting gets fixed, we don't +// want this test in grpcurl to suddenly start failing. So we check each +// line and compare the lines after stripping whitespace (which removes +// the jsonpb format anomalies). +func compare(actual, expected string) bool { + actualLines := strings.Split(actual, "\n") + expectedLines := strings.Split(expected, "\n") + if len(actualLines) != len(expectedLines) { + return false + } + for i := 0; i < len(actualLines); i++ { + if strings.TrimSpace(actualLines[i]) != strings.TrimSpace(expectedLines[i]) { + return false + } + } + return true +} + +func makeProto() (proto.Message, error) { + var rsp structpb.Value + err := jsonpb.UnmarshalString(`{ + "foo": ["abc", "def", "ghi"], + "bar": { "a": 1, "b": 2 }, + "baz": true, + "null": null + }`, &rsp) + if err != nil { + return nil, err + } + return &rsp, nil +} + +var ( + verbosePrefix = ` +Resolved method descriptor: +{ + "name": "GetFiles", + "inputType": ".TestRequest", + "outputType": ".TestResponse", + "options": { + + } +} + +Request metadata to send: +bar: 456 +foo: 123 + +Response headers received: +bar: def +baz: xyz +foo: abc +` + verboseSuffix = ` +Response trailers received: +a: 1 +b: 2 +c: 3 +` + verboseResponseHeader = ` +Response contents: +` + messageAsJSON = `{ + "bar": { + "a": 1, + "b": 2 + }, + "baz": true, + "foo": [ + "abc", + "def", + "ghi" + ], + "null": null +} +` + messageAsText = `struct_value: < + fields: < + key: "bar" + value: < + struct_value: < + fields: < + key: "a" + value: < + number_value: 1 + > + > + fields: < + key: "b" + value: < + number_value: 2 + > + > + > + > + > + fields: < + key: "baz" + value: < + bool_value: true + > + > + fields: < + key: "foo" + value: < + list_value: < + values: < + string_value: "abc" + > + values: < + string_value: "def" + > + values: < + string_value: "ghi" + > + > + > + > + fields: < + key: "null" + value: < + null_value: NULL_VALUE + > + > +> +` +) diff --git a/cmd/grpcurl/grpcurl.go b/cmd/grpcurl/grpcurl.go index 6d4d053..a1347b4 100644 --- a/cmd/grpcurl/grpcurl.go +++ b/cmd/grpcurl/grpcurl.go @@ -1,12 +1,15 @@ -// Command grpcurl makes GRPC requests (a la cURL, but HTTP/2). It can use a supplied descriptor file or -// service reflection to translate JSON request data into the appropriate protobuf request data and vice -// versa for presenting the response contents. +// Command grpcurl makes GRPC requests (a la cURL, but HTTP/2). It can use a supplied descriptor +// file, protobuf sources, or service reflection to translate JSON or text request data into the +// appropriate protobuf messages and vice versa for presenting the response contents. package main import ( + "bufio" + "bytes" "encoding/json" "flag" "fmt" + "io" "os" "strconv" "strings" @@ -62,12 +65,22 @@ var ( rpcHeaders multiString reflHeaders multiString authority = flag.String("authority", "", - ":authority pseudo header value to be passed along with underlying HTTP/2 requests. It defaults to `host [ \":\" port ]` part of the target url.") + `:authority pseudo header value to be passed along with underlying HTTP/2 + requests. It defaults to 'host [ ":" port ]' part of the target url.`) data = flag.String("d", "", - `JSON request contents. If the value is '@' then the request contents are - read from stdin. For calls that accept a stream of requests, the + `Data for request contents. If the value is '@' then the request contents + are read from stdin. For calls that accept a stream of requests, the contents should include all such request messages concatenated together (optionally separated by whitespace).`) + format = flag.String("format", "json", + `The format of request data. The allowed values are 'json' or 'text'. For + 'json', the input data must be in JSON format. Multiple request values may + be concatenated (messages with a JSON representation other than object + must be separated by whitespace, such as a newline). For 'text', the input + data must be in the protobuf text format, in which case multiple request + values must be separated by the "record separate" ASCII character: 0x1E. + The stream should not end in a record separator. If it does, it will be + interpreted as a final, blank message after the separator.`) connectTimeout = flag.String("connect-timeout", "", `The maximum time, in seconds, to wait for connection to be established. Defaults to 10 seconds.`) @@ -81,9 +94,9 @@ var ( preventing batch jobs that use grpcurl from hanging due to slow or bad network links or due to incorrect stream method usage.`) emitDefaults = flag.Bool("emit-defaults", false, - `Emit default values from JSON-encoded responses.`) + `Emit default values for JSON-encoded responses.`) msgTemplate = flag.Bool("msg-template", false, - `When describing messages, show a JSON template for the message type.`) + `When describing messages, show a template of input data.`) verbose = flag.Bool("v", false, `Enable verbose output.`) serverName = flag.String("servername", "", "Override servername when validating TLS certificate.") @@ -168,6 +181,9 @@ func main() { if (*key == "") != (*cert == "") { fail(nil, "The -cert and -key arguments must be used together and both be present.") } + if *format != "json" && *format != "text" { + fail(nil, "The -format option must be 'json' or 'text.") + } args := flag.Args() @@ -417,10 +433,17 @@ func main() { // create a request to invoke an RPC tmpl := makeTemplate(dynamic.NewMessage(dsc)) fmt.Println("\nMessage template:") - jsm := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} - err := jsm.Marshal(os.Stdout, tmpl) - if err != nil { - fail(err, "Failed to print template for message %s", s) + if *format == "json" { + jsm := jsonpb.Marshaler{Indent: " ", EmitDefaults: true} + err := jsm.Marshal(os.Stdout, tmpl) + if err != nil { + fail(err, "Failed to print template for message %s", s) + } + } else /* *format == "text" */ { + err := proto.MarshalText(os.Stdout, tmpl) + if err != nil { + fail(err, "Failed to print template for message %s", s) + } } fmt.Println() } @@ -431,28 +454,36 @@ func main() { if cc == nil { cc = dial() } - var dec *json.Decoder + var in io.Reader if *data == "@" { - dec = json.NewDecoder(os.Stdin) + in = os.Stdin } else { - dec = json.NewDecoder(strings.NewReader(*data)) + in = strings.NewReader(*data) } - h := &handler{dec: dec, descSource: descSource} - err := grpcurl.InvokeRpc(ctx, descSource, cc, symbol, append(addlHeaders, rpcHeaders...), h, h.getRequestData) + rf, formatter := formatDetails(*format, descSource, *verbose, in) + h := handler{ + out: os.Stdout, + descSource: descSource, + formatter: formatter, + verbose: *verbose, + } + + err := grpcurl.InvokeRPC(ctx, descSource, cc, symbol, append(addlHeaders, rpcHeaders...), &h, rf.next) if err != nil { fail(err, "Error invoking method %q", symbol) } reqSuffix := "" respSuffix := "" - if h.reqCount != 1 { + reqCount := rf.numRequests() + if reqCount != 1 { reqSuffix = "s" } if h.respCount != 1 { respSuffix = "s" } if *verbose { - fmt.Printf("Sent %d request%s and received %d response%s\n", h.reqCount, reqSuffix, h.respCount, respSuffix) + fmt.Printf("Sent %d request%s and received %d response%s\n", reqCount, reqSuffix, h.respCount, respSuffix) } if h.stat.Code() != codes.OK { fmt.Fprintf(os.Stderr, "ERROR:\n Code: %s\n Message: %s\n", h.stat.Code().String(), h.stat.Message()) @@ -512,64 +543,88 @@ func fail(err error, msg string, args ...interface{}) { } } +func anyResolver(source grpcurl.DescriptorSource) (jsonpb.AnyResolver, error) { + files, err := grpcurl.GetAllFiles(source) + if err != nil { + return nil, err + } + + var er dynamic.ExtensionRegistry + for _, fd := range files { + er.AddExtensionsFromFile(fd) + } + mf := dynamic.NewMessageFactoryWithExtensionRegistry(&er) + return dynamic.AnyResolver(mf, files...), nil +} + +func formatDetails(format string, descSource grpcurl.DescriptorSource, verbose bool, in io.Reader) (requestFactory, func(proto.Message) (string, error)) { + if format == "json" { + resolver, err := anyResolver(descSource) + if err != nil { + fail(err, "Error creating message resolver") + } + marshaler := jsonpb.Marshaler{ + EmitDefaults: *emitDefaults, + Indent: " ", + AnyResolver: resolver, + } + return newJsonFactory(in, resolver), marshaler.MarshalToString + } + /* else *format == "text" */ + + // if not verbose output, then also include record delimiters + // before each message (other than the first) so output could + // potentially piped to another grpcurl process + tf := textFormatter{useSeparator: !verbose} + return newTextFactory(in), tf.format +} + type handler struct { - dec *json.Decoder + out io.Writer descSource grpcurl.DescriptorSource - reqCount int respCount int stat *status.Status + formatter func(proto.Message) (string, error) + verbose bool } func (h *handler) OnResolveMethod(md *desc.MethodDescriptor) { - if *verbose { + if h.verbose { txt, err := grpcurl.GetDescriptorText(md, h.descSource) if err == nil { - fmt.Printf("\nResolved method descriptor:\n%s\n", txt) + fmt.Fprintf(h.out, "\nResolved method descriptor:\n%s\n", txt) } } } -func (*handler) OnSendHeaders(md metadata.MD) { - if *verbose { - fmt.Printf("\nRequest metadata to send:\n%s\n", grpcurl.MetadataToString(md)) +func (h *handler) OnSendHeaders(md metadata.MD) { + if h.verbose { + fmt.Fprintf(h.out, "\nRequest metadata to send:\n%s\n", grpcurl.MetadataToString(md)) } } -func (h *handler) getRequestData() ([]byte, error) { - // we don't use a mutex, though this methods will be called from different goroutine - // than other methods for bidi calls, because this method does not share any state - // with the other methods. - var msg json.RawMessage - if err := h.dec.Decode(&msg); err != nil { - return nil, err - } - h.reqCount++ - return msg, nil -} - -func (*handler) OnReceiveHeaders(md metadata.MD) { - if *verbose { - fmt.Printf("\nResponse headers received:\n%s\n", grpcurl.MetadataToString(md)) +func (h *handler) OnReceiveHeaders(md metadata.MD) { + if h.verbose { + fmt.Fprintf(h.out, "\nResponse headers received:\n%s\n", grpcurl.MetadataToString(md)) } } func (h *handler) OnReceiveResponse(resp proto.Message) { h.respCount++ - if *verbose { - fmt.Print("\nResponse contents:\n") + if h.verbose { + fmt.Fprint(h.out, "\nResponse contents:\n") } - jsm := jsonpb.Marshaler{EmitDefaults: *emitDefaults, Indent: " "} - respStr, err := jsm.MarshalToString(resp) + respStr, err := h.formatter(resp) if err != nil { - fail(err, "failed to generate JSON form of response message") + fail(err, "failed to generate %s form of response message", *format) } - fmt.Println(respStr) + fmt.Fprintln(h.out, respStr) } func (h *handler) OnReceiveTrailers(stat *status.Status, md metadata.MD) { h.stat = stat - if *verbose { - fmt.Printf("\nResponse trailers received:\n%s\n", grpcurl.MetadataToString(md)) + if h.verbose { + fmt.Fprintf(h.out, "\nResponse trailers received:\n%s\n", grpcurl.MetadataToString(md)) } } @@ -633,3 +688,116 @@ func makeTemplate(msg proto.Message) proto.Message { } return dm } + +type requestFactory interface { + next(proto.Message) error + numRequests() int +} + +type jsonFactory struct { + dec *json.Decoder + unmarshaler jsonpb.Unmarshaler + requestCount int +} + +func newJsonFactory(in io.Reader, resolver jsonpb.AnyResolver) *jsonFactory { + return &jsonFactory{ + dec: json.NewDecoder(in), + unmarshaler: jsonpb.Unmarshaler{AnyResolver: resolver}, + } +} + +func (f *jsonFactory) next(m proto.Message) error { + var msg json.RawMessage + if err := f.dec.Decode(&msg); err != nil { + return err + } + f.requestCount++ + return f.unmarshaler.Unmarshal(bytes.NewReader(msg), m) +} + +func (f *jsonFactory) numRequests() int { + return f.requestCount +} + +const ( + textSeparatorChar = 0x1e +) + +type textFactory struct { + r *bufio.Reader + err error + requestCount int +} + +func newTextFactory(in io.Reader) *textFactory { + return &textFactory{r: bufio.NewReader(in)} +} + +func (f *textFactory) next(m proto.Message) error { + if f.err != nil { + return f.err + } + + var b []byte + b, f.err = f.r.ReadBytes(textSeparatorChar) + if f.err != nil && f.err != io.EOF { + return f.err + } + // remove delimiter + if len(b) > 0 && b[len(b)-1] == textSeparatorChar { + b = b[:len(b)-1] + } + + f.requestCount++ + + return proto.UnmarshalText(string(b), m) +} + +func (f *textFactory) numRequests() int { + return f.requestCount +} + +type textFormatter struct { + useSeparator bool + numFormatted int +} + +func (tf *textFormatter) format(m proto.Message) (string, error) { + var buf bytes.Buffer + if tf.useSeparator && tf.numFormatted > 0 { + if err := buf.WriteByte(textSeparatorChar); err != nil { + return "", err + } + } + + // If message implements MarshalText method (such as a *dynamic.Message), + // it won't get details about whether or not to format to text compactly + // or with indentation. So first see if the message also implements a + // MarshalTextIndent method and use that instead if available. + type indentMarshaler interface { + MarshalTextIndent() ([]byte, error) + } + + if indenter, ok := m.(indentMarshaler); ok { + b, err := indenter.MarshalTextIndent() + if err != nil { + return "", err + } + if _, err := buf.Write(b); err != nil { + return "", err + } + } else if err := proto.MarshalText(&buf, m); err != nil { + return "", err + } + + // no trailing newline needed + str := buf.String() + if str[len(str)-1] == '\n' { + str = str[:len(str)-1] + } + + tf.numFormatted++ + + return str, nil +} diff --git a/grpcurl.go b/grpcurl.go index 6a33103..9419c02 100644 --- a/grpcurl.go +++ b/grpcurl.go @@ -48,7 +48,7 @@ var ErrReflectionNotSupported = errors.New("server does not support the reflecti // proto (like a file generated by protoc) or a remote server that supports the reflection API. type DescriptorSource interface { // ListServices returns a list of fully-qualified service names. It will be all services in a set of - // descriptor files or the set of all services exposed by a GRPC server. + // descriptor files or the set of all services exposed by a gRPC server. ListServices() ([]string, error) // FindSymbol returns a descriptor for the given fully-qualified symbol name. FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) @@ -181,6 +181,20 @@ func (fs *fileSource) ListServices() ([]string, error) { return sl, nil } +// GetAllFiles returns all of the underlying file descriptors. This is +// more thorough and more efficient than the fallback strategy used by +// the GetAllFiles package method, for enumerating all files from a +// descriptor source. +func (fs *fileSource) GetAllFiles() ([]*desc.FileDescriptor, error) { + files := make([]*desc.FileDescriptor, len(fs.files)) + i := 0 + for _, fd := range fs.files { + files[i] = fd + i++ + } + return files, nil +} + func (fs *fileSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) { for _, fd := range fs.files { if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil { @@ -200,7 +214,7 @@ func (fs *fileSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescri return fs.er.AllExtensionsForType(typeName), nil } -// DescriptorSourceFromServer creates a DescriptorSource that uses the given GRPC reflection client +// DescriptorSourceFromServer creates a DescriptorSource that uses the given gRPC reflection client // to interrogate a server for descriptor information. If the server does not support the reflection // API then the various DescriptorSource methods will return ErrReflectionNotSupported func DescriptorSourceFromServer(ctx context.Context, refClient *grpcreflect.Client) DescriptorSource { @@ -265,6 +279,75 @@ func ListServices(source DescriptorSource) ([]string, error) { return svcs, nil } +type sourceWithFiles interface { + GetAllFiles() ([]*desc.FileDescriptor, error) +} + +var _ sourceWithFiles = (*fileSource)(nil) + +// GetAllFiles uses the given descriptor source to return a list of file descriptors. +func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) { + var files []*desc.FileDescriptor + srcFiles, ok := source.(sourceWithFiles) + + if ok { + var err error + files, err = srcFiles.GetAllFiles() + if err != nil { + return nil, err + } + } else { + // Source does not implement GetAllFiles method, so use ListServices + // and grab files from there. + allFiles := map[string]*desc.FileDescriptor{} + svcNames, err := source.ListServices() + if err != nil { + return nil, err + } + for _, name := range svcNames { + d, err := source.FindSymbol(name) + if err != nil { + return nil, err + } + addAllFilesToSet(d.GetFile(), allFiles) + } + files = make([]*desc.FileDescriptor, len(allFiles)) + i := 0 + for _, fd := range allFiles { + files[i] = fd + i++ + } + } + + sort.Sort(filesByName(files)) + return files, nil +} + +type filesByName []*desc.FileDescriptor + +func (f filesByName) Len() int { + return len(f) +} + +func (f filesByName) Less(i, j int) bool { + return f[i].GetName() < f[j].GetName() +} + +func (f filesByName) Swap(i, j int) { + f[i], f[j] = f[j], f[i] +} + +func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) { + if _, ok := all[fd.GetName()]; ok { + // already added + return + } + all[fd.GetName()] = fd + for _, dep := range fd.GetDependencies() { + addAllFilesToSet(dep, all) + } +} + // ListMethods uses the given descriptor source to return a sorted list of method names // for the specified fully-qualified service name. func ListMethods(source DescriptorSource, serviceName string) ([]string, error) { @@ -319,27 +402,54 @@ type InvocationEventHandler interface { } // RequestMessageSupplier is a function that is called to retrieve request -// messages for a GRPC operation. The message contents must be valid JSON. If -// the supplier has no more messages, it should return nil, io.EOF. +// messages for a GRPC operation. This type is deprecated and will be removed in +// a future release. +// +// Deprecated: This is only used with the deprecated InvokeRpc. Instead, use +// RequestSupplier with InvokeRPC. type RequestMessageSupplier func() ([]byte, error) -// InvokeRpc uses the given GRPC connection to invoke the given method. The given descriptor source +// InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated +// and will be removed in a future release. It just delegates to the similarly named InvokeRPC +// method, whose signature is only slightly different. +// +// Deprecated: use InvokeRPC instead. +func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string, + headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error { + + return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error { + // New function is almost identical, but the request supplier function works differently. + // So we adapt the logic here to maintain compatibility. + data, err := requestData() + if err != nil { + return err + } + return jsonpb.Unmarshal(bytes.NewReader(data), m) + }) +} + +// RequestSupplier is a function that is called to populate messages for a gRPC operation. The +// function should populate the given message or return a non-nil error. If the supplier has no +// more messages, it should return io.EOF. When it returns io.EOF, it should not in any way +// modify the given message argument. +type RequestSupplier func(proto.Message) error + +// InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source // is used to determine the type of method and the type of request and response message. The given // headers are sent as request metadata. Methods on the given event handler are called as the // invocation proceeds. // // The given requestData function supplies the actual data to send. It should return io.EOF when -// there is no more request data. If it returns a nil error then the returned JSON message should -// not be blank. If the method being invoked is a unary or server-streaming RPC (e.g. exactly one -// request message) and there is no request data (e.g. the first invocation of the function returns -// io.EOF), then a blank request message is sent, as if the request data were an empty object: "{}". +// there is no more request data. If the method being invoked is a unary or server-streaming RPC +// (e.g. exactly one request message) and there is no request data (e.g. the first invocation of +// the function returns io.EOF), then an empty request message is sent. // // If the requestData function and the given event handler coordinate or share any state, they should // be thread-safe. This is because the requestData function may be called from a different goroutine // than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where // one goroutine sends request messages and another consumes the response messages). -func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string, - headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error { +func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string, + headers []string, handler InvocationEventHandler, requestData RequestSupplier) error { md := MetadataFromHeaders(headers) @@ -381,7 +491,7 @@ func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn handler.OnSendHeaders(md) ctx = metadata.NewOutgoingContext(ctx, md) - stub := grpcdynamic.NewStubWithMessageFactory(cc, msgFactory) + stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory) ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -397,21 +507,15 @@ func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn } func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, - requestData RequestMessageSupplier, req proto.Message) error { + requestData RequestSupplier, req proto.Message) error { - data, err := requestData() + err := requestData(req) if err != nil && err != io.EOF { return fmt.Errorf("error getting request data: %v", err) } - if len(data) != 0 { - err = jsonpb.UnmarshalString(string(data), req) - if err != nil { - return fmt.Errorf("could not parse given request body as message of type %q: %v", md.GetInputType().GetFullyQualifiedName(), err) - } - } if err != io.EOF { // verify there is no second message, which is a usage error - _, err := requestData() + err := requestData(req) if err == nil { return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) } else if err != io.EOF { @@ -443,7 +547,7 @@ func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDesc } func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, - requestData RequestMessageSupplier, req proto.Message) error { + requestData RequestSupplier, req proto.Message) error { // invoke the RPC! str, err := stub.InvokeRpcClientStream(ctx, md) @@ -451,8 +555,7 @@ func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met // Upload each request message in the stream var resp proto.Message for err == nil { - var data []byte - data, err = requestData() + err = requestData(req) if err == io.EOF { resp, err = str.CloseAndReceive() break @@ -460,12 +563,6 @@ func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met if err != nil { return fmt.Errorf("error getting request data: %v", err) } - if len(data) != 0 { - err = jsonpb.UnmarshalString(string(data), req) - if err != nil { - return fmt.Errorf("could not parse given request body as message of type %q: %v", md.GetInputType().GetFullyQualifiedName(), err) - } - } err = str.SendMsg(req) if err == io.EOF { @@ -500,21 +597,15 @@ func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met } func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, - requestData RequestMessageSupplier, req proto.Message) error { + requestData RequestSupplier, req proto.Message) error { - data, err := requestData() + err := requestData(req) if err != nil && err != io.EOF { return fmt.Errorf("error getting request data: %v", err) } - if len(data) != 0 { - err = jsonpb.UnmarshalString(string(data), req) - if err != nil { - return fmt.Errorf("could not parse given request body as message of type %q: %v", md.GetInputType().GetFullyQualifiedName(), err) - } - } if err != io.EOF { // verify there is no second message, which is a usage error - _, err := requestData() + err := requestData(req) if err == nil { return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) } else if err != io.EOF { @@ -555,7 +646,7 @@ func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met } func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, - requestData RequestMessageSupplier, req proto.Message) error { + requestData RequestSupplier, req proto.Message) error { // invoke the RPC! str, err := stub.InvokeRpcBidiStream(ctx, md) @@ -572,9 +663,8 @@ func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescr // Concurrently upload each request message in the stream var err error - var data []byte for err == nil { - data, err = requestData() + err = requestData(req) if err == io.EOF { err = str.CloseSend() @@ -584,13 +674,6 @@ func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescr err = fmt.Errorf("error getting request data: %v", err) break } - if len(data) != 0 { - err = jsonpb.UnmarshalString(string(data), req) - if err != nil { - err = fmt.Errorf("could not parse given request body as message of type %q: %v", md.GetInputType().GetFullyQualifiedName(), err) - break - } - } err = str.SendMsg(req) @@ -700,16 +783,29 @@ func MetadataToString(md metadata.MD) string { if len(md) == 0 { return "(empty)" } + + keys := make([]string, 0, len(md)) + for k := range md { + keys = append(keys, k) + } + sort.Strings(keys) + var b bytes.Buffer - for k, vs := range md { + first := true + for _, k := range keys { + vs := md[k] for _, v := range vs { + if first { + first = false + } else { + b.WriteString("\n") + } b.WriteString(k) b.WriteString(": ") if strings.HasSuffix(k, "-bin") { v = base64.StdEncoding.EncodeToString([]byte(v)) } b.WriteString(v) - b.WriteString("\n") } } return b.String() @@ -840,7 +936,7 @@ func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) ( return dm, nil } -// ClientTransportCredentials builds transport credentials for a GRPC client using the +// ClientTransportCredentials builds transport credentials for a gRPC client using the // given properties. If cacertFile is blank, only standard trusted certs are used to // verify the server certs. If clientCertFile is blank, the client will not use a client // certificate. If clientCertFile is not blank then clientKeyFile must not be blank. @@ -877,7 +973,7 @@ func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertF return credentials.NewTLS(&tlsConf), nil } -// ServerTransportCredentials builds transport credentials for a GRPC server using the +// ServerTransportCredentials builds transport credentials for a gRPC server using the // given properties. If cacertFile is blank, the server will not request client certs // unless requireClientCerts is true. When requireClientCerts is false and cacertFile is // not blank, the server will verify client certs when presented, but will not require diff --git a/grpcurl_test.go b/grpcurl_test.go index fb429ee..932137f 100644 --- a/grpcurl_test.go +++ b/grpcurl_test.go @@ -44,6 +44,10 @@ type descSourceCase struct { includeRefl bool } +// NB: These tests intentionally use the deprecated InvokeRpc since that +// calls the other (non-deprecated InvokeRPC). That allows the tests to +// easily exercise both functions. + func TestMain(m *testing.M) { var err error sourceProtoset, err = DescriptorSourceFromProtoSets("testing/test.protoset") @@ -235,6 +239,73 @@ func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection } } +func TestGetAllFiles(t *testing.T) { + expectedFiles := []string{"testing/test.proto"} + // server reflection picks up filename from linked in Go package, + // which indicates "grpc_testing/test.proto", not our local copy. + expectedFilesWithReflection := []string{"grpc_reflection_v1alpha/reflection.proto", "grpc_testing/test.proto"} + + for _, ds := range descSources { + t.Run(ds.name, func(t *testing.T) { + files, err := GetAllFiles(ds.source) + if err != nil { + t.Fatalf("failed to get all files: %v", err) + } + names := fileNames(files) + expected := expectedFiles + if ds.includeRefl { + expected = expectedFilesWithReflection + } + if !reflect.DeepEqual(expected, names) { + t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expected, names) + } + }) + } + + // try cases with more complicated set of files + otherSourceProtoset, err := DescriptorSourceFromProtoSets("testing/test.protoset", "testing/example.protoset") + if err != nil { + t.Fatal(err.Error()) + } + otherSourceProtoFiles, err := DescriptorSourceFromProtoFiles(nil, "testing/test.proto", "testing/example.proto") + if err != nil { + t.Fatal(err.Error()) + } + otherDescSources := []descSourceCase{ + {"protoset[b]", otherSourceProtoset, false}, + {"proto[b]", otherSourceProtoFiles, false}, + } + expectedFiles = []string{ + "google/protobuf/any.proto", + "google/protobuf/descriptor.proto", + "google/protobuf/empty.proto", + "google/protobuf/timestamp.proto", + "testing/example.proto", + "testing/example2.proto", + "testing/test.proto", + } + for _, ds := range otherDescSources { + t.Run(ds.name, func(t *testing.T) { + files, err := GetAllFiles(ds.source) + if err != nil { + t.Fatalf("failed to get all files: %v", err) + } + names := fileNames(files) + if !reflect.DeepEqual(expectedFiles, names) { + t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expectedFiles, names) + } + }) + } +} + +func fileNames(files []*desc.FileDescriptor) []string { + names := make([]string, len(files)) + for i, f := range files { + names[i] = f.GetName() + } + return names +} + func TestDescribe(t *testing.T) { for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { diff --git a/mk-test-files.sh b/mk-test-files.sh index 94e1027..407f7dc 100755 --- a/mk-test-files.sh +++ b/mk-test-files.sh @@ -7,8 +7,8 @@ cd "$(dirname $0)" # Run this script to generate files used by tests. echo "Creating protosets..." -protoc ../../../google.golang.org/grpc/interop/grpc_testing/test.proto \ - -I../../../ --include_imports \ +protoc testing/test.proto \ + --include_imports \ --descriptor_set_out=testing/test.protoset protoc testing/example.proto \ diff --git a/testing/example.proto b/testing/example.proto index 1229267..bfb3f3f 100644 --- a/testing/example.proto +++ b/testing/example.proto @@ -3,9 +3,11 @@ syntax = "proto3"; import "google/protobuf/descriptor.proto"; import "google/protobuf/empty.proto"; import "google/protobuf/timestamp.proto"; +import "testing/example2.proto"; message TestRequest { repeated string file_names = 1; + repeated Extension extensions = 2; } message TestResponse { diff --git a/testing/example.protoset b/testing/example.protoset index b3b19fe..cad3361 100644 Binary files a/testing/example.protoset and b/testing/example.protoset differ diff --git a/testing/example2.proto b/testing/example2.proto new file mode 100644 index 0000000..ee4e9be --- /dev/null +++ b/testing/example2.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + +import "google/protobuf/any.proto"; + +message Extension { + uint64 id = 1; + google.protobuf.Any data = 2; +} diff --git a/testing/test.protoset b/testing/test.protoset index 4b6d521..6915f3d 100644 Binary files a/testing/test.protoset and b/testing/test.protoset differ