use custom dialer so we can show better error messages when things like TLS handshakes go awry; restore error message checks in tls tests

This commit is contained in:
Josh Humphries 2017-12-13 15:30:57 -05:00
parent 45e17ae10b
commit 6c05311fb9
4 changed files with 115 additions and 19 deletions

View File

@ -17,6 +17,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
@ -188,7 +189,9 @@ func main() {
} }
dialTime = time.Duration(t * float64(time.Second)) dialTime = time.Duration(t * float64(time.Second))
} }
opts := []grpc.DialOption{grpc.WithTimeout(dialTime), grpc.WithBlock()} ctx, cancel := context.WithTimeout(ctx, dialTime)
defer cancel()
var opts []grpc.DialOption
if *keepaliveTime != "" { if *keepaliveTime != "" {
t, err := strconv.ParseFloat(*keepaliveTime, 64) t, err := strconv.ParseFloat(*keepaliveTime, 64)
if err != nil { if err != nil {
@ -200,16 +203,15 @@ func main() {
Timeout: timeout, Timeout: timeout,
})) }))
} }
if *plaintext { var creds credentials.TransportCredentials
opts = append(opts, grpc.WithInsecure()) if !*plaintext {
} else { var err error
creds, err := grpcurl.ClientTransportCredentials(*insecure, *cacert, *cert, *key) creds, err = grpcurl.ClientTransportCredentials(*insecure, *cacert, *cert, *key)
if err != nil { if err != nil {
fail(err, "Failed to configure transport credentials") fail(err, "Failed to configure transport credentials")
} }
opts = append(opts, grpc.WithTransportCredentials(creds))
} }
cc, err := grpc.Dial(target, opts...) cc, err := grpcurl.BlockingDial(ctx, target, creds, opts...)
if err != nil { if err != nil {
fail(err, "Failed to dial target host %q", target) fail(err, "Failed to dial target host %q", target)
} }

View File

@ -15,10 +15,12 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"sort" "sort"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -844,3 +846,75 @@ func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string
return credentials.NewTLS(&tlsConf), nil return credentials.NewTLS(&tlsConf), nil
} }
// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
// and blocking until the returned connection is ready. If the given credentials are nil, the
// connection will be insecure (plain-text).
func BlockingDial(ctx context.Context, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
// grpc.Dial doesn't provide any information on permanent connection errors (like
// TLS handshake failures). So in order to provide good error messages, we need a
// custom dialer that can provide that info. That means we manage the TLS handshake.
var wg sync.WaitGroup
wg.Add(1)
var once sync.Once
result := make(chan interface{}, 1)
writeResult := func(res interface{}) {
// non-blocking write: we only need the first result
select {
case result <- res:
default:
}
}
dialer := func(address string, timeout time.Duration) (net.Conn, error) {
once.Do(func() { wg.Done() })
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", address)
if err != nil {
writeResult(err)
return nil, err
}
if creds != nil {
conn, _, err = creds.ClientHandshake(ctx, address, conn)
if err != nil {
writeResult(err)
return nil, err
}
}
return conn, nil
}
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
// know when we're done. So we run it in a goroutine and then use result
// channel to either get the channel or fail-fast.
go func() {
opts = append(opts,
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithDialer(dialer),
grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
)
conn, err := grpc.DialContext(ctx, address, opts...)
var res interface{}
if err != nil {
res = err
} else {
res = conn
}
writeResult(res)
}()
select {
case res := <-result:
if conn, ok := res.(*grpc.ClientConn); ok {
return conn, nil
} else {
return nil, res.(error)
}
case <-ctx.Done():
return nil, ctx.Err()
}
}

View File

@ -57,8 +57,10 @@ func TestMain(m *testing.M) {
defer svrReflect.Stop() defer svrReflect.Stop()
// And a corresponding client // And a corresponding client
if ccReflect, err = grpc.Dial(fmt.Sprintf("127.0.0.1:%d", portReflect), ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
grpc.WithInsecure(), grpc.WithTimeout(10*time.Second), grpc.WithBlock()); err != nil { defer cancel()
if ccReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portReflect),
grpc.WithInsecure(), grpc.WithBlock()); err != nil {
panic(err) panic(err)
} }
defer ccReflect.Close() defer ccReflect.Close()
@ -80,8 +82,10 @@ func TestMain(m *testing.M) {
defer svrProtoset.Stop() defer svrProtoset.Stop()
// And a corresponding client // And a corresponding client
if ccProtoset, err = grpc.Dial(fmt.Sprintf("127.0.0.1:%d", portProtoset), ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
grpc.WithInsecure(), grpc.WithTimeout(10*time.Second), grpc.WithBlock()); err != nil { defer cancel()
if ccProtoset, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portProtoset),
grpc.WithInsecure(), grpc.WithBlock()); err != nil {
panic(err) panic(err)
} }
defer ccProtoset.Close() defer ccProtoset.Close()

View File

@ -125,7 +125,8 @@ func TestBrokenTLS_ClientPlainText(t *testing.T) {
// various errors possible when server closes connection // various errors possible when server closes connection
if !strings.Contains(err.Error(), "transport is closing") && if !strings.Contains(err.Error(), "transport is closing") &&
!strings.Contains(err.Error(), "connection is unavailable") && !strings.Contains(err.Error(), "connection is unavailable") &&
!strings.Contains(err.Error(), "use of closed network connection") { !strings.Contains(err.Error(), "use of closed network connection") &&
!strings.Contains(err.Error(), "all SubConns are in TransientFailure") {
t.Fatalf("expecting transport failure, got: %v", err) t.Fatalf("expecting transport failure, got: %v", err)
} }
@ -142,6 +143,9 @@ func TestBrokenTLS_ServerPlainText(t *testing.T) {
t.Fatal("expecting TLS failure setting up server and client") t.Fatal("expecting TLS failure setting up server and client")
e.Close() e.Close()
} }
if !strings.Contains(err.Error(), "first record does not look like a TLS handshake") {
t.Fatalf("expecting TLS handshake failure, got: %v", err)
}
} }
func TestBrokenTLS_ServerUsesWrongCert(t *testing.T) { func TestBrokenTLS_ServerUsesWrongCert(t *testing.T) {
@ -159,6 +163,9 @@ func TestBrokenTLS_ServerUsesWrongCert(t *testing.T) {
t.Fatal("expecting TLS failure setting up server and client") t.Fatal("expecting TLS failure setting up server and client")
e.Close() e.Close()
} }
if !strings.Contains(err.Error(), "certificate is valid for") {
t.Fatalf("expecting TLS certificate error, got: %v", err)
}
} }
func TestBrokenTLS_ClientHasExpiredCert(t *testing.T) { func TestBrokenTLS_ClientHasExpiredCert(t *testing.T) {
@ -176,6 +183,9 @@ func TestBrokenTLS_ClientHasExpiredCert(t *testing.T) {
t.Fatal("expecting TLS failure setting up server and client") t.Fatal("expecting TLS failure setting up server and client")
e.Close() e.Close()
} }
if !strings.Contains(err.Error(), "bad certificate") {
t.Fatalf("expecting TLS certificate error, got: %v", err)
}
} }
func TestBrokenTLS_ServerHasExpiredCert(t *testing.T) { func TestBrokenTLS_ServerHasExpiredCert(t *testing.T) {
@ -193,6 +203,9 @@ func TestBrokenTLS_ServerHasExpiredCert(t *testing.T) {
t.Fatal("expecting TLS failure setting up server and client") t.Fatal("expecting TLS failure setting up server and client")
e.Close() e.Close()
} }
if !strings.Contains(err.Error(), "certificate has expired or is not yet valid") {
t.Fatalf("expecting TLS certificate expired, got: %v", err)
}
} }
func TestBrokenTLS_ClientNotTrusted(t *testing.T) { func TestBrokenTLS_ClientNotTrusted(t *testing.T) {
@ -210,6 +223,9 @@ func TestBrokenTLS_ClientNotTrusted(t *testing.T) {
t.Fatal("expecting TLS failure setting up server and client") t.Fatal("expecting TLS failure setting up server and client")
e.Close() e.Close()
} }
if !strings.Contains(err.Error(), "bad certificate") {
t.Fatalf("expecting TLS certificate error, got: %v", err)
}
} }
func TestBrokenTLS_ServerNotTrusted(t *testing.T) { func TestBrokenTLS_ServerNotTrusted(t *testing.T) {
@ -227,6 +243,9 @@ func TestBrokenTLS_ServerNotTrusted(t *testing.T) {
t.Fatal("expecting TLS failure setting up server and client") t.Fatal("expecting TLS failure setting up server and client")
e.Close() e.Close()
} }
if !strings.Contains(err.Error(), "certificate signed by unknown authority") {
t.Fatalf("expecting TLS certificate error, got: %v", err)
}
} }
func TestBrokenTLS_RequireClientCertButNonePresented(t *testing.T) { func TestBrokenTLS_RequireClientCertButNonePresented(t *testing.T) {
@ -244,6 +263,9 @@ func TestBrokenTLS_RequireClientCertButNonePresented(t *testing.T) {
t.Fatal("expecting TLS failure setting up server and client") t.Fatal("expecting TLS failure setting up server and client")
e.Close() e.Close()
} }
if !strings.Contains(err.Error(), "bad certificate") {
t.Fatalf("expecting TLS certificate error, got: %v", err)
}
} }
func simpleTest(t *testing.T, cc *grpc.ClientConn) { func simpleTest(t *testing.T, cc *grpc.ClientConn) {
@ -279,13 +301,7 @@ func createTestServerAndClient(serverCreds, clientCreds credentials.TransportCre
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel() defer cancel()
var tlsOpt grpc.DialOption cc, err := BlockingDial(ctx, fmt.Sprintf("127.0.0.1:%d", port), clientCreds)
if clientCreds != nil {
tlsOpt = grpc.WithTransportCredentials(clientCreds)
} else {
tlsOpt = grpc.WithInsecure()
}
cc, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", port), grpc.WithBlock(), tlsOpt)
if err != nil { if err != nil {
return e, err return e, err
} }