diff --git a/grpcurl.go b/grpcurl.go index 3c5c607..521fb1c 100644 --- a/grpcurl.go +++ b/grpcurl.go @@ -605,33 +605,44 @@ func BlockingDial(ctx context.Context, network, address string, creds credential } } + // custom credentials and dialer will notify on error via the + // writeResult function + if creds != nil { + creds = &errSignalingCreds{ + TransportCredentials: creds, + writeResult: writeResult, + } + } dialer := func(ctx context.Context, address string) (net.Conn, error) { + // NB: We *could* handle the TLS handshake ourselves, in the custom + // dialer (instead of customizing both the dialer and the credentials). + // But that requires using WithInsecure dial option (so that the gRPC + // library doesn't *also* try to do a handshake). And that would mean + // that the library would send the wrong ":scheme" metaheader to + // servers: it would send "http" instead of "https" because it is + // unaware that TLS is actually in use. conn, err := (&net.Dialer{}).DialContext(ctx, network, address) if err != nil { writeResult(err) - return nil, err } - if creds != nil { - conn, _, err = creds.ClientHandshake(ctx, address, conn) - if err != nil { - writeResult(err) - return nil, err - } - } - return conn, nil + return conn, err } // Even with grpc.FailOnNonTempDialError, this call will usually timeout in // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to // know when we're done. So we run it in a goroutine and then use result - // channel to either get the channel or fail-fast. + // channel to either get the connection or fail-fast. go func() { opts = append(opts, grpc.WithBlock(), grpc.FailOnNonTempDialError(true), grpc.WithContextDialer(dialer), - grpc.WithInsecure(), // we are handling TLS, so tell grpc not to ) + if creds == nil { + opts = append(opts, grpc.WithInsecure()) + } else { + opts = append(opts, grpc.WithTransportCredentials(creds)) + } conn, err := grpc.DialContext(ctx, address, opts...) var res interface{} if err != nil { @@ -652,3 +663,18 @@ func BlockingDial(ctx context.Context, network, address string, creds credential return nil, ctx.Err() } } + +// errSignalingCreds is a wrapper around a TransportCredentials value, but +// it will use the writeResult function to notify on error. +type errSignalingCreds struct { + credentials.TransportCredentials + writeResult func(res interface{}) +} + +func (c *errSignalingCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, addr, rawConn) + if err != nil { + c.writeResult(err) + } + return conn, auth, err +}