Merge pull request #2 from fullstorydev/jh/support-binary-headers

support binary headers
This commit is contained in:
Joshua Humphries 2017-12-13 22:50:11 -05:00 committed by GitHub
commit cf5e463f0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 4 deletions

View File

@ -10,6 +10,7 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -37,6 +38,10 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// ErrReflectionNotSupported is returned by DescriptorSource operations that
// rely on interacting with the reflection service when the source does not
// actually expose the reflection service. When this occurs, an alternate source
// (like file descriptor sets) must be used.
var ErrReflectionNotSupported = errors.New("server does not support the reflection API") var ErrReflectionNotSupported = errors.New("server does not support the reflection API")
// DescriptorSource is a source of protobuf descriptor information. It can be backed by a FileDescriptorSet // DescriptorSource is a source of protobuf descriptor information. It can be backed by a FileDescriptorSet
@ -263,6 +268,9 @@ type InvocationEventHandler interface {
OnReceiveTrailers(*status.Status, metadata.MD) OnReceiveTrailers(*status.Status, metadata.MD)
} }
// RequestMessageSupplier is a function that is called to retrieve request
// messages for a GRPC operation. The message contents must be valid JSON. If
// the supplier has no more messages, it should return nil, io.EOF.
type RequestMessageSupplier func() (json.RawMessage, error) type RequestMessageSupplier func() (json.RawMessage, error)
// InvokeRpc uses te given GRPC connection to invoke the given method. The given descriptor source // InvokeRpc uses te given GRPC connection to invoke the given method. The given descriptor source
@ -605,6 +613,12 @@ func invokeBidi(ctx context.Context, cancel context.CancelFunc, stub grpcdynamic
return nil return nil
} }
// MetadataFromHeaders converts a list of header strings (each string in
// "Header-Name: Header-Value" form) into metadata. If a string has a header
// name without a value (e.g. does not contain a colon), the value is assumed
// to be blank. Binary headers (those whose names end in "-bin") should be
// base64-encoded. But if they cannot be base64-decoded, they will be assumed to
// be in raw form and used as is.
func MetadataFromHeaders(headers []string) metadata.MD { func MetadataFromHeaders(headers []string) metadata.MD {
md := make(metadata.MD) md := make(metadata.MD)
for _, part := range headers { for _, part := range headers {
@ -614,12 +628,38 @@ func MetadataFromHeaders(headers []string) metadata.MD {
pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter) pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
} }
headerName := strings.ToLower(strings.TrimSpace(pieces[0])) headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
md[headerName] = append(md[headerName], strings.TrimSpace(pieces[1])) val := strings.TrimSpace(pieces[1])
if strings.HasSuffix(headerName, "-bin") {
if v, err := decode(val); err == nil {
val = v
}
}
md[headerName] = append(md[headerName], val)
} }
} }
return md return md
} }
var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding}
func decode(val string) (string, error) {
var firstErr error
var b []byte
// we are lenient and can accept any of the flavors of base64 encoding
for _, d := range base64Codecs {
var err error
b, err = d.DecodeString(val)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return string(b), nil
}
return "", firstErr
}
func parseSymbol(svcAndMethod string) (string, string) { func parseSymbol(svcAndMethod string) (string, string) {
pos := strings.LastIndex(svcAndMethod, "/") pos := strings.LastIndex(svcAndMethod, "/")
if pos < 0 { if pos < 0 {
@ -631,6 +671,8 @@ func parseSymbol(svcAndMethod string) (string, string) {
return svcAndMethod[:pos], svcAndMethod[pos+1:] return svcAndMethod[:pos], svcAndMethod[pos+1:]
} }
// MetadataToString returns a string representation of the given metadata, for
// displaying to users.
func MetadataToString(md metadata.MD) string { func MetadataToString(md metadata.MD) string {
if len(md) == 0 { if len(md) == 0 {
return "(empty)" return "(empty)"
@ -640,6 +682,9 @@ func MetadataToString(md metadata.MD) string {
for _, v := range vs { for _, v := range vs {
b.WriteString(k) b.WriteString(k)
b.WriteString(": ") b.WriteString(": ")
if strings.HasSuffix(k, "-bin") {
v = base64.StdEncoding.EncodeToString([]byte(v))
}
b.WriteString(v) b.WriteString(v)
b.WriteString("\n") b.WriteString("\n")
} }
@ -647,11 +692,15 @@ func MetadataToString(md metadata.MD) string {
return b.String() return b.String()
} }
// GetDescriptorText returns a string representation of the given descriptor.
func GetDescriptorText(dsc desc.Descriptor, descSource DescriptorSource) (string, error) { func GetDescriptorText(dsc desc.Descriptor, descSource DescriptorSource) (string, error) {
dscProto := EnsureExtensions(descSource, dsc.AsProto()) dscProto := EnsureExtensions(descSource, dsc.AsProto())
return (&jsonpb.Marshaler{Indent: " "}).MarshalToString(dscProto) return (&jsonpb.Marshaler{Indent: " "}).MarshalToString(dscProto)
} }
// EnsureExtensions uses the given descriptor source to download extensions for
// the given message. It returns a copy of the given message, but as a dynamic
// message that knows about all extensions known to the given descriptor source.
func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message { func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message {
// load any server extensions so we can properly describe custom options // load any server extensions so we can properly describe custom options
dsc, err := desc.LoadMessageDescriptorForMessage(msg) dsc, err := desc.LoadMessageDescriptorForMessage(msg)

View File

@ -108,9 +108,22 @@ func TestBrokenTLS_ClientPlainText(t *testing.T) {
t.Fatalf("failed to create server creds: %v", err) t.Fatalf("failed to create server creds: %v", err)
} }
// client connection succeeds since client is not waiting for TLS handshake // client connection (usually) succeeds since client is not waiting for TLS handshake
e, err := createTestServerAndClient(serverCreds, nil) e, err := createTestServerAndClient(serverCreds, nil)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "deadline exceeded") {
// It is possible that connection never becomes healthy:
// 1) grpc connects successfully
// 2) grpc client tries to send HTTP/2 preface and settings frame
// 3) server, expecting handshake, closes the connection
// 4) in the client, the write fails, so the connection never
// becomes ready
// More often than not, the connection becomes ready (presumably
// the write to the socket succeeds before the server closes the
// connection). But when it does not, it is possible to observe
// timeouts when setting up the connection.
return
}
t.Fatalf("failed to setup server and client: %v", err) t.Fatalf("failed to setup server and client: %v", err)
} }
defer e.Close() defer e.Close()
@ -126,8 +139,7 @@ func TestBrokenTLS_ClientPlainText(t *testing.T) {
if !strings.Contains(err.Error(), "transport is closing") && if !strings.Contains(err.Error(), "transport is closing") &&
!strings.Contains(err.Error(), "connection is unavailable") && !strings.Contains(err.Error(), "connection is unavailable") &&
!strings.Contains(err.Error(), "use of closed network connection") && !strings.Contains(err.Error(), "use of closed network connection") &&
!strings.Contains(err.Error(), "all SubConns are in TransientFailure") && !strings.Contains(err.Error(), "all SubConns are in TransientFailure") {
!strings.Contains(err.Error(), "deadline exceeded") {
t.Fatalf("expecting transport failure, got: %v", err) t.Fatalf("expecting transport failure, got: %v", err)
} }