use custom dialer so we can show better error messages when things like TLS handshakes go awry; restore error message checks in tls tests

This commit is contained in:
Josh Humphries
2017-12-13 15:30:57 -05:00
parent 45e17ae10b
commit 6c05311fb9
4 changed files with 115 additions and 19 deletions

View File

@@ -15,10 +15,12 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
@@ -844,3 +846,75 @@ func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string
return credentials.NewTLS(&tlsConf), nil
}
// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
// and blocking until the returned connection is ready. If the given credentials are nil, the
// connection will be insecure (plain-text).
func BlockingDial(ctx context.Context, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
// grpc.Dial doesn't provide any information on permanent connection errors (like
// TLS handshake failures). So in order to provide good error messages, we need a
// custom dialer that can provide that info. That means we manage the TLS handshake.
var wg sync.WaitGroup
wg.Add(1)
var once sync.Once
result := make(chan interface{}, 1)
writeResult := func(res interface{}) {
// non-blocking write: we only need the first result
select {
case result <- res:
default:
}
}
dialer := func(address string, timeout time.Duration) (net.Conn, error) {
once.Do(func() { wg.Done() })
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", address)
if err != nil {
writeResult(err)
return nil, err
}
if creds != nil {
conn, _, err = creds.ClientHandshake(ctx, address, conn)
if err != nil {
writeResult(err)
return nil, err
}
}
return conn, nil
}
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
// know when we're done. So we run it in a goroutine and then use result
// channel to either get the channel or fail-fast.
go func() {
opts = append(opts,
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithDialer(dialer),
grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
)
conn, err := grpc.DialContext(ctx, address, opts...)
var res interface{}
if err != nil {
res = err
} else {
res = conn
}
writeResult(res)
}()
select {
case res := <-result:
if conn, ok := res.(*grpc.ClientConn); ok {
return conn, nil
} else {
return nil, res.(error)
}
case <-ctx.Done():
return nil, ctx.Err()
}
}