883 lines
29 KiB
Go
883 lines
29 KiB
Go
package grpcurl_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 we have to import this because it appears in exported API
|
|
"github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API
|
|
"github.com/jhump/protoreflect/desc"
|
|
"github.com/jhump/protoreflect/grpcreflect"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/reflection"
|
|
reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
|
|
"google.golang.org/grpc/status"
|
|
|
|
. "github.com/fullstorydev/grpcurl"
|
|
grpcurl_testing "github.com/fullstorydev/grpcurl/internal/testing"
|
|
jsonpbtest "github.com/fullstorydev/grpcurl/internal/testing/jsonpb_test_proto"
|
|
)
|
|
|
|
var (
|
|
sourceProtoset DescriptorSource
|
|
sourceProtoFiles DescriptorSource
|
|
ccNoReflect *grpc.ClientConn
|
|
|
|
sourceReflect DescriptorSource
|
|
ccReflect *grpc.ClientConn
|
|
|
|
descSources []descSourceCase
|
|
)
|
|
|
|
type descSourceCase struct {
|
|
name string
|
|
source DescriptorSource
|
|
includeRefl bool
|
|
}
|
|
|
|
// NB: These tests intentionally use the deprecated InvokeRpc since that
|
|
// calls the other (non-deprecated InvokeRPC). That allows the tests to
|
|
// easily exercise both functions.
|
|
|
|
func TestMain(m *testing.M) {
|
|
var err error
|
|
sourceProtoset, err = DescriptorSourceFromProtoSets("internal/testing/test.protoset")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
sourceProtoFiles, err = DescriptorSourceFromProtoFiles([]string{"internal/testing"}, "test.proto")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Create a server that includes the reflection service
|
|
svrReflect := grpc.NewServer()
|
|
grpcurl_testing.RegisterTestServiceServer(svrReflect, grpcurl_testing.TestServer{})
|
|
reflection.Register(svrReflect)
|
|
var portReflect int
|
|
if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil {
|
|
panic(err)
|
|
} else {
|
|
portReflect = l.Addr().(*net.TCPAddr).Port
|
|
go svrReflect.Serve(l)
|
|
}
|
|
defer svrReflect.Stop()
|
|
|
|
// And a corresponding client
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
if ccReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portReflect),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()); err != nil {
|
|
panic(err)
|
|
}
|
|
defer ccReflect.Close()
|
|
refClient := grpcreflect.NewClientV1Alpha(context.Background(), reflectpb.NewServerReflectionClient(ccReflect))
|
|
defer refClient.Reset()
|
|
|
|
sourceReflect = DescriptorSourceFromServer(context.Background(), refClient)
|
|
|
|
// Also create a server that does *not* include the reflection service
|
|
svrProtoset := grpc.NewServer()
|
|
grpcurl_testing.RegisterTestServiceServer(svrProtoset, grpcurl_testing.TestServer{})
|
|
var portProtoset int
|
|
if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil {
|
|
panic(err)
|
|
} else {
|
|
portProtoset = l.Addr().(*net.TCPAddr).Port
|
|
go svrProtoset.Serve(l)
|
|
}
|
|
defer svrProtoset.Stop()
|
|
|
|
// And a corresponding client
|
|
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
if ccNoReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portProtoset),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()); err != nil {
|
|
panic(err)
|
|
}
|
|
defer ccNoReflect.Close()
|
|
|
|
descSources = []descSourceCase{
|
|
{"protoset", sourceProtoset, false},
|
|
{"proto", sourceProtoFiles, false},
|
|
{"reflect", sourceReflect, true},
|
|
}
|
|
|
|
os.Exit(m.Run())
|
|
}
|
|
|
|
func TestServerDoesNotSupportReflection(t *testing.T) {
|
|
refClient := grpcreflect.NewClientV1Alpha(context.Background(), reflectpb.NewServerReflectionClient(ccNoReflect))
|
|
defer refClient.Reset()
|
|
|
|
refSource := DescriptorSourceFromServer(context.Background(), refClient)
|
|
|
|
_, err := ListServices(refSource)
|
|
if err != ErrReflectionNotSupported {
|
|
t.Errorf("ListServices should have returned ErrReflectionNotSupported; instead got %v", err)
|
|
}
|
|
|
|
_, err = ListMethods(refSource, "SomeService")
|
|
if err != ErrReflectionNotSupported {
|
|
t.Errorf("ListMethods should have returned ErrReflectionNotSupported; instead got %v", err)
|
|
}
|
|
|
|
err = InvokeRpc(context.Background(), refSource, ccNoReflect, "FooService/Method", nil, nil, nil)
|
|
// InvokeRpc wraps the error, so we just verify the returned error includes the right message
|
|
if err == nil || !strings.Contains(err.Error(), ErrReflectionNotSupported.Error()) {
|
|
t.Errorf("InvokeRpc should have returned ErrReflectionNotSupported; instead got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestProtosetWithImports(t *testing.T) {
|
|
sourceProtoset, err := DescriptorSourceFromProtoSets("internal/testing/example.protoset")
|
|
if err != nil {
|
|
t.Fatalf("failed to load protoset: %v", err)
|
|
}
|
|
// really shallow check of the loaded descriptors
|
|
if sd, err := sourceProtoset.FindSymbol("TestService"); err != nil {
|
|
t.Errorf("failed to find TestService in protoset: %v", err)
|
|
} else if sd == nil {
|
|
t.Errorf("FindSymbol returned nil for TestService")
|
|
} else if _, ok := sd.(*desc.ServiceDescriptor); !ok {
|
|
t.Errorf("FindSymbol returned wrong kind of descriptor for TestService: %T", sd)
|
|
}
|
|
if md, err := sourceProtoset.FindSymbol("TestRequest"); err != nil {
|
|
t.Errorf("failed to find TestRequest in protoset: %v", err)
|
|
} else if md == nil {
|
|
t.Errorf("FindSymbol returned nil for TestRequest")
|
|
} else if _, ok := md.(*desc.MessageDescriptor); !ok {
|
|
t.Errorf("FindSymbol returned wrong kind of descriptor for TestRequest: %T", md)
|
|
}
|
|
}
|
|
|
|
func TestListServices(t *testing.T) {
|
|
for _, ds := range descSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
doTestListServices(t, ds.source, ds.includeRefl)
|
|
})
|
|
}
|
|
}
|
|
|
|
func doTestListServices(t *testing.T, source DescriptorSource, includeReflection bool) {
|
|
names, err := ListServices(source)
|
|
if err != nil {
|
|
t.Fatalf("failed to list services: %v", err)
|
|
}
|
|
var expected []string
|
|
if includeReflection {
|
|
// when using server reflection, we see the TestService as well as the ServerReflection service
|
|
expected = []string{"grpc.reflection.v1alpha.ServerReflection", "testing.TestService"}
|
|
} else {
|
|
// without reflection, we see all services defined in the same test.proto file, which is the
|
|
// TestService as well as UnimplementedService
|
|
expected = []string{"testing.TestService", "testing.UnimplementedService"}
|
|
}
|
|
if !reflect.DeepEqual(expected, names) {
|
|
t.Errorf("ListServices returned wrong results: wanted %v, got %v", expected, names)
|
|
}
|
|
}
|
|
|
|
func TestListMethods(t *testing.T) {
|
|
for _, ds := range descSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
doTestListMethods(t, ds.source, ds.includeRefl)
|
|
})
|
|
}
|
|
}
|
|
|
|
func doTestListMethods(t *testing.T, source DescriptorSource, includeReflection bool) {
|
|
names, err := ListMethods(source, "testing.TestService")
|
|
if err != nil {
|
|
t.Fatalf("failed to list methods for TestService: %v", err)
|
|
}
|
|
expected := []string{
|
|
"testing.TestService.EmptyCall",
|
|
"testing.TestService.FullDuplexCall",
|
|
"testing.TestService.HalfDuplexCall",
|
|
"testing.TestService.StreamingInputCall",
|
|
"testing.TestService.StreamingOutputCall",
|
|
"testing.TestService.UnaryCall",
|
|
}
|
|
if !reflect.DeepEqual(expected, names) {
|
|
t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names)
|
|
}
|
|
|
|
if includeReflection {
|
|
// when using server reflection, we see the TestService as well as the ServerReflection service
|
|
names, err = ListMethods(source, "grpc.reflection.v1alpha.ServerReflection")
|
|
if err != nil {
|
|
t.Fatalf("failed to list methods for ServerReflection: %v", err)
|
|
}
|
|
expected = []string{"grpc.reflection.v1alpha.ServerReflection.ServerReflectionInfo"}
|
|
} else {
|
|
// without reflection, we see all services defined in the same test.proto file, which is the
|
|
// TestService as well as UnimplementedService
|
|
names, err = ListMethods(source, "testing.UnimplementedService")
|
|
if err != nil {
|
|
t.Fatalf("failed to list methods for ServerReflection: %v", err)
|
|
}
|
|
expected = []string{"testing.UnimplementedService.UnimplementedCall"}
|
|
}
|
|
if !reflect.DeepEqual(expected, names) {
|
|
t.Errorf("ListMethods returned wrong results: wanted %v, got %v", expected, names)
|
|
}
|
|
|
|
// force an error
|
|
_, err = ListMethods(source, "FooService")
|
|
if err != nil && !strings.Contains(err.Error(), "Symbol not found: FooService") {
|
|
t.Errorf("ListMethods should have returned 'not found' error but instead returned %v", err)
|
|
}
|
|
}
|
|
|
|
func TestGetAllFiles(t *testing.T) {
|
|
expectedFiles := []string{"test.proto"}
|
|
// server reflection picks up filename from linked in Go package,
|
|
// which indicates "grpc_testing/test.proto", not our local copy.
|
|
expectedFilesWithReflection := [][]string{
|
|
{"grpc_reflection_v1alpha/reflection.proto", "test.proto"},
|
|
// depending on the version of grpc, the filenames could be prefixed with "interop/" and "reflection/"
|
|
{"reflection/grpc_reflection_v1alpha/reflection.proto", "test.proto"},
|
|
}
|
|
|
|
for _, ds := range descSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
files, err := GetAllFiles(ds.source)
|
|
if err != nil {
|
|
t.Fatalf("failed to get all files: %v", err)
|
|
}
|
|
names := fileNames(files)
|
|
match := false
|
|
var expected []string
|
|
if ds.includeRefl {
|
|
for _, expectedNames := range expectedFilesWithReflection {
|
|
expected = expectedNames
|
|
if reflect.DeepEqual(expected, names) {
|
|
match = true
|
|
break
|
|
}
|
|
}
|
|
} else {
|
|
expected = expectedFiles
|
|
match = reflect.DeepEqual(expected, names)
|
|
}
|
|
if !match {
|
|
t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expected, names)
|
|
}
|
|
})
|
|
}
|
|
|
|
// try cases with more complicated set of files
|
|
otherSourceProtoset, err := DescriptorSourceFromProtoSets("internal/testing/test.protoset", "internal/testing/example.protoset")
|
|
if err != nil {
|
|
t.Fatal(err.Error())
|
|
}
|
|
otherSourceProtoFiles, err := DescriptorSourceFromProtoFiles([]string{"internal/testing"}, "test.proto", "example.proto")
|
|
if err != nil {
|
|
t.Fatal(err.Error())
|
|
}
|
|
otherDescSources := []descSourceCase{
|
|
{"protoset[b]", otherSourceProtoset, false},
|
|
{"proto[b]", otherSourceProtoFiles, false},
|
|
}
|
|
expectedFiles = []string{
|
|
"example.proto",
|
|
"example2.proto",
|
|
"google/protobuf/any.proto",
|
|
"google/protobuf/descriptor.proto",
|
|
"google/protobuf/empty.proto",
|
|
"google/protobuf/timestamp.proto",
|
|
"test.proto",
|
|
}
|
|
for _, ds := range otherDescSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
files, err := GetAllFiles(ds.source)
|
|
if err != nil {
|
|
t.Fatalf("failed to get all files: %v", err)
|
|
}
|
|
names := fileNames(files)
|
|
if !reflect.DeepEqual(expectedFiles, names) {
|
|
t.Errorf("GetAllFiles returned wrong results: wanted %v, got %v", expectedFiles, names)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExpandHeaders(t *testing.T) {
|
|
inHeaders := []string{"key1: ${value}", "key2: bar", "key3: ${woo", "key4: woo}", "key5: ${TEST}",
|
|
"key6: ${TEST_VAR}", "${TEST}: ${TEST_VAR}", "key8: ${EMPTY}"}
|
|
os.Setenv("value", "value")
|
|
os.Setenv("TEST", "value5")
|
|
os.Setenv("TEST_VAR", "value6")
|
|
os.Setenv("EMPTY", "")
|
|
expectedHeaders := map[string]bool{"key1: value": true, "key2: bar": true, "key3: ${woo": true, "key4: woo}": true,
|
|
"key5: value5": true, "key6: value6": true, "value5: value6": true, "key8: ": true}
|
|
|
|
outHeaders, err := ExpandHeaders(inHeaders)
|
|
if err != nil {
|
|
t.Errorf("The ExpandHeaders function generated an unexpected error %s", err)
|
|
}
|
|
for _, expandedHeader := range outHeaders {
|
|
if _, ok := expectedHeaders[expandedHeader]; !ok {
|
|
t.Errorf("The ExpandHeaders function has returned an unexpected header. Received unexpected header %s", expandedHeader)
|
|
}
|
|
}
|
|
|
|
badHeaders := []string{"key: ${DNE}"}
|
|
_, err = ExpandHeaders(badHeaders)
|
|
if err == nil {
|
|
t.Errorf("The ExpandHeaders function should return an error for missing environment variables %q", badHeaders)
|
|
}
|
|
}
|
|
|
|
func fileNames(files []*desc.FileDescriptor) []string {
|
|
names := make([]string, len(files))
|
|
for i, f := range files {
|
|
names[i] = f.GetName()
|
|
}
|
|
return names
|
|
}
|
|
|
|
const expectKnownType = `{
|
|
"dur": "0s",
|
|
"ts": "1970-01-01T00:00:00Z",
|
|
"dbl": 0,
|
|
"flt": 0,
|
|
"i64": "0",
|
|
"u64": "0",
|
|
"i32": 0,
|
|
"u32": 0,
|
|
"bool": false,
|
|
"str": "",
|
|
"bytes": null,
|
|
"st": {"google.protobuf.Struct": "supports arbitrary JSON objects"},
|
|
"an": {"@type": "type.googleapis.com/google.protobuf.Empty", "value": {}},
|
|
"lv": [{"google.protobuf.ListValue": "is an array of arbitrary JSON values"}],
|
|
"val": {"google.protobuf.Value": "supports arbitrary JSON"}
|
|
}`
|
|
|
|
func TestMakeTemplateKnownTypes(t *testing.T) {
|
|
descriptor, err := desc.LoadMessageDescriptorForMessage((*jsonpbtest.KnownTypes)(nil))
|
|
if err != nil {
|
|
t.Fatalf("failed to load descriptor: %v", err)
|
|
}
|
|
message := MakeTemplate(descriptor)
|
|
|
|
jsm := jsonpb.Marshaler{EmitDefaults: true}
|
|
out, err := jsm.MarshalToString(message)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal to JSON: %v", err)
|
|
}
|
|
|
|
// make sure template JSON matches expected
|
|
var actual, expected interface{}
|
|
if err := json.Unmarshal([]byte(out), &actual); err != nil {
|
|
t.Fatalf("failed to parse actual JSON: %v", err)
|
|
}
|
|
if err := json.Unmarshal([]byte(expectKnownType), &expected); err != nil {
|
|
t.Fatalf("failed to parse expected JSON: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(actual, expected) {
|
|
t.Errorf("template message is not as expected; want:\n%s\ngot:\n%s", expectKnownType, out)
|
|
}
|
|
}
|
|
|
|
func TestDescribe(t *testing.T) {
|
|
for _, ds := range descSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
doTestDescribe(t, ds.source)
|
|
})
|
|
}
|
|
}
|
|
|
|
func doTestDescribe(t *testing.T, source DescriptorSource) {
|
|
sym := "testing.TestService.EmptyCall"
|
|
dsc, err := source.FindSymbol(sym)
|
|
if err != nil {
|
|
t.Fatalf("failed to get descriptor for %q: %v", sym, err)
|
|
}
|
|
if _, ok := dsc.(*desc.MethodDescriptor); !ok {
|
|
t.Fatalf("descriptor for %q was a %T (expecting a MethodDescriptor)", sym, dsc)
|
|
}
|
|
txt := proto.MarshalTextString(dsc.AsProto())
|
|
expected :=
|
|
`name: "EmptyCall"
|
|
input_type: ".testing.Empty"
|
|
output_type: ".testing.Empty"
|
|
`
|
|
if expected != txt {
|
|
t.Errorf("descriptor mismatch: expected %s, got %s", expected, txt)
|
|
}
|
|
|
|
sym = "testing.StreamingOutputCallResponse"
|
|
dsc, err = source.FindSymbol(sym)
|
|
if err != nil {
|
|
t.Fatalf("failed to get descriptor for %q: %v", sym, err)
|
|
}
|
|
if _, ok := dsc.(*desc.MessageDescriptor); !ok {
|
|
t.Fatalf("descriptor for %q was a %T (expecting a MessageDescriptor)", sym, dsc)
|
|
}
|
|
txt = proto.MarshalTextString(dsc.AsProto())
|
|
expected =
|
|
`name: "StreamingOutputCallResponse"
|
|
field: <
|
|
name: "payload"
|
|
number: 1
|
|
label: LABEL_OPTIONAL
|
|
type: TYPE_MESSAGE
|
|
type_name: ".testing.Payload"
|
|
json_name: "payload"
|
|
>
|
|
`
|
|
if expected != txt {
|
|
t.Errorf("descriptor mismatch: expected %s, got %s", expected, txt)
|
|
}
|
|
|
|
_, err = source.FindSymbol("FooService")
|
|
if err != nil && !strings.Contains(err.Error(), "Symbol not found: FooService") {
|
|
t.Errorf("FindSymbol should have returned 'not found' error but instead returned %v", err)
|
|
}
|
|
}
|
|
|
|
const (
|
|
// type == COMPRESSABLE, but that is default (since it has
|
|
// numeric value == 0) and thus doesn't actually get included
|
|
// on the wire
|
|
payload1 = `{
|
|
"payload": {
|
|
"body": "SXQncyBCdXNpbmVzcyBUaW1l"
|
|
}
|
|
}`
|
|
payload2 = `{
|
|
"payload": {
|
|
"type": "RANDOM",
|
|
"body": "Rm91eCBkdSBGYUZh"
|
|
}
|
|
}`
|
|
payload3 = `{
|
|
"payload": {
|
|
"type": "UNCOMPRESSABLE",
|
|
"body": "SGlwaG9wb3BvdGFtdXMgdnMuIFJoeW1lbm9jZXJvcw=="
|
|
}
|
|
}`
|
|
)
|
|
|
|
func getCC(includeRefl bool) *grpc.ClientConn {
|
|
if includeRefl {
|
|
return ccReflect
|
|
} else {
|
|
return ccNoReflect
|
|
}
|
|
}
|
|
|
|
func TestUnary(t *testing.T) {
|
|
for _, ds := range descSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
doTestUnary(t, getCC(ds.includeRefl), ds.source)
|
|
})
|
|
}
|
|
}
|
|
|
|
func doTestUnary(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
|
|
// Success
|
|
h := &handler{reqMessages: []string{payload1}}
|
|
err := InvokeRpc(context.Background(), source, cc, "testing.TestService/UnaryCall", makeHeaders(codes.OK), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
if h.check(t, "testing.TestService.UnaryCall", codes.OK, 1, 1) {
|
|
if h.respMessages[0] != payload1 {
|
|
t.Errorf("unexpected response from RPC: expecting %s; got %s", payload1, h.respMessages[0])
|
|
}
|
|
}
|
|
|
|
// Failure
|
|
h = &handler{reqMessages: []string{payload1}}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/UnaryCall", makeHeaders(codes.NotFound), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
h.check(t, "testing.TestService.UnaryCall", codes.NotFound, 1, 0)
|
|
}
|
|
|
|
func TestClientStream(t *testing.T) {
|
|
for _, ds := range descSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
doTestClientStream(t, getCC(ds.includeRefl), ds.source)
|
|
})
|
|
}
|
|
}
|
|
|
|
func doTestClientStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
|
|
// Success
|
|
h := &handler{reqMessages: []string{payload1, payload2, payload3}}
|
|
err := InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.OK), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
if h.check(t, "testing.TestService.StreamingInputCall", codes.OK, 3, 1) {
|
|
expected :=
|
|
`{
|
|
"aggregatedPayloadSize": 61
|
|
}`
|
|
if h.respMessages[0] != expected {
|
|
t.Errorf("unexpected response from RPC: expecting %s; got %s", expected, h.respMessages[0])
|
|
}
|
|
}
|
|
|
|
// Fail fast (server rejects as soon as possible)
|
|
h = &handler{reqMessages: []string{payload1, payload2, payload3}}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.InvalidArgument), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
h.check(t, "testing.TestService.StreamingInputCall", codes.InvalidArgument, -3, 0)
|
|
|
|
// Fail late (server waits until stream is complete to reject)
|
|
h = &handler{reqMessages: []string{payload1, payload2, payload3}}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingInputCall", makeHeaders(codes.Internal, true), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
h.check(t, "testing.TestService.StreamingInputCall", codes.Internal, 3, 0)
|
|
}
|
|
|
|
func TestServerStream(t *testing.T) {
|
|
for _, ds := range descSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
doTestServerStream(t, getCC(ds.includeRefl), ds.source)
|
|
})
|
|
}
|
|
}
|
|
|
|
func doTestServerStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
|
|
req := &grpcurl_testing.StreamingOutputCallRequest{
|
|
ResponseType: grpcurl_testing.PayloadType_COMPRESSABLE,
|
|
ResponseParameters: []*grpcurl_testing.ResponseParameters{
|
|
{Size: 10}, {Size: 20}, {Size: 30}, {Size: 40}, {Size: 50},
|
|
},
|
|
}
|
|
payload, err := (&jsonpb.Marshaler{}).MarshalToString(req)
|
|
if err != nil {
|
|
t.Fatalf("failed to construct request: %v", err)
|
|
}
|
|
|
|
// Success
|
|
h := &handler{reqMessages: []string{payload}}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.OK), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
if h.check(t, "testing.TestService.StreamingOutputCall", codes.OK, 1, 5) {
|
|
resp := &grpcurl_testing.StreamingOutputCallResponse{}
|
|
for i, msg := range h.respMessages {
|
|
if err := jsonpb.UnmarshalString(msg, resp); err != nil {
|
|
t.Errorf("failed to parse response %d: %v", i+1, err)
|
|
}
|
|
if resp.Payload.GetType() != grpcurl_testing.PayloadType_COMPRESSABLE {
|
|
t.Errorf("response %d has wrong payload type; expecting %v, got %v", i, grpcurl_testing.PayloadType_COMPRESSABLE, resp.Payload.Type)
|
|
}
|
|
if len(resp.Payload.Body) != (i+1)*10 {
|
|
t.Errorf("response %d has wrong payload size; expecting %d, got %d", i, (i+1)*10, len(resp.Payload.Body))
|
|
}
|
|
resp.Reset()
|
|
}
|
|
}
|
|
|
|
// Fail fast (server rejects as soon as possible)
|
|
h = &handler{reqMessages: []string{payload}}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.Aborted), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
h.check(t, "testing.TestService.StreamingOutputCall", codes.Aborted, 1, 0)
|
|
|
|
// Fail late (server waits until stream is complete to reject)
|
|
h = &handler{reqMessages: []string{payload}}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/StreamingOutputCall", makeHeaders(codes.AlreadyExists, true), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
h.check(t, "testing.TestService.StreamingOutputCall", codes.AlreadyExists, 1, 5)
|
|
}
|
|
|
|
func TestHalfDuplexStream(t *testing.T) {
|
|
for _, ds := range descSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
doTestHalfDuplexStream(t, getCC(ds.includeRefl), ds.source)
|
|
})
|
|
}
|
|
}
|
|
|
|
func doTestHalfDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
|
|
reqs := []string{payload1, payload2, payload3}
|
|
|
|
// Success
|
|
h := &handler{reqMessages: reqs}
|
|
err := InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.OK), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
if h.check(t, "testing.TestService.HalfDuplexCall", codes.OK, 3, 3) {
|
|
for i, resp := range h.respMessages {
|
|
if resp != reqs[i] {
|
|
t.Errorf("unexpected response %d from RPC:\nexpecting %q\ngot %q", i, reqs[i], resp)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fail fast (server rejects as soon as possible)
|
|
h = &handler{reqMessages: reqs}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.Canceled), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
h.check(t, "testing.TestService.HalfDuplexCall", codes.Canceled, -3, 0)
|
|
|
|
// Fail late (server waits until stream is complete to reject)
|
|
h = &handler{reqMessages: reqs}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/HalfDuplexCall", makeHeaders(codes.DataLoss, true), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
h.check(t, "testing.TestService.HalfDuplexCall", codes.DataLoss, 3, 3)
|
|
}
|
|
|
|
func TestFullDuplexStream(t *testing.T) {
|
|
for _, ds := range descSources {
|
|
t.Run(ds.name, func(t *testing.T) {
|
|
doTestFullDuplexStream(t, getCC(ds.includeRefl), ds.source)
|
|
})
|
|
}
|
|
}
|
|
|
|
func doTestFullDuplexStream(t *testing.T, cc *grpc.ClientConn, source DescriptorSource) {
|
|
reqs := make([]string, 3)
|
|
req := &grpcurl_testing.StreamingOutputCallRequest{
|
|
ResponseType: grpcurl_testing.PayloadType_RANDOM,
|
|
}
|
|
for i := range reqs {
|
|
req.ResponseParameters = append(req.ResponseParameters, &grpcurl_testing.ResponseParameters{Size: int32((i + 1) * 10)})
|
|
payload, err := (&jsonpb.Marshaler{}).MarshalToString(req)
|
|
if err != nil {
|
|
t.Fatalf("failed to construct request %d: %v", i, err)
|
|
}
|
|
reqs[i] = payload
|
|
}
|
|
|
|
// Success
|
|
h := &handler{reqMessages: reqs}
|
|
err := InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.OK), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
if h.check(t, "testing.TestService.FullDuplexCall", codes.OK, 3, 6) {
|
|
resp := &grpcurl_testing.StreamingOutputCallResponse{}
|
|
i := 0
|
|
for j := 1; j < 3; j++ {
|
|
// three requests
|
|
for k := 0; k < j; k++ {
|
|
// 1 response for first request, 2 for second, etc
|
|
msg := h.respMessages[i]
|
|
if err := jsonpb.UnmarshalString(msg, resp); err != nil {
|
|
t.Errorf("failed to parse response %d: %v", i+1, err)
|
|
}
|
|
if resp.Payload.GetType() != grpcurl_testing.PayloadType_RANDOM {
|
|
t.Errorf("response %d has wrong payload type; expecting %v, got %v", i, grpcurl_testing.PayloadType_RANDOM, resp.Payload.Type)
|
|
}
|
|
if len(resp.Payload.Body) != (k+1)*10 {
|
|
t.Errorf("response %d has wrong payload size; expecting %d, got %d", i, (k+1)*10, len(resp.Payload.Body))
|
|
}
|
|
resp.Reset()
|
|
|
|
i++
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fail fast (server rejects as soon as possible)
|
|
h = &handler{reqMessages: reqs}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.PermissionDenied), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
h.check(t, "testing.TestService.FullDuplexCall", codes.PermissionDenied, -3, 0)
|
|
|
|
// Fail late (server waits until stream is complete to reject)
|
|
h = &handler{reqMessages: reqs}
|
|
err = InvokeRpc(context.Background(), source, cc, "testing.TestService/FullDuplexCall", makeHeaders(codes.ResourceExhausted, true), h, h.getRequestData)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during RPC: %v", err)
|
|
}
|
|
|
|
h.check(t, "testing.TestService.FullDuplexCall", codes.ResourceExhausted, 3, 6)
|
|
}
|
|
|
|
type handler struct {
|
|
method *desc.MethodDescriptor
|
|
methodCount int
|
|
reqHeaders metadata.MD
|
|
reqHeadersCount int
|
|
reqMessages []string
|
|
reqMessagesCount int
|
|
respHeaders metadata.MD
|
|
respHeadersCount int
|
|
respMessages []string
|
|
respTrailers metadata.MD
|
|
respStatus *status.Status
|
|
respTrailersCount int
|
|
}
|
|
|
|
func (h *handler) getRequestData() ([]byte, error) {
|
|
// we don't use a mutex, though this method will be called from different goroutine
|
|
// than other methods for bidi calls, because this method does not share any state
|
|
// with the other methods.
|
|
h.reqMessagesCount++
|
|
if h.reqMessagesCount > len(h.reqMessages) {
|
|
return nil, io.EOF
|
|
}
|
|
if h.reqMessagesCount > 1 {
|
|
// insert delay between messages in request stream
|
|
time.Sleep(time.Millisecond * 50)
|
|
}
|
|
return []byte(h.reqMessages[h.reqMessagesCount-1]), nil
|
|
}
|
|
|
|
func (h *handler) OnResolveMethod(md *desc.MethodDescriptor) {
|
|
h.methodCount++
|
|
h.method = md
|
|
}
|
|
|
|
func (h *handler) OnSendHeaders(md metadata.MD) {
|
|
h.reqHeadersCount++
|
|
h.reqHeaders = md
|
|
}
|
|
|
|
func (h *handler) OnReceiveHeaders(md metadata.MD) {
|
|
h.respHeadersCount++
|
|
h.respHeaders = md
|
|
}
|
|
|
|
func (h *handler) OnReceiveResponse(msg proto.Message) {
|
|
jsm := jsonpb.Marshaler{Indent: " "}
|
|
respStr, err := jsm.MarshalToString(msg)
|
|
if err != nil {
|
|
panic(fmt.Errorf("failed to generate JSON form of response message: %v", err))
|
|
}
|
|
h.respMessages = append(h.respMessages, respStr)
|
|
}
|
|
|
|
func (h *handler) OnReceiveTrailers(stat *status.Status, md metadata.MD) {
|
|
h.respTrailersCount++
|
|
h.respTrailers = md
|
|
h.respStatus = stat
|
|
}
|
|
|
|
func (h *handler) check(t *testing.T, expectedMethod string, expectedCode codes.Code, expectedRequestQueries, expectedResponses int) bool {
|
|
// verify a few things were only ever called once
|
|
if h.methodCount != 1 {
|
|
t.Errorf("expected grpcurl to invoke OnResolveMethod once; was %d", h.methodCount)
|
|
}
|
|
if h.reqHeadersCount != 1 {
|
|
t.Errorf("expected grpcurl to invoke OnSendHeaders once; was %d", h.reqHeadersCount)
|
|
}
|
|
if h.reqHeadersCount != 1 {
|
|
t.Errorf("expected grpcurl to invoke OnSendHeaders once; was %d", h.reqHeadersCount)
|
|
}
|
|
if h.respHeadersCount != 1 {
|
|
t.Errorf("expected grpcurl to invoke OnReceiveHeaders once; was %d", h.respHeadersCount)
|
|
}
|
|
if h.respTrailersCount != 1 {
|
|
t.Errorf("expected grpcurl to invoke OnReceiveTrailers once; was %d", h.respTrailersCount)
|
|
}
|
|
|
|
// check other stuff against given expectations
|
|
if h.method.GetFullyQualifiedName() != expectedMethod {
|
|
t.Errorf("wrong method: expecting %v, got %v", expectedMethod, h.method.GetFullyQualifiedName())
|
|
}
|
|
if h.respStatus.Code() != expectedCode {
|
|
t.Errorf("wrong code: expecting %v, got %v", expectedCode, h.respStatus.Code())
|
|
}
|
|
if expectedRequestQueries < 0 {
|
|
// negative expectation means "negate and expect up to that number; could be fewer"
|
|
if h.reqMessagesCount > -expectedRequestQueries+1 {
|
|
// the + 1 is because there will be an extra query that returns EOF
|
|
t.Errorf("wrong number of messages queried: expecting no more than %v, got %v", -expectedRequestQueries, h.reqMessagesCount-1)
|
|
}
|
|
} else {
|
|
if h.reqMessagesCount != expectedRequestQueries+1 {
|
|
// the + 1 is because there will be an extra query that returns EOF
|
|
t.Errorf("wrong number of messages queried: expecting %v, got %v", expectedRequestQueries, h.reqMessagesCount-1)
|
|
}
|
|
}
|
|
if len(h.respMessages) != expectedResponses {
|
|
t.Errorf("wrong number of messages received: expecting %v, got %v", expectedResponses, len(h.respMessages))
|
|
}
|
|
|
|
// also check headers and trailers came through as expected
|
|
v := h.respHeaders["some-fake-header-1"]
|
|
if len(v) != 1 || v[0] != "val1" {
|
|
t.Errorf("wrong request header for %q: %v", "some-fake-header-1", v)
|
|
}
|
|
v = h.respHeaders["some-fake-header-2"]
|
|
if len(v) != 1 || v[0] != "val2" {
|
|
t.Errorf("wrong request header for %q: %v", "some-fake-header-2", v)
|
|
}
|
|
v = h.respTrailers["some-fake-trailer-1"]
|
|
if len(v) != 1 || v[0] != "valA" {
|
|
t.Errorf("wrong request header for %q: %v", "some-fake-trailer-1", v)
|
|
}
|
|
v = h.respTrailers["some-fake-trailer-2"]
|
|
if len(v) != 1 || v[0] != "valB" {
|
|
t.Errorf("wrong request header for %q: %v", "some-fake-trailer-2", v)
|
|
}
|
|
|
|
return len(h.respMessages) == expectedResponses
|
|
}
|
|
|
|
func makeHeaders(code codes.Code, failLate ...bool) []string {
|
|
if len(failLate) > 1 {
|
|
panic("incorrect use of makeContext; should be at most one failLate flag")
|
|
}
|
|
|
|
hdrs := append(make([]string, 0, 5),
|
|
fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyHeaders, "some-fake-header-1: val1"),
|
|
fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyHeaders, "some-fake-header-2: val2"),
|
|
fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyTrailers, "some-fake-trailer-1: valA"),
|
|
fmt.Sprintf("%s: %s", grpcurl_testing.MetadataReplyTrailers, "some-fake-trailer-2: valB"))
|
|
if code != codes.OK {
|
|
if len(failLate) > 0 && failLate[0] {
|
|
hdrs = append(hdrs, fmt.Sprintf("%s: %d", grpcurl_testing.MetadataFailLate, code))
|
|
} else {
|
|
hdrs = append(hdrs, fmt.Sprintf("%s: %d", grpcurl_testing.MetadataFailEarly, code))
|
|
}
|
|
}
|
|
|
|
return hdrs
|
|
}
|