Merge pull request #513 from iamqizhao/master
Simplify compression API
diff --git a/call.go b/call.go
index f29396a..d4ae68b 100644
--- a/call.go
+++ b/call.go
@@ -57,7 +57,7 @@
}
p := &parser{s: stream}
for {
- if err = recv(p, dopts.codec, stream, dopts.dg, reply); err != nil {
+ if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
if err == io.EOF {
break
}
@@ -133,11 +133,7 @@
}
var (
lastErr error // record the error that happened
- cp Compressor
)
- if cc.dopts.cg != nil {
- cp = cc.dopts.cg()
- }
for {
var (
err error
@@ -152,8 +148,8 @@
Host: cc.authority,
Method: method,
}
- if cp != nil {
- callHdr.SendCompress = cp.Type()
+ if cc.dopts.cp != nil {
+ callHdr.SendCompress = cc.dopts.cp.Type()
}
t, err = cc.dopts.picker.Pick(ctx)
if err != nil {
@@ -166,7 +162,7 @@
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
- stream, err = sendRequest(ctx, cc.dopts.codec, cp, callHdr, t, args, topts)
+ stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
if err != nil {
if _, ok := err.(transport.ConnectionError); ok {
lastErr = err
diff --git a/clientconn.go b/clientconn.go
index 038ed88..28e74da 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -73,8 +73,8 @@
// values passed to Dial.
type dialOptions struct {
codec Codec
- cg CompressorGenerator
- dg DecompressorGenerator
+ cp Compressor
+ dc Decompressor
picker Picker
block bool
insecure bool
@@ -93,17 +93,17 @@
// WithCompressor returns a DialOption which sets a CompressorGenerator for generating message
// compressor.
-func WithCompressor(f CompressorGenerator) DialOption {
+func WithCompressor(cp Compressor) DialOption {
return func(o *dialOptions) {
- o.cg = f
+ o.cp = cp
}
}
// WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating
// message decompressor.
-func WithDecompressor(f DecompressorGenerator) DialOption {
+func WithDecompressor(dc Decompressor) DialOption {
return func(o *dialOptions) {
- o.dg = f
+ o.dc = dc
}
}
diff --git a/rpc_util.go b/rpc_util.go
index 427b49e..e98ddbc 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -135,12 +135,6 @@
return "gzip"
}
-// CompressorGenerator defines the function generating a Compressor.
-type CompressorGenerator func() Compressor
-
-// DecompressorGenerator defines the function generating a Decompressor.
-type DecompressorGenerator func() Decompressor
-
// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
failFast bool
@@ -290,15 +284,11 @@
return nil
}
-func recv(p *parser, c Codec, s *transport.Stream, dg DecompressorGenerator, m interface{}) error {
+func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
pf, d, err := p.recvMsg()
if err != nil {
return err
}
- var dc Decompressor
- if pf == compressionMade && dg != nil {
- dc = dg()
- }
if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil {
return err
}
diff --git a/server.go b/server.go
index dd86427..904a66a 100644
--- a/server.go
+++ b/server.go
@@ -93,8 +93,8 @@
type options struct {
creds credentials.Credentials
codec Codec
- cg CompressorGenerator
- dg DecompressorGenerator
+ cp Compressor
+ dc Decompressor
maxConcurrentStreams uint32
}
@@ -108,15 +108,15 @@
}
}
-func CompressON(f CompressorGenerator) ServerOption {
+func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) {
- o.cg = f
+ o.cp = cp
}
}
-func DecompressON(f DecompressorGenerator) ServerOption {
+func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) {
- o.dg = f
+ o.dc = dc
}
}
@@ -362,11 +362,7 @@
return err
}
- var dc Decompressor
- if pf == compressionMade && s.opts.dg != nil {
- dc = s.opts.dg()
- }
- if err := checkRecvPayload(pf, stream.RecvCompress(), dc); err != nil {
+ if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
switch err := err.(type) {
case transport.StreamError:
if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
@@ -385,7 +381,7 @@
df := func(v interface{}) error {
if pf == compressionMade {
var err error
- req, err = dc.Do(bytes.NewReader(req))
+ req, err = s.opts.dc.Do(bytes.NewReader(req))
if err != nil {
if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
@@ -427,12 +423,10 @@
Last: true,
Delay: false,
}
- var cp Compressor
- if s.opts.cg != nil {
- cp = s.opts.cg()
- stream.SetSendCompress(cp.Type())
+ if s.opts.cp != nil {
+ stream.SetSendCompress(s.opts.cp.Type())
}
- if err := s.sendResponse(t, stream, reply, cp, opts); err != nil {
+ if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
switch err := err.(type) {
case transport.ConnectionError:
// Nothing to do here.
@@ -453,21 +447,19 @@
}
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
- var cp Compressor
- if s.opts.cg != nil {
- cp = s.opts.cg()
- stream.SetSendCompress(cp.Type())
+ if s.opts.cp != nil {
+ stream.SetSendCompress(s.opts.cp.Type())
}
ss := &serverStream{
t: t,
s: stream,
p: &parser{s: stream},
codec: s.opts.codec,
- cp: cp,
- dg: s.opts.dg,
+ cp: s.opts.cp,
+ dc: s.opts.dc,
trInfo: trInfo,
}
- if cp != nil {
+ if ss.cp != nil {
ss.cbuf = new(bytes.Buffer)
}
if trInfo != nil {
diff --git a/stream.go b/stream.go
index e649c4c..4974d8a 100644
--- a/stream.go
+++ b/stream.go
@@ -105,28 +105,24 @@
if err != nil {
return nil, toRPCErr(err)
}
- var cp Compressor
- if cc.dopts.cg != nil {
- cp = cc.dopts.cg()
- }
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
callHdr := &transport.CallHdr{
Host: cc.authority,
Method: method,
Flush: desc.ServerStreams&&desc.ClientStreams,
}
- if cp != nil {
- callHdr.SendCompress = cp.Type()
+ if cc.dopts.cp != nil {
+ callHdr.SendCompress = cc.dopts.cp.Type()
}
cs := &clientStream{
desc: desc,
codec: cc.dopts.codec,
- cp: cp,
- dg: cc.dopts.dg,
+ cp: cc.dopts.cp,
+ dc: cc.dopts.dc,
tracing: EnableTracing,
}
- if cp != nil {
- callHdr.SendCompress = cp.Type()
+ if cc.dopts.cp != nil {
+ callHdr.SendCompress = cc.dopts.cp.Type()
cs.cbuf = new(bytes.Buffer)
}
if cs.tracing {
@@ -170,7 +166,7 @@
codec Codec
cp Compressor
cbuf *bytes.Buffer
- dg DecompressorGenerator
+ dc Decompressor
tracing bool // set to EnableTracing when the clientStream is created.
@@ -229,7 +225,7 @@
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
- err = recv(cs.p, cs.codec, cs.s, cs.dg, m)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@@ -248,7 +244,7 @@
return
}
// Special handling for client streaming rpc.
- err = recv(cs.p, cs.codec, cs.s, cs.dg, m)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@@ -334,7 +330,7 @@
p *parser
codec Codec
cp Compressor
- dg DecompressorGenerator
+ dc Decompressor
cbuf *bytes.Buffer
statusCode codes.Code
statusDesc string
@@ -402,5 +398,5 @@
ss.mu.Unlock()
}
}()
- return recv(ss.p, ss.codec, ss.s, ss.dg, m)
+ return recv(ss.p, ss.codec, ss.s, ss.dc, m)
}
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 82f8373..d0e4ea2 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -327,8 +327,8 @@
return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}}
}
-func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, e env) (s *grpc.Server, addr string) {
- sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.CompressON(cg), grpc.DecompressON(dg)}
+func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) {
+ sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.RPCCompressor(cp), grpc.RPCDecompressor(dc)}
la := ":0"
switch e.network {
case "unix":
@@ -367,16 +367,16 @@
return
}
-func clientSetUp(t *testing.T, addr string, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, ua string, e env) (cc *grpc.ClientConn) {
+func clientSetUp(t *testing.T, addr string, cp grpc.Compressor, dc grpc.Decompressor, ua string, e env) (cc *grpc.ClientConn) {
var derr error
if e.security == "tls" {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
- cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg))
+ cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cp), grpc.WithDecompressor(dc))
} else {
- cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg))
+ cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cp), grpc.WithDecompressor(dc))
}
if derr != nil {
t.Fatalf("Dial(%q) = %v", addr, derr)
@@ -1151,7 +1151,7 @@
func testCompressServerHasNoSupport(t *testing.T, e env) {
s, addr := serverSetUp(t, true, nil, math.MaxUint32, nil, nil, e)
- cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, nil, "", e)
+ cc := clientSetUp(t, addr, grpc.NewGZIPCompressor(), nil, "", e)
// Unary call
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
@@ -1203,8 +1203,8 @@
}
func testCompressOK(t *testing.T, e env) {
- s, addr := serverSetUp(t, true, nil, math.MaxUint32, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, e)
- cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, "", e)
+ s, addr := serverSetUp(t, true, nil, math.MaxUint32, grpc.NewGZIPCompressor(), grpc.NewGZIPDecompressor(), e)
+ cc := clientSetUp(t, addr, grpc.NewGZIPCompressor(), grpc.NewGZIPDecompressor(), "", e)
// Unary call
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)