diff --git a/cmd/grpcurl/grpcurl.go b/cmd/grpcurl/grpcurl.go index 1a9cc67..28603ef 100644 --- a/cmd/grpcurl/grpcurl.go +++ b/cmd/grpcurl/grpcurl.go @@ -17,6 +17,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" @@ -188,7 +189,9 @@ func main() { } dialTime = time.Duration(t * float64(time.Second)) } - opts := []grpc.DialOption{grpc.WithTimeout(dialTime), grpc.WithBlock()} + ctx, cancel := context.WithTimeout(ctx, dialTime) + defer cancel() + var opts []grpc.DialOption if *keepaliveTime != "" { t, err := strconv.ParseFloat(*keepaliveTime, 64) if err != nil { @@ -200,16 +203,15 @@ func main() { Timeout: timeout, })) } - if *plaintext { - opts = append(opts, grpc.WithInsecure()) - } else { - creds, err := grpcurl.ClientTransportCredentials(*insecure, *cacert, *cert, *key) + var creds credentials.TransportCredentials + if !*plaintext { + var err error + creds, err = grpcurl.ClientTransportCredentials(*insecure, *cacert, *cert, *key) if err != nil { fail(err, "Failed to configure transport credentials") } - opts = append(opts, grpc.WithTransportCredentials(creds)) } - cc, err := grpc.Dial(target, opts...) + cc, err := grpcurl.BlockingDial(ctx, target, creds, opts...) if err != nil { fail(err, "Failed to dial target host %q", target) } diff --git a/grpcurl.go b/grpcurl.go index 4e132c5..1112e1c 100644 --- a/grpcurl.go +++ b/grpcurl.go @@ -15,10 +15,12 @@ import ( "fmt" "io" "io/ioutil" + "net" "sort" "strings" "sync" "sync/atomic" + "time" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" @@ -844,3 +846,75 @@ func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string return credentials.NewTLS(&tlsConf), nil } + +// BlockingDial is a helper method to dial the given address, using optional TLS credentials, +// and blocking until the returned connection is ready. If the given credentials are nil, the +// connection will be insecure (plain-text). +func BlockingDial(ctx context.Context, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + // grpc.Dial doesn't provide any information on permanent connection errors (like + // TLS handshake failures). So in order to provide good error messages, we need a + // custom dialer that can provide that info. That means we manage the TLS handshake. + var wg sync.WaitGroup + wg.Add(1) + var once sync.Once + result := make(chan interface{}, 1) + + writeResult := func(res interface{}) { + // non-blocking write: we only need the first result + select { + case result <- res: + default: + } + } + + dialer := func(address string, timeout time.Duration) (net.Conn, error) { + once.Do(func() { wg.Done() }) + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", address) + if err != nil { + writeResult(err) + return nil, err + } + if creds != nil { + conn, _, err = creds.ClientHandshake(ctx, address, conn) + if err != nil { + writeResult(err) + return nil, err + } + } + return conn, nil + } + + // Even with grpc.FailOnNonTempDialError, this call will usually timeout in + // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to + // know when we're done. So we run it in a goroutine and then use result + // channel to either get the channel or fail-fast. + go func() { + opts = append(opts, + grpc.WithBlock(), + grpc.FailOnNonTempDialError(true), + grpc.WithDialer(dialer), + grpc.WithInsecure(), // we are handling TLS, so tell grpc not to + ) + conn, err := grpc.DialContext(ctx, address, opts...) + var res interface{} + if err != nil { + res = err + } else { + res = conn + } + writeResult(res) + }() + + select { + case res := <-result: + if conn, ok := res.(*grpc.ClientConn); ok { + return conn, nil + } else { + return nil, res.(error) + } + case <-ctx.Done(): + return nil, ctx.Err() + } +} diff --git a/grpcurl_test.go b/grpcurl_test.go index d06798a..13613b7 100644 --- a/grpcurl_test.go +++ b/grpcurl_test.go @@ -57,8 +57,10 @@ func TestMain(m *testing.M) { defer svrReflect.Stop() // And a corresponding client - if ccReflect, err = grpc.Dial(fmt.Sprintf("127.0.0.1:%d", portReflect), - grpc.WithInsecure(), grpc.WithTimeout(10*time.Second), grpc.WithBlock()); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if ccReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portReflect), + grpc.WithInsecure(), grpc.WithBlock()); err != nil { panic(err) } defer ccReflect.Close() @@ -80,8 +82,10 @@ func TestMain(m *testing.M) { defer svrProtoset.Stop() // And a corresponding client - if ccProtoset, err = grpc.Dial(fmt.Sprintf("127.0.0.1:%d", portProtoset), - grpc.WithInsecure(), grpc.WithTimeout(10*time.Second), grpc.WithBlock()); err != nil { + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if ccProtoset, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portProtoset), + grpc.WithInsecure(), grpc.WithBlock()); err != nil { panic(err) } defer ccProtoset.Close() diff --git a/tls_settings_test.go b/tls_settings_test.go index 65813b5..f396f8f 100644 --- a/tls_settings_test.go +++ b/tls_settings_test.go @@ -125,7 +125,8 @@ func TestBrokenTLS_ClientPlainText(t *testing.T) { // various errors possible when server closes connection if !strings.Contains(err.Error(), "transport is closing") && !strings.Contains(err.Error(), "connection is unavailable") && - !strings.Contains(err.Error(), "use of closed network connection") { + !strings.Contains(err.Error(), "use of closed network connection") && + !strings.Contains(err.Error(), "all SubConns are in TransientFailure") { t.Fatalf("expecting transport failure, got: %v", err) } @@ -142,6 +143,9 @@ func TestBrokenTLS_ServerPlainText(t *testing.T) { t.Fatal("expecting TLS failure setting up server and client") e.Close() } + if !strings.Contains(err.Error(), "first record does not look like a TLS handshake") { + t.Fatalf("expecting TLS handshake failure, got: %v", err) + } } func TestBrokenTLS_ServerUsesWrongCert(t *testing.T) { @@ -159,6 +163,9 @@ func TestBrokenTLS_ServerUsesWrongCert(t *testing.T) { t.Fatal("expecting TLS failure setting up server and client") e.Close() } + if !strings.Contains(err.Error(), "certificate is valid for") { + t.Fatalf("expecting TLS certificate error, got: %v", err) + } } func TestBrokenTLS_ClientHasExpiredCert(t *testing.T) { @@ -176,6 +183,9 @@ func TestBrokenTLS_ClientHasExpiredCert(t *testing.T) { t.Fatal("expecting TLS failure setting up server and client") e.Close() } + if !strings.Contains(err.Error(), "bad certificate") { + t.Fatalf("expecting TLS certificate error, got: %v", err) + } } func TestBrokenTLS_ServerHasExpiredCert(t *testing.T) { @@ -193,6 +203,9 @@ func TestBrokenTLS_ServerHasExpiredCert(t *testing.T) { t.Fatal("expecting TLS failure setting up server and client") e.Close() } + if !strings.Contains(err.Error(), "certificate has expired or is not yet valid") { + t.Fatalf("expecting TLS certificate expired, got: %v", err) + } } func TestBrokenTLS_ClientNotTrusted(t *testing.T) { @@ -210,6 +223,9 @@ func TestBrokenTLS_ClientNotTrusted(t *testing.T) { t.Fatal("expecting TLS failure setting up server and client") e.Close() } + if !strings.Contains(err.Error(), "bad certificate") { + t.Fatalf("expecting TLS certificate error, got: %v", err) + } } func TestBrokenTLS_ServerNotTrusted(t *testing.T) { @@ -227,6 +243,9 @@ func TestBrokenTLS_ServerNotTrusted(t *testing.T) { t.Fatal("expecting TLS failure setting up server and client") e.Close() } + if !strings.Contains(err.Error(), "certificate signed by unknown authority") { + t.Fatalf("expecting TLS certificate error, got: %v", err) + } } func TestBrokenTLS_RequireClientCertButNonePresented(t *testing.T) { @@ -244,6 +263,9 @@ func TestBrokenTLS_RequireClientCertButNonePresented(t *testing.T) { t.Fatal("expecting TLS failure setting up server and client") e.Close() } + if !strings.Contains(err.Error(), "bad certificate") { + t.Fatalf("expecting TLS certificate error, got: %v", err) + } } func simpleTest(t *testing.T, cc *grpc.ClientConn) { @@ -279,13 +301,7 @@ func createTestServerAndClient(serverCreds, clientCreds credentials.TransportCre ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - var tlsOpt grpc.DialOption - if clientCreds != nil { - tlsOpt = grpc.WithTransportCredentials(clientCreds) - } else { - tlsOpt = grpc.WithInsecure() - } - cc, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", port), grpc.WithBlock(), tlsOpt) + cc, err := BlockingDial(ctx, fmt.Sprintf("127.0.0.1:%d", port), clientCreds) if err != nil { return e, err }