use wrapped TransportCredentials instead of handling handshake in Dialer (#130)

* use wrapped TransportCredentials instead of handling handshake in Dialer, so that grpc library will use correct :scheme
* support -authority for TLS conns; now effectively supercedes -servername flag
This commit is contained in:
Joshua Humphries 2020-01-27 12:15:47 -05:00 committed by GitHub
parent 9572bd4525
commit 0d669e78d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 20 deletions

View File

@ -71,8 +71,11 @@ var (
performed. This can be used to supply credentials/secrets without having performed. This can be used to supply credentials/secrets without having
to put them in command-line arguments.`)) to put them in command-line arguments.`))
authority = flags.String("authority", "", prettify(` authority = flags.String("authority", "", prettify(`
Value of :authority pseudo-header to be use with underlying HTTP/2 The authoritative name of the remote server. This value is passed as the
requests. It defaults to the given address.`)) value of the ":authority" pseudo-header in the HTTP/2 protocol. When TLS
is used, this will also be used as the server name when verifying the
server's certificate. It defaults to the address that is provided in the
positional arguments.`))
data = flags.String("d", "", prettify(` data = flags.String("d", "", prettify(`
Data for request contents. If the value is '@' then the request contents Data for request contents. If the value is '@' then the request contents
are read from stdin. For calls that accept a stream of requests, the are read from stdin. For calls that accept a stream of requests, the
@ -117,7 +120,12 @@ var (
verbose = flags.Bool("v", false, prettify(` verbose = flags.Bool("v", false, prettify(`
Enable verbose output.`)) Enable verbose output.`))
serverName = flags.String("servername", "", prettify(` serverName = flags.String("servername", "", prettify(`
Override server name when validating TLS certificate.`)) Override server name when validating TLS certificate. This flag is
ignored if -plaintext or -insecure is used.
NOTE: Prefer -authority. This flag may be removed in the future. It is
an error to use both -authority and -servername (though this will be
permitted if they are both set to the same value, to increase backwards
compatibility with earlier releases that allowed both to be set).`))
) )
func init() { func init() {
@ -305,9 +313,6 @@ func main() {
if *maxMsgSz > 0 { if *maxMsgSz > 0 {
opts = append(opts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(*maxMsgSz))) opts = append(opts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(*maxMsgSz)))
} }
if *authority != "" {
opts = append(opts, grpc.WithAuthority(*authority))
}
var creds credentials.TransportCredentials var creds credentials.TransportCredentials
if !*plaintext { if !*plaintext {
var err error var err error
@ -315,11 +320,27 @@ func main() {
if err != nil { if err != nil {
fail(err, "Failed to configure transport credentials") fail(err, "Failed to configure transport credentials")
} }
if *serverName != "" {
if err := creds.OverrideServerName(*serverName); err != nil { // can use either -servername or -authority; but not both
fail(err, "Failed to override server name as %q", *serverName) if *serverName != "" && *authority != "" {
if *serverName == *authority {
warn("Both -servername and -authority are present; prefer only -authority.")
} else {
fail(nil, "Cannot specify different values for -servername and -authority.")
} }
} }
overrideName := *serverName
if overrideName == "" {
overrideName = *authority
}
if overrideName != "" {
if err := creds.OverrideServerName(overrideName); err != nil {
fail(err, "Failed to override server name as %q", overrideName)
}
}
} else if *authority != "" {
opts = append(opts, grpc.WithAuthority(*authority))
} }
network := "tcp" network := "tcp"
if isUnixSocket != nil && isUnixSocket() { if isUnixSocket != nil && isUnixSocket() {

View File

@ -605,33 +605,44 @@ func BlockingDial(ctx context.Context, network, address string, creds credential
} }
} }
// custom credentials and dialer will notify on error via the
// writeResult function
if creds != nil {
creds = &errSignalingCreds{
TransportCredentials: creds,
writeResult: writeResult,
}
}
dialer := func(ctx context.Context, address string) (net.Conn, error) { dialer := func(ctx context.Context, address string) (net.Conn, error) {
// NB: We *could* handle the TLS handshake ourselves, in the custom
// dialer (instead of customizing both the dialer and the credentials).
// But that requires using WithInsecure dial option (so that the gRPC
// library doesn't *also* try to do a handshake). And that would mean
// that the library would send the wrong ":scheme" metaheader to
// servers: it would send "http" instead of "https" because it is
// unaware that TLS is actually in use.
conn, err := (&net.Dialer{}).DialContext(ctx, network, address) conn, err := (&net.Dialer{}).DialContext(ctx, network, address)
if err != nil { if err != nil {
writeResult(err) writeResult(err)
return nil, err
} }
if creds != nil { return conn, err
conn, _, err = creds.ClientHandshake(ctx, address, conn)
if err != nil {
writeResult(err)
return nil, err
}
}
return conn, nil
} }
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in // Even with grpc.FailOnNonTempDialError, this call will usually timeout in
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
// know when we're done. So we run it in a goroutine and then use result // know when we're done. So we run it in a goroutine and then use result
// channel to either get the channel or fail-fast. // channel to either get the connection or fail-fast.
go func() { go func() {
opts = append(opts, opts = append(opts,
grpc.WithBlock(), grpc.WithBlock(),
grpc.FailOnNonTempDialError(true), grpc.FailOnNonTempDialError(true),
grpc.WithContextDialer(dialer), grpc.WithContextDialer(dialer),
grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
) )
if creds == nil {
opts = append(opts, grpc.WithInsecure())
} else {
opts = append(opts, grpc.WithTransportCredentials(creds))
}
conn, err := grpc.DialContext(ctx, address, opts...) conn, err := grpc.DialContext(ctx, address, opts...)
var res interface{} var res interface{}
if err != nil { if err != nil {
@ -652,3 +663,18 @@ func BlockingDial(ctx context.Context, network, address string, creds credential
return nil, ctx.Err() return nil, ctx.Err()
} }
} }
// errSignalingCreds is a wrapper around a TransportCredentials value, but
// it will use the writeResult function to notify on error.
type errSignalingCreds struct {
credentials.TransportCredentials
writeResult func(res interface{})
}
func (c *errSignalingCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, addr, rawConn)
if err != nil {
c.writeResult(err)
}
return conn, auth, err
}