diff --git a/cmd/grpcurl/grpcurl.go b/cmd/grpcurl/grpcurl.go index b0e69a9..82b7bca 100644 --- a/cmd/grpcurl/grpcurl.go +++ b/cmd/grpcurl/grpcurl.go @@ -122,6 +122,9 @@ var ( probe is sent. If the connection remains idle and no keepalive response is received for this same period then the connection is closed and the operation fails.`)) + disableHalfClose = flags.Bool("disable-half-close", false, prettify(` + If true, the client will not call CloseSend() on the stream after all + request messages have been sent.`)) maxTime = flags.Float64("max-time", 0, prettify(` The maximum total time the operation can take, in seconds. This is useful for preventing batch jobs that use grpcurl from hanging due to @@ -696,7 +699,7 @@ func main() { VerbosityLevel: verbosityLevel, } - err = grpcurl.InvokeRPC(ctx, descSource, cc, symbol, append(addlHeaders, rpcHeaders...), h, rf.Next) + err = grpcurl.InvokeRPC(ctx, descSource, cc, symbol, append(addlHeaders, rpcHeaders...), h, rf.Next, !(*disableHalfClose)) if err != nil { if errStatus, ok := status.FromError(err); ok && *formatError { h.Status = errStatus diff --git a/invoke.go b/invoke.go index 0db362c..b843433 100644 --- a/invoke.go +++ b/invoke.go @@ -61,7 +61,7 @@ func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn return err } return jsonpb.Unmarshal(bytes.NewReader(data), m) - }) + }, false) } // RequestSupplier is a function that is called to populate messages for a gRPC operation. The @@ -85,7 +85,7 @@ type RequestSupplier func(proto.Message) error // than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where // one goroutine sends request messages and another consumes the response messages). func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string, - headers []string, handler InvocationEventHandler, requestData RequestSupplier) error { + headers []string, handler InvocationEventHandler, requestData RequestSupplier, bidiShouldHalfClose bool) error { md := MetadataFromHeaders(headers) @@ -140,7 +140,7 @@ func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Chan defer cancel() if mtd.IsClientStreaming() && mtd.IsServerStreaming() { - return invokeBidi(ctx, stub, mtd, handler, requestData, req) + return invokeBidi(ctx, stub, mtd, handler, requestData, req, bidiShouldHalfClose) } else if mtd.IsClientStreaming() { return invokeClientStream(ctx, stub, mtd, handler, requestData, req) } else if mtd.IsServerStreaming() { @@ -290,7 +290,7 @@ func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met } func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, - requestData RequestSupplier, req proto.Message) error { + requestData RequestSupplier, req proto.Message, shouldHalfClose bool) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -314,7 +314,9 @@ func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescr err = requestData(req) if err == io.EOF { - err = str.CloseSend() + if shouldHalfClose { + err = str.CloseSend() + } break } if err != nil {