use wrapped TransportCredentials instead of handling handshake in Dialer, so that grpc library will use correct :scheme

This commit is contained in:
Josh Humphries 2020-01-23 09:09:55 -05:00
parent 9572bd4525
commit cbf1d36242
1 changed files with 37 additions and 11 deletions

View File

@ -605,33 +605,44 @@ func BlockingDial(ctx context.Context, network, address string, creds credential
}
}
// custom credentials and dialer will notify on error via the
// writeResult function
if creds != nil {
creds = &errSignalingCreds{
TransportCredentials: creds,
writeResult: writeResult,
}
}
dialer := func(ctx context.Context, address string) (net.Conn, error) {
// NB: We *could* handle the TLS handshake ourselves, in the custom
// dialer (instead of customizing both the dialer and the credentials).
// But that requires using WithInsecure dial option (so that the gRPC
// library doesn't *also* try to do a handshake). And that would mean
// that the library would send the wrong ":scheme" metaheader to
// servers: it would send "http" instead of "https" because it is
// unaware that TLS is actually in use.
conn, err := (&net.Dialer{}).DialContext(ctx, network, address)
if err != nil {
writeResult(err)
return nil, err
}
if creds != nil {
conn, _, err = creds.ClientHandshake(ctx, address, conn)
if err != nil {
writeResult(err)
return nil, err
}
}
return conn, nil
return conn, err
}
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
// know when we're done. So we run it in a goroutine and then use result
// channel to either get the channel or fail-fast.
// channel to either get the connection or fail-fast.
go func() {
opts = append(opts,
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithContextDialer(dialer),
grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
)
if creds == nil {
opts = append(opts, grpc.WithInsecure())
} else {
opts = append(opts, grpc.WithTransportCredentials(creds))
}
conn, err := grpc.DialContext(ctx, address, opts...)
var res interface{}
if err != nil {
@ -652,3 +663,18 @@ func BlockingDial(ctx context.Context, network, address string, creds credential
return nil, ctx.Err()
}
}
// errSignalingCreds is a wrapper around a TransportCredentials value, but
// it will use the writeResult function to notify on error.
type errSignalingCreds struct {
credentials.TransportCredentials
writeResult func(res interface{})
}
func (c *errSignalingCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, addr, rawConn)
if err != nil {
c.writeResult(err)
}
return conn, auth, err
}