Merge pull request #1 from fullstorydev/jh/fix-tls-test
provide custom dialer to yield good error messages for dial errors, including TLS handshakes
This commit is contained in:
commit
58aba8cee5
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
|
reflectpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
|
||||||
|
|
@ -188,7 +189,9 @@ func main() {
|
||||||
}
|
}
|
||||||
dialTime = time.Duration(t * float64(time.Second))
|
dialTime = time.Duration(t * float64(time.Second))
|
||||||
}
|
}
|
||||||
opts := []grpc.DialOption{grpc.WithTimeout(dialTime), grpc.WithBlock()}
|
ctx, cancel := context.WithTimeout(ctx, dialTime)
|
||||||
|
defer cancel()
|
||||||
|
var opts []grpc.DialOption
|
||||||
if *keepaliveTime != "" {
|
if *keepaliveTime != "" {
|
||||||
t, err := strconv.ParseFloat(*keepaliveTime, 64)
|
t, err := strconv.ParseFloat(*keepaliveTime, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -200,16 +203,15 @@ func main() {
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
if *plaintext {
|
var creds credentials.TransportCredentials
|
||||||
opts = append(opts, grpc.WithInsecure())
|
if !*plaintext {
|
||||||
} else {
|
var err error
|
||||||
creds, err := grpcurl.ClientTransportCredentials(*insecure, *cacert, *cert, *key)
|
creds, err = grpcurl.ClientTransportCredentials(*insecure, *cacert, *cert, *key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fail(err, "Failed to configure transport credentials")
|
fail(err, "Failed to configure transport credentials")
|
||||||
}
|
}
|
||||||
opts = append(opts, grpc.WithTransportCredentials(creds))
|
|
||||||
}
|
}
|
||||||
cc, err := grpc.Dial(target, opts...)
|
cc, err := grpcurl.BlockingDial(ctx, target, creds, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fail(err, "Failed to dial target host %q", target)
|
fail(err, "Failed to dial target host %q", target)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
72
grpcurl.go
72
grpcurl.go
|
|
@ -15,10 +15,12 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/jsonpb"
|
"github.com/golang/protobuf/jsonpb"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
|
@ -685,7 +687,7 @@ func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry,
|
||||||
fds, err := source.AllExtensionsForType(msgTypeName)
|
fds, err := source.AllExtensionsForType(msgTypeName)
|
||||||
for _, fd := range fds {
|
for _, fd := range fds {
|
||||||
if err = ext.AddExtension(fd); err != nil {
|
if err = ext.AddExtension(fd); err != nil {
|
||||||
return fmt.Errorf("could not register extension %d of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err)
|
return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -844,3 +846,71 @@ func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string
|
||||||
|
|
||||||
return credentials.NewTLS(&tlsConf), nil
|
return credentials.NewTLS(&tlsConf), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
|
||||||
|
// and blocking until the returned connection is ready. If the given credentials are nil, the
|
||||||
|
// connection will be insecure (plain-text).
|
||||||
|
func BlockingDial(ctx context.Context, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
|
||||||
|
// grpc.Dial doesn't provide any information on permanent connection errors (like
|
||||||
|
// TLS handshake failures). So in order to provide good error messages, we need a
|
||||||
|
// custom dialer that can provide that info. That means we manage the TLS handshake.
|
||||||
|
result := make(chan interface{}, 1)
|
||||||
|
|
||||||
|
writeResult := func(res interface{}) {
|
||||||
|
// non-blocking write: we only need the first result
|
||||||
|
select {
|
||||||
|
case result <- res:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := func(address string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
conn, err := (&net.Dialer{Cancel: ctx.Done()}).Dial("tcp", address)
|
||||||
|
if err != nil {
|
||||||
|
writeResult(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if creds != nil {
|
||||||
|
conn, _, err = creds.ClientHandshake(ctx, address, conn)
|
||||||
|
if err != nil {
|
||||||
|
writeResult(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
|
||||||
|
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
|
||||||
|
// know when we're done. So we run it in a goroutine and then use result
|
||||||
|
// channel to either get the channel or fail-fast.
|
||||||
|
go func() {
|
||||||
|
opts = append(opts,
|
||||||
|
grpc.WithBlock(),
|
||||||
|
grpc.FailOnNonTempDialError(true),
|
||||||
|
grpc.WithDialer(dialer),
|
||||||
|
grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
|
||||||
|
)
|
||||||
|
conn, err := grpc.DialContext(ctx, address, opts...)
|
||||||
|
var res interface{}
|
||||||
|
if err != nil {
|
||||||
|
res = err
|
||||||
|
} else {
|
||||||
|
res = conn
|
||||||
|
}
|
||||||
|
writeResult(res)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case res := <-result:
|
||||||
|
if conn, ok := res.(*grpc.ClientConn); ok {
|
||||||
|
return conn, nil
|
||||||
|
} else {
|
||||||
|
return nil, res.(error)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,8 +57,10 @@ func TestMain(m *testing.M) {
|
||||||
defer svrReflect.Stop()
|
defer svrReflect.Stop()
|
||||||
|
|
||||||
// And a corresponding client
|
// And a corresponding client
|
||||||
if ccReflect, err = grpc.Dial(fmt.Sprintf("127.0.0.1:%d", portReflect),
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
grpc.WithInsecure(), grpc.WithTimeout(10*time.Second), grpc.WithBlock()); err != nil {
|
defer cancel()
|
||||||
|
if ccReflect, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portReflect),
|
||||||
|
grpc.WithInsecure(), grpc.WithBlock()); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer ccReflect.Close()
|
defer ccReflect.Close()
|
||||||
|
|
@ -80,8 +82,10 @@ func TestMain(m *testing.M) {
|
||||||
defer svrProtoset.Stop()
|
defer svrProtoset.Stop()
|
||||||
|
|
||||||
// And a corresponding client
|
// And a corresponding client
|
||||||
if ccProtoset, err = grpc.Dial(fmt.Sprintf("127.0.0.1:%d", portProtoset),
|
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
grpc.WithInsecure(), grpc.WithTimeout(10*time.Second), grpc.WithBlock()); err != nil {
|
defer cancel()
|
||||||
|
if ccProtoset, err = grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", portProtoset),
|
||||||
|
grpc.WithInsecure(), grpc.WithBlock()); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer ccProtoset.Close()
|
defer ccProtoset.Close()
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,9 @@ func TestBrokenTLS_ClientPlainText(t *testing.T) {
|
||||||
// various errors possible when server closes connection
|
// various errors possible when server closes connection
|
||||||
if !strings.Contains(err.Error(), "transport is closing") &&
|
if !strings.Contains(err.Error(), "transport is closing") &&
|
||||||
!strings.Contains(err.Error(), "connection is unavailable") &&
|
!strings.Contains(err.Error(), "connection is unavailable") &&
|
||||||
!strings.Contains(err.Error(), "use of closed network connection") {
|
!strings.Contains(err.Error(), "use of closed network connection") &&
|
||||||
|
!strings.Contains(err.Error(), "all SubConns are in TransientFailure") &&
|
||||||
|
!strings.Contains(err.Error(), "deadline exceeded") {
|
||||||
|
|
||||||
t.Fatalf("expecting transport failure, got: %v", err)
|
t.Fatalf("expecting transport failure, got: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -297,13 +299,10 @@ func createTestServerAndClient(serverCreds, clientCreds credentials.TransportCre
|
||||||
port := l.Addr().(*net.TCPAddr).Port
|
port := l.Addr().(*net.TCPAddr).Port
|
||||||
go svr.Serve(l)
|
go svr.Serve(l)
|
||||||
|
|
||||||
cliOpts := []grpc.DialOption{grpc.WithTimeout(2 * time.Second), grpc.WithBlock()}
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
if clientCreds != nil {
|
defer cancel()
|
||||||
cliOpts = append(cliOpts, grpc.WithTransportCredentials(clientCreds))
|
|
||||||
} else {
|
cc, err := BlockingDial(ctx, fmt.Sprintf("127.0.0.1:%d", port), clientCreds)
|
||||||
cliOpts = append(cliOpts, grpc.WithInsecure())
|
|
||||||
}
|
|
||||||
cc, err := grpc.Dial(fmt.Sprintf("127.0.0.1:%d", port), cliOpts...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return e, err
|
return e, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue