diff --git a/options.go b/options.go index 9a5be9a..56af078 100644 --- a/options.go +++ b/options.go @@ -90,6 +90,7 @@ type TransportOptions struct { RoutingKey string `long:"rk" description:"The routing key overrides the service name traffic group for proxies."` RoutingDelegate string `long:"rd" description:"The routing delegate overrides the routing key traffic group for proxies."` ShardKey string `long:"sk" description:"The shard key is a transport header that clues where to send a request within a clustered traffic group."` + RPCEncoding string `long:"en" description:"Override the rpc-encoding header/metadata value for gRPC and HTTP transports. This does not re-encode the request body and is intended for development."` Jaeger bool `long:"jaeger" description:"Use the Jaeger tracing client to send Uber style traces and baggage headers"` TransportHeaders map[string]string `short:"T" long:"topt" description:"Transport options for TChannel, protocol headers for HTTP"` HTTPMethod string `long:"http-method" description:"The HTTP method to use"` diff --git a/rpc_encoding_flag_test.go b/rpc_encoding_flag_test.go new file mode 100644 index 0000000..70282dc --- /dev/null +++ b/rpc_encoding_flag_test.go @@ -0,0 +1,88 @@ +package main + +import ( + "context" + "net" + "testing" + + "github.com/opentracing/opentracing-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yarpc/yab/encoding" + yabtransport "github.com/yarpc/yab/transport" + + apitransport "go.uber.org/yarpc/api/transport" + yarpcgrpc "go.uber.org/yarpc/transport/grpc" +) + +type unaryHandlerFunc func(context.Context, *apitransport.Request, apitransport.ResponseWriter) error + +func (f unaryHandlerFunc) Handle(ctx context.Context, req *apitransport.Request, resw apitransport.ResponseWriter) error { + return f(ctx, req, resw) +} + +type encodingCaptureRouter struct { + expectedService string + expectedProcedure string + capturedEncoding chan string +} + +func (r *encodingCaptureRouter) Procedures() []apitransport.Procedure { + return nil +} + +func (r *encodingCaptureRouter) Choose(_ context.Context, req *apitransport.Request) (apitransport.HandlerSpec, error) { + if req.Service == r.expectedService && req.Procedure == r.expectedProcedure { + select { + case r.capturedEncoding <- string(req.Encoding): + default: + } + return apitransport.NewUnaryHandlerSpec(unaryHandlerFunc(func(_ context.Context, _ *apitransport.Request, resw apitransport.ResponseWriter) error { + _, _ = resw.Write([]byte("ok")) + return nil + })), nil + } + return apitransport.HandlerSpec{}, apitransport.UnrecognizedProcedureError(req) +} + +func TestRPCEncodingFlagOverridesGRPCHeader(t *testing.T) { + const want = "dev-override" + + gt := yarpcgrpc.NewTransport(yarpcgrpc.Tracer(opentracing.NoopTracer{})) + require.NoError(t, gt.Start()) + defer func() { _ = gt.Stop() }() + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { _ = lis.Close() }() + + router := &encodingCaptureRouter{ + expectedService: "svc", + expectedProcedure: "svc::echo", + capturedEncoding: make(chan string, 1), + } + + inbound := gt.NewInbound(lis) + inbound.SetRouter(router) + require.NoError(t, inbound.Start()) + defer func() { _ = inbound.Stop() }() + + opts := TransportOptions{ + ServiceName: "svc", + CallerName: "caller", + Peers: []string{lis.Addr().String()}, + RPCEncoding: want, + } + tp, err := getTransport(opts, resolvedProtocolEncoding{protocol: yabtransport.GRPC, enc: encoding.Protobuf}, opentracing.NoopTracer{}) + require.NoError(t, err) + + _, err = tp.Call(context.Background(), &yabtransport.Request{TargetService: "svc", Method: "svc::echo", Body: []byte("hello")}) + require.NoError(t, err) + + select { + case got := <-router.capturedEncoding: + assert.Equal(t, want, got) + default: + t.Fatal("did not capture inbound encoding") + } +} diff --git a/transport.go b/transport.go index 7dfff2a..6bc5e73 100644 --- a/transport.go +++ b/transport.go @@ -164,11 +164,15 @@ func getTransport(opts TransportOptions, resolved resolvedProtocolEncoding, trac } if resolved.protocol == transport.GRPC { + grpcEncoding := resolved.enc.String() + if opts.RPCEncoding != "" { + grpcEncoding = opts.RPCEncoding + } return transport.NewGRPC(transport.GRPCOptions{ Addresses: getHosts(opts.Peers), Tracer: tracer, Caller: opts.CallerName, - Encoding: resolved.enc.String(), + Encoding: grpcEncoding, RoutingKey: opts.RoutingKey, RoutingDelegate: opts.RoutingDelegate, MaxResponseSize: opts.GRPCMaxResponseSize,