From cbf1d36242ebbf592ba4fa7f8ee2b2cfc5672a8e Mon Sep 17 00:00:00 2001 From: Josh Humphries Date: Thu, 23 Jan 2020 09:09:55 -0500 Subject: [PATCH] use wrapped TransportCredentials instead of handling handshake in Dialer, so that grpc library will use correct :scheme --- grpcurl.go | 48 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/grpcurl.go b/grpcurl.go index 3c5c607..521fb1c 100644 --- a/grpcurl.go +++ b/grpcurl.go @@ -605,33 +605,44 @@ func BlockingDial(ctx context.Context, network, address string, creds credential } } + // custom credentials and dialer will notify on error via the + // writeResult function + if creds != nil { + creds = &errSignalingCreds{ + TransportCredentials: creds, + writeResult: writeResult, + } + } dialer := func(ctx context.Context, address string) (net.Conn, error) { + // NB: We *could* handle the TLS handshake ourselves, in the custom + // dialer (instead of customizing both the dialer and the credentials). + // But that requires using WithInsecure dial option (so that the gRPC + // library doesn't *also* try to do a handshake). And that would mean + // that the library would send the wrong ":scheme" metaheader to + // servers: it would send "http" instead of "https" because it is + // unaware that TLS is actually in use. conn, err := (&net.Dialer{}).DialContext(ctx, network, address) if err != nil { writeResult(err) - return nil, err } - if creds != nil { - conn, _, err = creds.ClientHandshake(ctx, address, conn) - if err != nil { - writeResult(err) - return nil, err - } - } - return conn, nil + return conn, err } // Even with grpc.FailOnNonTempDialError, this call will usually timeout in // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to // know when we're done. So we run it in a goroutine and then use result - // channel to either get the channel or fail-fast. + // channel to either get the connection or fail-fast. go func() { opts = append(opts, grpc.WithBlock(), grpc.FailOnNonTempDialError(true), grpc.WithContextDialer(dialer), - grpc.WithInsecure(), // we are handling TLS, so tell grpc not to ) + if creds == nil { + opts = append(opts, grpc.WithInsecure()) + } else { + opts = append(opts, grpc.WithTransportCredentials(creds)) + } conn, err := grpc.DialContext(ctx, address, opts...) var res interface{} if err != nil { @@ -652,3 +663,18 @@ func BlockingDial(ctx context.Context, network, address string, creds credential return nil, ctx.Err() } } + +// errSignalingCreds is a wrapper around a TransportCredentials value, but +// it will use the writeResult function to notify on error. +type errSignalingCreds struct { + credentials.TransportCredentials + writeResult func(res interface{}) +} + +func (c *errSignalingCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, addr, rawConn) + if err != nil { + c.writeResult(err) + } + return conn, auth, err +}