package grpcurl_test import ( "context" "encoding/json" "fmt" "io" "net" "os" "reflect" "strings" "testing" "time" "github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 we have to import these because some of their types appear in exported API "github.com/golang/protobuf/proto" //lint:ignore SA1019 same as above "github.com/jhump/protoreflect/desc" //lint:ignore SA1019 same as above "github.com/jhump/protoreflect/grpcreflect" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" . "github.com/fullstorydev/grpcurl" grpcurl_testing "github.com/fullstorydev/grpcurl/internal/testing" jsonpbtest "github.com/fullstorydev/grpcurl/internal/testing/jsonpb_test_proto" ) var ( sourceProtoset DescriptorSource sourceProtoFiles DescriptorSource ccNoReflect *grpc.ClientConn sourceReflect DescriptorSource ccReflect *grpc.ClientConn descSources []descSourceCase ) type descSourceCase struct { name string source DescriptorSource includeRefl bool } // NB: These tests intentionally use the deprecated InvokeRpc since that // calls the other (non-deprecated InvokeRPC). That allows the tests to // easily exercise both functions. func TestMain(m *testing.M) { var err error sourceProtoset, err = DescriptorSourceFromProtoSets("internal/testing/test.protoset") if err != nil { panic(err) } sourceProtoFiles, err = DescriptorSourceFromProtoFiles([]string{"internal/testing"}, "test.proto") if err != nil { panic(err) } // Create a server that includes the reflection service svrReflect := grpc.NewServer() grpcurl_testing.RegisterTestServiceServer(svrReflect, grpcurl_testing.TestServer{}) reflection.Register(svrReflect) var portReflect int if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil { panic(err) } else { portReflect = l.Addr().(*net.TCPAddr).Port go svrReflect.Serve(l) } defer svrReflect.Stop() // And a corresponding client ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if ccReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portReflect), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()); err != nil { panic(err) } defer ccReflect.Close() refClient := grpcreflect.NewClientAuto(context.Background(), ccReflect) defer refClient.Reset() sourceReflect = DescriptorSourceFromServer(context.Background(), refClient) // Also create a server that does *not* include the reflection service svrProtoset := grpc.NewServer() grpcurl_testing.RegisterTestServiceServer(svrProtoset, grpcurl_testing.TestServer{}) var portProtoset int if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil { panic(err) } else { portProtoset = l.Addr().(*net.TCPAddr).Port go svrProtoset.Serve(l) } defer svrProtoset.Stop() // And a corresponding client ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if ccNoReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portProtoset), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()); err != nil { panic(err) } defer ccNoReflect.Close() descSources = []descSourceCase{ {"protoset", sourceProtoset, false}, {"proto", sourceProtoFiles, false}, {"reflect", sourceReflect, true}, } os.Exit(m.Run()) } func TestServerDoesNotSupportReflection(t *testing.T) { refClient := grpcreflect.NewClientAuto(context.Background(), ccNoReflect) defer refClient.Reset() refSource := DescriptorSourceFromServer(context.Background(), refClient) _, err := ListServices(refSource) if err != ErrReflectionNotSupported { t.Errorf("ListServices should have returned ErrReflectionNotSupported; instead got %v", err) } _, err = ListMethods(refSource, "SomeService") if err != ErrReflectionNotSupported { t.Errorf("ListMethods should have returned ErrReflectionNotSupported; instead got %v", err) } err = InvokeRpc(context.Background(), refSource, ccNoReflect, "FooService/Method", nil, nil, nil) // InvokeRpc wraps the error, so we just verify the returned error includes the right message if err == nil || !strings.Contains(err.Error(), ErrReflectionNotSupported.Error()) { t.Errorf("InvokeRpc should have returned ErrReflectionNotSupported; instead got %v", err) } } func TestProtosetWithImports(t *testing.T) { sourceProtoset, err := DescriptorSourceFromProtoSets("internal/testing/example.protoset") if err != nil { t.Fatalf("failed to load protoset: %v", err) } // really shallow check of the loaded descriptors if sd, err := sourceProtoset.FindSymbol("TestService"); err != nil { t.Errorf("failed to find TestService in protoset: %v", err) } else if sd == nil { t.Errorf("FindSymbol returned nil for TestService") } else if _, ok := sd.(*desc.ServiceDescriptor); !ok { t.Errorf("FindSymbol returned wrong kind of descriptor for TestService: %T", sd) } if md, err := sourceProtoset.FindSymbol("TestRequest"); err != nil { t.Errorf("failed to find TestRequest in protoset: %v", err) } else if md == nil { t.Errorf("FindSymbol returned nil for TestRequest") } else if _, ok := md.(*desc.MessageDescriptor); !ok { t.Errorf("FindSymbol returned wrong kind of descriptor for TestRequest: %T", md) } } func TestListServices(t *testing.T) { for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { doTestListServices(t, ds.source, ds.includeRefl) }) } } func doTestListServices(t *testing.T, source DescriptorSource, includeReflection bool) { names, err := ListServices(source) if err != nil { t.Fatalf("failed to list services: %v", err) } var expected []string if includeReflection { // when using server reflection, we see the TestService as well as the ServerReflection service expected = []string{"grpc.reflection.v1.ServerReflection", "grpc.reflection.v1alpha.ServerReflection", "testing.TestService"} } else { // without reflection, we see all services defined in the same test.proto file, which is the // TestService as well as UnimplementedService expected = []string{"testing.TestService", "testing.UnimplementedService"} } if !reflect.DeepEqual(expected, names) { t.Errorf("ListServices returned wrong results: wanted %v, got %v", expected, names) } } func TestListMethods(t *testing.T) { for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { doTestListMethods(t, ds.source, ds.includeRefl) }) } } func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection bool) { names, err := ListMethods(source, "testing.TestService") if err != nil { t.Fatalf("failed to list methods for TestService: %v", err) } expected := []string{ "testing.TestService.EmptyCall", "testing.TestService.FullDuplexCall", "testing.TestService.HalfDuplexCall", "testing.TestService.StreamingInputCall", "testing.TestService.StreamingOutputCall", "testing.TestService.UnaryCall", } if !reflect.DeepEqual(expected, names) { t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names) } if includeReflection { // when using server reflection, we see the TestService as well as the ServerReflection service names, err = ListMethods(source, "grpc.reflection.v1.ServerReflection") if err != nil { t.Fatalf("failed to list methods for ServerReflection: %v", err) } expected = []string{"grpc.reflection.v1.ServerReflection.ServerReflectionInfo"} } else { // without reflection, we see all services defined in the same test.proto file, which is the // TestService as well as UnimplementedService names, err = ListMethods(source, "testing.UnimplementedService") if err != nil { t.Fatalf("failed to list methods for ServerReflection: %v", err) } expected = []string{"testing.UnimplementedService.UnimplementedCall"} } if !reflect.DeepEqual(expected, names) { t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names) } // force an error _, err = ListMethods(source, "FooService") if err != nil && !strings.Contains(err.Error(), "Symbol not found: FooService") { t.Errorf("ListMethods should have returned 'not found' error but instead returned %v", err) } } func TestGetAllFiles(t *testing.T) { expectedFiles := []string{"test.proto"} expectedFilesWithReflection := []string{ "grpc/reflection/v1/reflection.proto", "grpc/reflection/v1alpha/reflection.proto", "test.proto", } for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { files, err := GetAllFiles(ds.source) if err != nil { t.Fatalf("failed to get all files: %v", err) } names := fileNames(files) match := false var expected []string if ds.includeRefl { expected = expectedFilesWithReflection } else { expected = expectedFiles } match = reflect.DeepEqual(expected, names) if !match { t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expected, names) } }) } // try cases with more complicated set of files otherSourceProtoset, err := DescriptorSourceFromProtoSets("internal/testing/test.protoset", "internal/testing/example.protoset") if err != nil { t.Fatal(err.Error()) } otherSourceProtoFiles, err := DescriptorSourceFromProtoFiles([]string{"internal/testing"}, "test.proto", "example.proto") if err != nil { t.Fatal(err.Error()) } otherDescSources := []descSourceCase{ {"protoset[b]", otherSourceProtoset, false}, {"proto[b]", otherSourceProtoFiles, false}, } expectedFiles = []string{ "example.proto", "example2.proto", "google/protobuf/any.proto", "google/protobuf/descriptor.proto", "google/protobuf/empty.proto", "google/protobuf/timestamp.proto", "test.proto", } for _, ds := range otherDescSources { t.Run(ds.name, func(t *testing.T) { files, err := GetAllFiles(ds.source) if err != nil { t.Fatalf("failed to get all files: %v", err) } names := fileNames(files) if !reflect.DeepEqual(expectedFiles, names) { t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expectedFiles, names) } }) } } func TestExpandHeaders(t *testing.T) { inHeaders := []string{"key1: ${value}", "key2: bar", "key3: ${woo", "key4: woo}", "key5: ${TEST}", "key6: ${TEST_VAR}", "${TEST}: ${TEST_VAR}", "key8: ${EMPTY}"} os.Setenv("value", "value") os.Setenv("TEST", "value5") os.Setenv("TEST_VAR", "value6") os.Setenv("EMPTY", "") expectedHeaders := map[string]bool{"key1: value": true, "key2: bar": true, "key3: ${woo": true, "key4: woo}": true, "key5: value5": true, "key6: value6": true, "value5: value6": true, "key8: ": true} outHeaders, err := ExpandHeaders(inHeaders) if err != nil { t.Errorf("The ExpandHeaders function generated an unexpected error %s", err) } for _, expandedHeader := range outHeaders { if _, ok := expectedHeaders[expandedHeader]; !ok { t.Errorf("The ExpandHeaders function has returned an unexpected header. Received unexpected header %s", expandedHeader) } } badHeaders := []string{"key: ${DNE}"} _, err = ExpandHeaders(badHeaders) if err == nil { t.Errorf("The ExpandHeaders function should return an error for missing environment variables %q", badHeaders) } } func fileNames(files []*desc.FileDescriptor) []string { names := make([]string, len(files)) for i, f := range files { names[i] = f.GetName() } return names } const expectKnownType = `{ "dur": "0s", "ts": "1970-01-01T00:00:00Z", "dbl": 0, "flt": 0, "i64": "0", "u64": "0", "i32": 0, "u32": 0, "bool": false, "str": "", "bytes": null, "st": {"google.protobuf.Struct": "supports arbitrary JSON objects"}, "an": {"@type": "type.googleapis.com/google.protobuf.Empty", "value": {}}, "lv": [{"google.protobuf.ListValue": "is an array of arbitrary JSON values"}], "val": {"google.protobuf.Value": "supports arbitrary JSON"} }` func TestMakeTemplateKnownTypes(t *testing.T) { descriptor, err := desc.LoadMessageDescriptorForMessage((*jsonpbtest.KnownTypes)(nil)) if err != nil { t.Fatalf("failed to load descriptor: %v", err) } message := MakeTemplate(descriptor) jsm := jsonpb.Marshaler{EmitDefaults: true} out, err := jsm.MarshalToString(message) if err != nil { t.Fatalf("failed to marshal to JSON: %v", err) } // make sure template JSON matches expected var actual, expected interface{} if err := json.Unmarshal([]byte(out), &actual); err != nil { t.Fatalf("failed to parse actual JSON: %v", err) } if err := json.Unmarshal([]byte(expectKnownType), &expected); err != nil { t.Fatalf("failed to parse expected JSON: %v", err) } if !reflect.DeepEqual(actual, expected) { t.Errorf("template message is not as expected; want:\n%s\ngot:\n%s", expectKnownType, out) } } func TestDescribe(t *testing.T) { for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { doTestDescribe(t, ds.source) }) } } func doTestDescribe(t *testing.T, source DescriptorSource) { sym := "testing.TestService.EmptyCall" dsc, err := source.FindSymbol(sym) if err != nil { t.Fatalf("failed to get descriptor for %q: %v", sym, err) } if _, ok := dsc.(*desc.MethodDescriptor); !ok { t.Fatalf("descriptor for %q was a %T (expecting a MethodDescriptor)", sym, dsc) } txt := proto.MarshalTextString(dsc.AsProto()) expected := `name: "EmptyCall" input_type: ".testing.Empty" output_type: ".testing.Empty" ` if expected != txt { t.Errorf("descriptor mismatch: expected %s, got %s", expected, txt) } sym = "testing.StreamingOutputCallResponse" dsc, err = source.FindSymbol(sym) if err != nil { t.Fatalf("failed to get descriptor for %q: %v", sym, err) } if _, ok := dsc.(*desc.MessageDescriptor); !ok { t.Fatalf("descriptor for %q was a %T (expecting a MessageDescriptor)", sym, dsc) } txt = proto.MarshalTextString(dsc.AsProto()) expected = `name: "StreamingOutputCallResponse" field: < name: "payload" number: 1 label: LABEL_OPTIONAL type: TYPE_MESSAGE type_name: ".testing.Payload" json_name: "payload" > ` if expected != txt { t.Errorf("descriptor mismatch: expected %s, got %s", expected, txt) } _, err = source.FindSymbol("FooService") if err != nil && !strings.Contains(err.Error(), "Symbol not found: FooService") { t.Errorf("FindSymbol should have returned 'not found' error but instead returned %v", err) } } const ( // type == COMPRESSABLE, but that is default (since it has // numeric value == 0) and thus doesn't actually get included // on the wire payload1 = `{ "payload": { "body": "SXQncyBCdXNpbmVzcyBUaW1l" } }` payload2 = `{ "payload": { "type": "RANDOM", "body": "Rm91eCBkdSBGYUZh" } }` payload3 = `{ "payload": { "type": "UNCOMPRESSABLE", "body": "SGlwaG9wb3BvdGFtdXMgdnMuIFJoeW1lbm9jZXJvcw==" } }` ) func getCC(includeRefl bool) *grpc.ClientConn { if includeRefl { return ccReflect } else { return ccNoReflect } } func TestUnary(t *testing.T) { for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { doTestUnary(t, getCC(ds.includeRefl), ds.source) }) } } func doTestUnary(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { // Success h := &handler{reqMessages: []string{payload1}} err := InvokeRpc(context.Background(), source, cc, "testing.TestService/UnaryCall", makeHeaders(codes.OK), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } if h.check(t, "testing.TestService.UnaryCall", codes.OK, 1, 1) { if h.respMessages[0] != payload1 { t.Errorf("unexpected response from RPC: expecting %s; got %s", payload1, h.respMessages[0]) } } // Failure h = &handler{reqMessages: []string{payload1}} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/UnaryCall", makeHeaders(codes.NotFound), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } h.check(t, "testing.TestService.UnaryCall", codes.NotFound, 1, 0) } func TestClientStream(t *testing.T) { for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { doTestClientStream(t, getCC(ds.includeRefl), ds.source) }) } } func doTestClientStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { // Success h := &handler{reqMessages: []string{payload1, payload2, payload3}} err := InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.OK), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } if h.check(t, "testing.TestService.StreamingInputCall", codes.OK, 3, 1) { expected := `{ "aggregatedPayloadSize": 61 }` if h.respMessages[0] != expected { t.Errorf("unexpected response from RPC: expecting %s; got %s", expected, h.respMessages[0]) } } // Fail fast (server rejects as soon as possible) h = &handler{reqMessages: []string{payload1, payload2, payload3}} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.InvalidArgument), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } h.check(t, "testing.TestService.StreamingInputCall", codes.InvalidArgument, -3, 0) // Fail late (server waits until stream is complete to reject) h = &handler{reqMessages: []string{payload1, payload2, payload3}} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.Internal, true), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } h.check(t, "testing.TestService.StreamingInputCall", codes.Internal, 3, 0) } func TestServerStream(t *testing.T) { for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { doTestServerStream(t, getCC(ds.includeRefl), ds.source) }) } } func doTestServerStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { req := &grpcurl_testing.StreamingOutputCallRequest{ ResponseType: grpcurl_testing.PayloadType_COMPRESSABLE, ResponseParameters: []*grpcurl_testing.ResponseParameters{ {Size: 10}, {Size: 20}, {Size: 30}, {Size: 40}, {Size: 50}, }, } payload, err := (&jsonpb.Marshaler{}).MarshalToString(req) if err != nil { t.Fatalf("failed to construct request: %v", err) } // Success h := &handler{reqMessages: []string{payload}} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.OK), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } if h.check(t, "testing.TestService.StreamingOutputCall", codes.OK, 1, 5) { resp := &grpcurl_testing.StreamingOutputCallResponse{} for i, msg := range h.respMessages { if err := jsonpb.UnmarshalString(msg, resp); err != nil { t.Errorf("failed to parse response %d: %v", i+1, err) } if resp.Payload.GetType() != grpcurl_testing.PayloadType_COMPRESSABLE { t.Errorf("response %d has wrong payload type; expecting %v, got %v", i, grpcurl_testing.PayloadType_COMPRESSABLE, resp.Payload.Type) } if len(resp.Payload.Body) != (i+1)*10 { t.Errorf("response %d has wrong payload size; expecting %d, got %d", i, (i+1)*10, len(resp.Payload.Body)) } resp.Reset() } } // Fail fast (server rejects as soon as possible) h = &handler{reqMessages: []string{payload}} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.Aborted), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } h.check(t, "testing.TestService.StreamingOutputCall", codes.Aborted, 1, 0) // Fail late (server waits until stream is complete to reject) h = &handler{reqMessages: []string{payload}} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.AlreadyExists, true), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } h.check(t, "testing.TestService.StreamingOutputCall", codes.AlreadyExists, 1, 5) } func TestHalfDuplexStream(t *testing.T) { for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { doTestHalfDuplexStream(t, getCC(ds.includeRefl), ds.source) }) } } func doTestHalfDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { reqs := []string{payload1, payload2, payload3} // Success h := &handler{reqMessages: reqs} err := InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.OK), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } if h.check(t, "testing.TestService.HalfDuplexCall", codes.OK, 3, 3) { for i, resp := range h.respMessages { if resp != reqs[i] { t.Errorf("unexpected response %d from RPC:\nexpecting %q\ngot %q", i, reqs[i], resp) } } } // Fail fast (server rejects as soon as possible) h = &handler{reqMessages: reqs} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.Canceled), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } h.check(t, "testing.TestService.HalfDuplexCall", codes.Canceled, -3, 0) // Fail late (server waits until stream is complete to reject) h = &handler{reqMessages: reqs} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.DataLoss, true), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } h.check(t, "testing.TestService.HalfDuplexCall", codes.DataLoss, 3, 3) } func TestFullDuplexStream(t *testing.T) { for _, ds := range descSources { t.Run(ds.name, func(t *testing.T) { doTestFullDuplexStream(t, getCC(ds.includeRefl), ds.source) }) } } func doTestFullDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) { reqs := make([]string, 3) req := &grpcurl_testing.StreamingOutputCallRequest{ ResponseType: grpcurl_testing.PayloadType_RANDOM, } for i := range reqs { req.ResponseParameters = append(req.ResponseParameters, &grpcurl_testing.ResponseParameters{Size: int32((i + 1) * 10)}) payload, err := (&jsonpb.Marshaler{}).MarshalToString(req) if err != nil { t.Fatalf("failed to construct request %d: %v", i, err) } reqs[i] = payload } // Success h := &handler{reqMessages: reqs} err := InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.OK), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } if h.check(t, "testing.TestService.FullDuplexCall", codes.OK, 3, 6) { resp := &grpcurl_testing.StreamingOutputCallResponse{} i := 0 for j := 1; j < 3; j++ { // three requests for k := 0; k < j; k++ { // 1 response for first request, 2 for second, etc msg := h.respMessages[i] if err := jsonpb.UnmarshalString(msg, resp); err != nil { t.Errorf("failed to parse response %d: %v", i+1, err) } if resp.Payload.GetType() != grpcurl_testing.PayloadType_RANDOM { t.Errorf("response %d has wrong payload type; expecting %v, got %v", i, grpcurl_testing.PayloadType_RANDOM, resp.Payload.Type) } if len(resp.Payload.Body) != (k+1)*10 { t.Errorf("response %d has wrong payload size; expecting %d, got %d", i, (k+1)*10, len(resp.Payload.Body)) } resp.Reset() i++ } } } // Fail fast (server rejects as soon as possible) h = &handler{reqMessages: reqs} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.PermissionDenied), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } h.check(t, "testing.TestService.FullDuplexCall", codes.PermissionDenied, -3, 0) // Fail late (server waits until stream is complete to reject) h = &handler{reqMessages: reqs} err = InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.ResourceExhausted, true), h, h.getRequestData) if err != nil { t.Fatalf("unexpected error during RPC: %v", err) } h.check(t, "testing.TestService.FullDuplexCall", codes.ResourceExhausted, 3, 6) } type handler struct { method *desc.MethodDescriptor methodCount int reqHeaders metadata.MD reqHeadersCount int reqMessages []string reqMessagesCount int respHeaders metadata.MD respHeadersCount int respMessages []string respTrailers metadata.MD respStatus *status.Status respTrailersCount int } func (h *handler) getRequestData() ([]byte, error) { // we don't use a mutex, though this method will be called from different goroutine // than other methods for bidi calls, because this method does not share any state // with the other methods. h.reqMessagesCount++ if h.reqMessagesCount > len(h.reqMessages) { return nil, io.EOF } if h.reqMessagesCount > 1 { // insert delay between messages in request stream time.Sleep(time.Millisecond * 50) } return []byte(h.reqMessages[h.reqMessagesCount-1]), nil } func (h *handler) OnResolveMethod(md *desc.MethodDescriptor) { h.methodCount++ h.method = md } func (h *handler) OnSendHeaders(md metadata.MD) { h.reqHeadersCount++ h.reqHeaders = md } func (h *handler) OnReceiveHeaders(md metadata.MD) { h.respHeadersCount++ h.respHeaders = md } func (h *handler) OnReceiveResponse(msg proto.Message) { jsm := jsonpb.Marshaler{Indent: " "} respStr, err := jsm.MarshalToString(msg) if err != nil { panic(fmt.Errorf("failed to generate JSON form of response message: %v", err)) } h.respMessages = append(h.respMessages, respStr) } func (h *handler) OnReceiveTrailers(stat *status.Status, md metadata.MD) { h.respTrailersCount++ h.respTrailers = md h.respStatus = stat } func (h *handler) check(t *testing.T, expectedMethod string, expectedCode codes.Code, expectedRequestQueries, expectedResponses int) bool { // verify a few things were only ever called once if h.methodCount != 1 { t.Errorf("expected grpcurl to invoke OnResolveMethod once; was %d", h.methodCount) } if h.reqHeadersCount != 1 { t.Errorf("expected grpcurl to invoke OnSendHeaders once; was %d", h.reqHeadersCount) } if h.reqHeadersCount != 1 { t.Errorf("expected grpcurl to invoke OnSendHeaders once; was %d", h.reqHeadersCount) } if h.respHeadersCount != 1 { t.Errorf("expected grpcurl to invoke OnReceiveHeaders once; was %d", h.respHeadersCount) } if h.respTrailersCount != 1 { t.Errorf("expected grpcurl to invoke OnReceiveTrailers once; was %d", h.respTrailersCount) } // check other stuff against given expectations if h.method.GetFullyQualifiedName() != expectedMethod { t.Errorf("wrong method: expecting %v, got %v", expectedMethod, h.method.GetFullyQualifiedName()) } if h.respStatus.Code() != expectedCode { t.Errorf("wrong code: expecting %v, got %v", expectedCode, h.respStatus.Code()) } if expectedRequestQueries < 0 { // negative expectation means "negate and expect up to that number; could be fewer" if h.reqMessagesCount > -expectedRequestQueries+1 { // the + 1 is because there will be an extra query that returns EOF t.Errorf("wrong number of messages queried: expecting no more than %v, got %v", -expectedRequestQueries, h.reqMessagesCount-1) } } else { if h.reqMessagesCount != expectedRequestQueries+1 { // the + 1 is because there will be an extra query that returns EOF t.Errorf("wrong number of messages queried: expecting %v, got %v", expectedRequestQueries, h.reqMessagesCount-1) } } if len(h.respMessages) != expectedResponses { t.Errorf("wrong number of messages received: expecting %v, got %v", expectedResponses, len(h.respMessages)) } // also check headers and trailers came through as expected v := h.respHeaders["some-fake-header-1"] if len(v) != 1 || v[0] != "val1" { t.Errorf("wrong request header for %q: %v", "some-fake-header-1", v) } v = h.respHeaders["some-fake-header-2"] if len(v) != 1 || v[0] != "val2" { t.Errorf("wrong request header for %q: %v", "some-fake-header-2", v) } v = h.respTrailers["some-fake-trailer-1"] if len(v) != 1 || v[0] != "valA" { t.Errorf("wrong request header for %q: %v", "some-fake-trailer-1", v) } v = h.respTrailers["some-fake-trailer-2"] if len(v) != 1 || v[0] != "valB" { t.Errorf("wrong request header for %q: %v", "some-fake-trailer-2", v) } return len(h.respMessages) == expectedResponses } func makeHeaders(code codes.Code, failLate ...bool) []string { if len(failLate) > 1 { panic("incorrect use of makeContext; should be at most one failLate flag") } hdrs := append(make([]string, 0, 5), fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyHeaders, "some-fake-header-1: val1"), fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyHeaders, "some-fake-header-2: val2"), fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyTrailers, "some-fake-trailer-1: valA"), fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyTrailers, "some-fake-trailer-2: valB")) if code != codes.OK { if len(failLate) > 0 && failLate[0] { hdrs = append(hdrs, fmt.Sprintf("%s: %d", grpcurl_testing.MetadataFailLate, code)) } else { hdrs = append(hdrs, fmt.Sprintf("%s: %d", grpcurl_testing.MetadataFailEarly, code)) } } return hdrs }