Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type TransportOptions struct {
RoutingKey string `long:"rk" description:"The routing key overrides the service name traffic group for proxies."`
RoutingDelegate string `long:"rd" description:"The routing delegate overrides the routing key traffic group for proxies."`
ShardKey string `long:"sk" description:"The shard key is a transport header that clues where to send a request within a clustered traffic group."`
RPCEncoding string `long:"rpc-encoding" description:"Override the rpc-encoding header/metadata value for gRPC and HTTP transports. This does not re-encode the request body and is intended for development."`
Jaeger bool `long:"jaeger" description:"Use the Jaeger tracing client to send Uber style traces and baggage headers"`
TransportHeaders map[string]string `short:"T" long:"topt" description:"Transport options for TChannel, protocol headers for HTTP"`
HTTPMethod string `long:"http-method" description:"The HTTP method to use"`
Expand Down
117 changes: 117 additions & 0 deletions rpc_encoding_flag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package main

import (
"context"
"net"
"net/http"
"net/http/httptest"
"testing"

"github.com/opentracing/opentracing-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/yarpc/yab/encoding"
yabtransport "github.com/yarpc/yab/transport"

apitransport "go.uber.org/yarpc/api/transport"
yarpcgrpc "go.uber.org/yarpc/transport/grpc"
)

type unaryHandlerFunc func(context.Context, *apitransport.Request, apitransport.ResponseWriter) error

func (f unaryHandlerFunc) Handle(ctx context.Context, req *apitransport.Request, resw apitransport.ResponseWriter) error {
return f(ctx, req, resw)
}

type encodingCaptureRouter struct {
expectedService string
expectedProcedure string
capturedEncoding chan string
}

func (r *encodingCaptureRouter) Procedures() []apitransport.Procedure {
return nil
}

func (r *encodingCaptureRouter) Choose(_ context.Context, req *apitransport.Request) (apitransport.HandlerSpec, error) {
if req.Service == r.expectedService && req.Procedure == r.expectedProcedure {
select {
case r.capturedEncoding <- string(req.Encoding):
default:
}
return apitransport.NewUnaryHandlerSpec(unaryHandlerFunc(func(_ context.Context, _ *apitransport.Request, resw apitransport.ResponseWriter) error {
_, _ = resw.Write([]byte("ok"))
return nil
})), nil
}
return apitransport.HandlerSpec{}, apitransport.UnrecognizedProcedureError(req)
}

func TestRPCEncodingFlagOverridesHTTPHeader(t *testing.T) {
const want = "dev-override"

var got string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got = r.Header.Get("RPC-Encoding")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer srv.Close()

opts := TransportOptions{
ServiceName: "svc",
CallerName: "caller",
Peers: []string{srv.URL},
HTTPMethod: "POST",
RPCEncoding: want,
}

tp, err := getTransport(opts, resolvedProtocolEncoding{protocol: yabtransport.HTTP, enc: encoding.JSON}, opentracing.NoopTracer{})
require.NoError(t, err)

_, err = tp.Call(context.Background(), &yabtransport.Request{Method: "Foo::Bar", Body: []byte("hello")})
require.NoError(t, err)
assert.Equal(t, want, got)
}

func TestRPCEncodingFlagOverridesGRPCHeader(t *testing.T) {
const want = "dev-override"

gt := yarpcgrpc.NewTransport(yarpcgrpc.Tracer(opentracing.NoopTracer{}))
require.NoError(t, gt.Start())
defer func() { _ = gt.Stop() }()

lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() { _ = lis.Close() }()

router := &encodingCaptureRouter{
expectedService: "svc",
expectedProcedure: "svc::echo",
capturedEncoding: make(chan string, 1),
}

inbound := gt.NewInbound(lis)
inbound.SetRouter(router)
require.NoError(t, inbound.Start())
defer func() { _ = inbound.Stop() }()

opts := TransportOptions{
ServiceName: "svc",
CallerName: "caller",
Peers: []string{lis.Addr().String()},
RPCEncoding: want,
}
tp, err := getTransport(opts, resolvedProtocolEncoding{protocol: yabtransport.GRPC, enc: encoding.Protobuf}, opentracing.NoopTracer{})
require.NoError(t, err)

_, err = tp.Call(context.Background(), &yabtransport.Request{TargetService: "svc", Method: "svc::echo", Body: []byte("hello")})
require.NoError(t, err)

select {
case got := <-router.capturedEncoding:
assert.Equal(t, want, got)
default:
t.Fatal("did not capture inbound encoding")
}
}
12 changes: 10 additions & 2 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,25 +164,33 @@ func getTransport(opts TransportOptions, resolved resolvedProtocolEncoding, trac
}

if resolved.protocol == transport.GRPC {
grpcEncoding := resolved.enc.String()
if opts.RPCEncoding != "" {
grpcEncoding = opts.RPCEncoding
}
return transport.NewGRPC(transport.GRPCOptions{
Addresses: getHosts(opts.Peers),
Tracer: tracer,
Caller: opts.CallerName,
Encoding: resolved.enc.String(),
Encoding: grpcEncoding,
RoutingKey: opts.RoutingKey,
RoutingDelegate: opts.RoutingDelegate,
MaxResponseSize: opts.GRPCMaxResponseSize,
})
}

httpEncoding := resolved.enc.String()
if opts.RPCEncoding != "" {
httpEncoding = opts.RPCEncoding
}
hopts := transport.HTTPOptions{
Method: opts.HTTPMethod,
SourceService: opts.CallerName,
TargetService: opts.ServiceName,
RoutingDelegate: opts.RoutingDelegate,
RoutingKey: opts.RoutingKey,
ShardKey: opts.ShardKey,
Encoding: resolved.enc.String(),
Encoding: httpEncoding,
URLs: opts.Peers,
Tracer: tracer,
UseHTTP2: opts.UseHTTP2,
Expand Down