support unix sockets (#26)

This commit is contained in:
Joshua Humphries 2018-03-27 11:24:35 -04:00 committed by GitHub
parent 09a863d763
commit ca5693f42c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 70 additions and 16 deletions

View File

@ -33,6 +33,8 @@ import (
var (
exit = os.Exit
isUnixSocket func() bool // nil when run on non-unix platform
help = flag.Bool("help", false,
`Print usage instructions and exit.`)
plaintext = flag.Bool("plaintext", false,
@ -240,7 +242,11 @@ func main() {
}
}
}
cc, err := grpcurl.BlockingDial(ctx, target, creds, opts...)
network := "tcp"
if isUnixSocket != nil && isUnixSocket() {
network = "unix"
}
cc, err := grpcurl.BlockingDial(ctx, network, target, creds, opts...)
if err != nil {
fail(err, "Failed to dial target host %q", target)
}
@ -407,7 +413,7 @@ func main() {
func usage() {
fmt.Fprintf(os.Stderr, `Usage:
%s [flags] [host:port] [list|describe] [symbol]
%s [flags] [address] [list|describe] [symbol]
The 'host:port' is only optional when used with 'list' or 'describe' and a
protoset flag is provided.
@ -425,6 +431,11 @@ If neither verb is present, the symbol must be a fully-qualified method name in
be used to invoke the named method. If no body is given, an empty instance of
the method's request type will be sent.
The address will typically be in the form "host:port" where host can be an IP
address or a hostname and port is a numeric port or service name. If an IPv6
address is given, it must be surrounded by brackets, like "[2001:db8::1]". For
Unix variants, if a -unix=true flag is present, then the address must be the
path to the domain socket.
`, os.Args[0])
flag.PrintDefaults()

16
cmd/grpcurl/unix.go Normal file
View File

@ -0,0 +1,16 @@
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
package main
import "flag"
var (
unix = flag.Bool("unix", false,
`Indicates that the server address is the path to a Unix domain socket.`)
)
func init() {
isUnixSocket = func() bool {
return *unix
}
}

View File

@ -23,19 +23,21 @@ import (
)
var (
getUnixSocket func() string // nil when run on non-unix platforms
help = flag.Bool("help", false, "Print usage instructions and exit.")
cacert = flag.String("cacert", "",
"File containing trusted root certificates for verifying client certs. Ignored if\n"+
" TLS is not in use (e.g. no -cert or -key specified).")
`File containing trusted root certificates for verifying client certs. Ignored
if TLS is not in use (e.g. no -cert or -key specified).`)
cert = flag.String("cert", "",
"File containing server certificate (public key). Must also provide -key option.\n"+
" Server uses plain-text if no -cert and -key options are given.")
`File containing server certificate (public key). Must also provide -key option.
Server uses plain-text if no -cert and -key options are given.`)
key = flag.String("key", "",
"File containing server private key. Must also provide -cert option. Server uses\n"+
" plain-text if no -cert and -key options are given.")
`File containing server private key. Must also provide -cert option. Server uses
plain-text if no -cert and -key options are given.`)
requirecert = flag.Bool("requirecert", false,
"Require clients to authenticate via client certs. Must be using TLS (e.g. must also\n"+
" provide -cert and -key options).")
`Require clients to authenticate via client certs. Must be using TLS (e.g. must
also provide -cert and -key options).`)
port = flag.Int("p", 0, "Port on which to listen. Ephemeral port used if not specified.")
noreflect = flag.Bool("noreflect", false, "Indicates that server should not support server reflection.")
quiet = flag.Bool("q", false, "Suppresses server request and stream logging.")
@ -77,13 +79,20 @@ func main() {
opts = append(opts, grpc.UnaryInterceptor(unaryLogger), grpc.StreamInterceptor(streamLogger))
}
l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", *port))
var network, addr string
if getUnixSocket != nil && getUnixSocket() != "" {
network = "unix"
addr = getUnixSocket()
} else {
network = "tcp"
addr = fmt.Sprintf("127.0.0.1:%d", *port)
}
l, err := net.Listen(network, addr)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to listen on socket: %v\n", err)
os.Exit(1)
}
p := l.Addr().(*net.TCPAddr).Port
fmt.Printf("Listening on 127.0.0.1:%d\n", p)
fmt.Printf("Listening on %v\n", l.Addr())
svr := grpc.NewServer(opts...)

17
cmd/testserver/unix.go Normal file
View File

@ -0,0 +1,17 @@
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
package main
import "flag"
var (
unix = flag.String("unix", "",
`Use instead of -p to indicate listening on a Unix domain socket instead of a
TCP port. If present, must be the path to a domain socket.`)
)
func init() {
getUnixSocket = func() string {
return *unix
}
}

View File

@ -878,7 +878,7 @@ func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string
// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
// and blocking until the returned connection is ready. If the given credentials are nil, the
// connection will be insecure (plain-text).
func BlockingDial(ctx context.Context, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
// grpc.Dial doesn't provide any information on permanent connection errors (like
// TLS handshake failures). So in order to provide good error messages, we need a
// custom dialer that can provide that info. That means we manage the TLS handshake.
@ -895,7 +895,8 @@ func BlockingDial(ctx context.Context, address string, creds credentials.Transpo
dialer := func(address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
conn, err := (&net.Dialer{Cancel: ctx.Done()}).Dial("tcp", address)
conn, err := (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
if err != nil {
writeResult(err)
return nil, err

View File

@ -316,7 +316,7 @@ func createTestServerAndClient(serverCreds, clientCreds credentials.TransportCre
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
cc, err := BlockingDial(ctx, fmt.Sprintf("127.0.0.1:%d", port), clientCreds)
cc, err := BlockingDial(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port), clientCreds)
if err != nil {
return e, err
}