Merge pull request #500 from iamqizhao/master
Support connection level compression
diff --git a/call.go b/call.go
index 89d2782..f29396a 100644
--- a/call.go
+++ b/call.go
@@ -34,6 +34,7 @@
package grpc
import (
+ "bytes"
"io"
"time"
@@ -47,7 +48,7 @@
// On error, it returns the error and indicates whether the call should be retried.
//
// TODO(zhaoq): Check whether the received message sequence is valid.
-func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
+func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any.
var err error
c.headerMD, err = stream.Header()
@@ -56,7 +57,7 @@
}
p := &parser{s: stream}
for {
- if err = recv(p, codec, reply); err != nil {
+ if err = recv(p, dopts.codec, stream, dopts.dg, reply); err != nil {
if err == io.EOF {
break
}
@@ -68,7 +69,7 @@
}
// sendRequest writes out various information of an RPC such as Context and Message.
-func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
+func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
return nil, err
@@ -80,8 +81,11 @@
}
}
}()
- // TODO(zhaoq): Support compression.
- outBuf, err := encode(codec, args, compressionNone)
+ var cbuf *bytes.Buffer
+ if compressor != nil {
+ cbuf = new(bytes.Buffer)
+ }
+ outBuf, err := encode(codec, args, compressor, cbuf)
if err != nil {
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
}
@@ -129,7 +133,11 @@
}
var (
lastErr error // record the error that happened
+ cp Compressor
)
+ if cc.dopts.cg != nil {
+ cp = cc.dopts.cg()
+ }
for {
var (
err error
@@ -144,6 +152,9 @@
Host: cc.authority,
Method: method,
}
+ if cp != nil {
+ callHdr.SendCompress = cp.Type()
+ }
t, err = cc.dopts.picker.Pick(ctx)
if err != nil {
if lastErr != nil {
@@ -155,7 +166,7 @@
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
- stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts)
+ stream, err = sendRequest(ctx, cc.dopts.codec, cp, callHdr, t, args, topts)
if err != nil {
if _, ok := err.(transport.ConnectionError); ok {
lastErr = err
@@ -167,7 +178,7 @@
return toRPCErr(err)
}
// Receive the response
- lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply)
+ lastErr = recvResponse(cc.dopts, t, &c, stream, reply)
if _, ok := lastErr.(transport.ConnectionError); ok {
continue
}
diff --git a/call_test.go b/call_test.go
index 48d25e5..22e42c2 100644
--- a/call_test.go
+++ b/call_test.go
@@ -98,7 +98,7 @@
}
}
// send a response back to end the stream.
- reply, err := encode(testCodec{}, &expectedResponse, compressionNone)
+ reply, err := encode(testCodec{}, &expectedResponse, nil, nil)
if err != nil {
t.Fatalf("Failed to encode the response: %v", err)
}
diff --git a/clientconn.go b/clientconn.go
index 9c2e983..038ed88 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -73,6 +73,8 @@
// values passed to Dial.
type dialOptions struct {
codec Codec
+ cg CompressorGenerator
+ dg DecompressorGenerator
picker Picker
block bool
insecure bool
@@ -89,6 +91,22 @@
}
}
+// WithCompressor returns a DialOption which sets a CompressorGenerator for generating message
+// compressor.
+func WithCompressor(f CompressorGenerator) DialOption {
+ return func(o *dialOptions) {
+ o.cg = f
+ }
+}
+
+// WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating
+// message decompressor.
+func WithDecompressor(f DecompressorGenerator) DialOption {
+ return func(o *dialOptions) {
+ o.dg = f
+ }
+}
+
// WithPicker returns a DialOption which sets a picker for connection selection.
func WithPicker(p Picker) DialOption {
return func(o *dialOptions) {
diff --git a/rpc_util.go b/rpc_util.go
index e6b2236..427b49e 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -34,9 +34,12 @@
package grpc
import (
+ "bytes"
+ "compress/gzip"
"encoding/binary"
"fmt"
"io"
+ "io/ioutil"
"math"
"math/rand"
"os"
@@ -75,6 +78,69 @@
return "proto"
}
+// Compressor defines the interface gRPC uses to compress a message.
+type Compressor interface {
+ // Do compresses p into w.
+ Do(w io.Writer, p []byte) error
+ // Type returns the compression algorithm the Compressor uses.
+ Type() string
+}
+
+// NewGZIPCompressor creates a Compressor based on GZIP.
+func NewGZIPCompressor() Compressor {
+ return &gzipCompressor{}
+}
+
+type gzipCompressor struct {
+}
+
+func (c *gzipCompressor) Do(w io.Writer, p []byte) error {
+ z := gzip.NewWriter(w)
+ if _, err := z.Write(p); err != nil {
+ return err
+ }
+ return z.Close()
+}
+
+func (c *gzipCompressor) Type() string {
+ return "gzip"
+}
+
+// Decompressor defines the interface gRPC uses to decompress a message.
+type Decompressor interface {
+ // Do reads the data from r and uncompress them.
+ Do(r io.Reader) ([]byte, error)
+ // Type returns the compression algorithm the Decompressor uses.
+ Type() string
+}
+
+type gzipDecompressor struct {
+}
+
+// NewGZIPDecompressor creates a Decompressor based on GZIP.
+func NewGZIPDecompressor() Decompressor {
+ return &gzipDecompressor{}
+}
+
+func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) {
+ z, err := gzip.NewReader(r)
+ if err != nil {
+ return nil, err
+ }
+ defer z.Close()
+ return ioutil.ReadAll(z)
+}
+
+func (d *gzipDecompressor) Type() string {
+ return "gzip"
+}
+
+// CompressorGenerator defines the function generating a Compressor.
+type CompressorGenerator func() Compressor
+
+// DecompressorGenerator defines the function generating a Decompressor.
+type DecompressorGenerator func() Decompressor
+
// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
failFast bool
@@ -126,8 +192,7 @@
const (
compressionNone payloadFormat = iota // no compression
- compressionFlate
- // More formats
+ compressionMade
)
// parser reads complelete gRPC messages from the underlying reader.
@@ -166,7 +231,7 @@
// encode serializes msg and prepends the message header. If msg is nil, it
// generates the message header of 0 message length.
-func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
+func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte, error) {
var b []byte
var length uint
if msg != nil {
@@ -176,6 +241,12 @@
if err != nil {
return nil, err
}
+ if cp != nil {
+ if err := cp.Do(cbuf, b); err != nil {
+ return nil, err
+ }
+ b = cbuf.Bytes()
+ }
length = uint(len(b))
}
if length > math.MaxUint32 {
@@ -190,7 +261,11 @@
var buf = make([]byte, payloadLen+sizeLen+len(b))
// Write payload format
- buf[0] = byte(pf)
+ if cp == nil {
+ buf[0] = byte(compressionNone)
+ } else {
+ buf[0] = byte(compressionMade)
+ }
// Write length of b into buf
binary.BigEndian.PutUint32(buf[1:], uint32(length))
// Copy encoded msg to buf
@@ -199,22 +274,42 @@
return buf, nil
}
-func recv(p *parser, c Codec, m interface{}) error {
+func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error {
+ switch pf {
+ case compressionNone:
+ case compressionMade:
+ if recvCompress == "" {
+ return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
+ }
+ if dc == nil || recvCompress != dc.Type() {
+ return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
+ }
+ default:
+ return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
+ }
+ return nil
+}
+
+func recv(p *parser, c Codec, s *transport.Stream, dg DecompressorGenerator, m interface{}) error {
pf, d, err := p.recvMsg()
if err != nil {
return err
}
- switch pf {
- case compressionNone:
- if err := c.Unmarshal(d, m); err != nil {
- if rErr, ok := err.(rpcError); ok {
- return rErr
- } else {
- return Errorf(codes.Internal, "grpc: %v", err)
- }
+ var dc Decompressor
+ if pf == compressionMade && dg != nil {
+ dc = dg()
+ }
+ if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil {
+ return err
+ }
+ if pf == compressionMade {
+ d, err = dc.Do(bytes.NewReader(d))
+ if err != nil {
+ return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
- default:
- return Errorf(codes.Internal, "gprc: compression is not supported yet.")
+ }
+ if err := c.Unmarshal(d, m); err != nil {
+ return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
}
return nil
}
diff --git a/rpc_util_test.go b/rpc_util_test.go
index 2673cd0..3f3749a 100644
--- a/rpc_util_test.go
+++ b/rpc_util_test.go
@@ -106,16 +106,40 @@
for _, test := range []struct {
// input
msg proto.Message
- pt payloadFormat
+ cp Compressor
// outputs
b []byte
err error
}{
- {nil, compressionNone, []byte{0, 0, 0, 0, 0}, nil},
+ {nil, nil, []byte{0, 0, 0, 0, 0}, nil},
} {
- b, err := encode(protoCodec{}, test.msg, test.pt)
+ b, err := encode(protoCodec{}, test.msg, nil, nil)
if err != test.err || !bytes.Equal(b, test.b) {
- t.Fatalf("encode(_, _, %d) = %v, %v\nwant %v, %v", test.pt, b, err, test.b, test.err)
+ t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err)
+ }
+ }
+}
+
+func TestCompress(t *testing.T) {
+ for _, test := range []struct {
+ // input
+ data []byte
+ cp Compressor
+ dc Decompressor
+ // outputs
+ err error
+ }{
+ {make([]byte, 1024), &gzipCompressor{}, &gzipDecompressor{}, nil},
+ } {
+ b := new(bytes.Buffer)
+ if err := test.cp.Do(b, test.data); err != test.err {
+ t.Fatalf("Compressor.Do(_, %v) = %v, want %v", test.data, err, test.err)
+ }
+ if b.Len() >= len(test.data) {
+ t.Fatalf("The compressor fails to compress data.")
+ }
+ if p, err := test.dc.Do(b); err != nil || !bytes.Equal(test.data, p) {
+ t.Fatalf("Decompressor.Do(%v) = %v, %v, want %v, <nil>", b, p, err, test.data)
}
}
}
@@ -158,12 +182,12 @@
// bytes.
func bmEncode(b *testing.B, mSize int) {
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
- encoded, _ := encode(protoCodec{}, msg, compressionNone)
+ encoded, _ := encode(protoCodec{}, msg, nil, nil)
encodedSz := int64(len(encoded))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
- encode(protoCodec{}, msg, compressionNone)
+ encode(protoCodec{}, msg, nil, nil)
}
b.SetBytes(encodedSz)
}
diff --git a/server.go b/server.go
index 655e7d8..a7ff16c 100644
--- a/server.go
+++ b/server.go
@@ -34,6 +34,7 @@
package grpc
import (
+ "bytes"
"errors"
"fmt"
"io"
@@ -92,6 +93,8 @@
type options struct {
creds credentials.Credentials
codec Codec
+ cg CompressorGenerator
+ dg DecompressorGenerator
maxConcurrentStreams uint32
}
@@ -105,6 +108,18 @@
}
}
+func CompressON(f CompressorGenerator) ServerOption {
+ return func(o *options) {
+ o.cg = f
+ }
+}
+
+func DecompressON(f DecompressorGenerator) ServerOption {
+ return func(o *options) {
+ o.dg = f
+ }
+}
+
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
@@ -287,8 +302,12 @@
}
}
-func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error {
- p, err := encode(s.opts.codec, msg, pf)
+func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
+ var cbuf *bytes.Buffer
+ if cp != nil {
+ cbuf = new(bytes.Buffer)
+ }
+ p, err := encode(s.opts.codec, msg, cp, cbuf)
if err != nil {
// This typically indicates a fatal issue (e.g., memory
// corruption or hardware faults) the application program
@@ -327,84 +346,124 @@
// Nothing to do here.
case transport.StreamError:
if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err))
}
return err
}
- switch pf {
- case compressionNone:
- statusCode := codes.OK
- statusDesc := ""
- df := func(v interface{}) error {
- if err := s.opts.codec.Unmarshal(req, v); err != nil {
- return err
+
+ var dc Decompressor
+ if pf == compressionMade && s.opts.dg != nil {
+ dc = s.opts.dg()
+ }
+ if err := checkRecvPayload(pf, stream.RecvCompress(), dc); err != nil {
+ switch err := err.(type) {
+ case transport.StreamError:
+ if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
- if trInfo != nil {
- trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
- }
- return nil
- }
- reply, appErr := md.Handler(srv.server, stream.Context(), df)
- if appErr != nil {
- if err, ok := appErr.(rpcError); ok {
- statusCode = err.code
- statusDesc = err.desc
- } else {
- statusCode = convertCode(appErr)
- statusDesc = appErr.Error()
- }
- if trInfo != nil && statusCode != codes.OK {
- trInfo.tr.LazyLog(stringer(statusDesc), true)
- trInfo.tr.SetError()
+ default:
+ if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
- if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
- grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
+ }
+ return err
+ }
+ statusCode := codes.OK
+ statusDesc := ""
+ df := func(v interface{}) error {
+ if pf == compressionMade {
+ var err error
+ req, err = dc.Do(bytes.NewReader(req))
+ //req, err = ioutil.ReadAll(dc)
+ //defer dc.Close()
+ if err != nil {
+ if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+ }
return err
}
- return nil
}
- if trInfo != nil {
- trInfo.tr.LazyLog(stringer("OK"), false)
- }
- opts := &transport.Options{
- Last: true,
- Delay: false,
- }
- if err := s.sendResponse(t, stream, reply, compressionNone, opts); err != nil {
- switch err := err.(type) {
- case transport.ConnectionError:
- // Nothing to do here.
- case transport.StreamError:
- statusCode = err.Code
- statusDesc = err.Desc
- default:
- statusCode = codes.Unknown
- statusDesc = err.Error()
- }
+ if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
}
if trInfo != nil {
- trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
+ trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
}
- return t.WriteStatus(stream, statusCode, statusDesc)
- default:
- panic(fmt.Sprintf("payload format to be supported: %d", pf))
+ return nil
}
+ reply, appErr := md.Handler(srv.server, stream.Context(), df)
+ if appErr != nil {
+ if err, ok := appErr.(rpcError); ok {
+ statusCode = err.code
+ statusDesc = err.desc
+ } else {
+ statusCode = convertCode(appErr)
+ statusDesc = appErr.Error()
+ }
+ if trInfo != nil && statusCode != codes.OK {
+ trInfo.tr.LazyLog(stringer(statusDesc), true)
+ trInfo.tr.SetError()
+ }
+ if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
+ grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
+ return err
+ }
+ return nil
+ }
+ if trInfo != nil {
+ trInfo.tr.LazyLog(stringer("OK"), false)
+ }
+ opts := &transport.Options{
+ Last: true,
+ Delay: false,
+ }
+ var cp Compressor
+ if s.opts.cg != nil {
+ cp = s.opts.cg()
+ stream.SetSendCompress(cp.Type())
+ }
+ if err := s.sendResponse(t, stream, reply, cp, opts); err != nil {
+ switch err := err.(type) {
+ case transport.ConnectionError:
+ // Nothing to do here.
+ case transport.StreamError:
+ statusCode = err.Code
+ statusDesc = err.Desc
+ default:
+ statusCode = codes.Unknown
+ statusDesc = err.Error()
+ }
+ return err
+ }
+ if trInfo != nil {
+ trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
+ }
+ return t.WriteStatus(stream, statusCode, statusDesc)
}
}
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
+ var cp Compressor
+ if s.opts.cg != nil {
+ cp = s.opts.cg()
+ stream.SetSendCompress(cp.Type())
+ }
ss := &serverStream{
t: t,
s: stream,
p: &parser{s: stream},
codec: s.opts.codec,
+ cp: cp,
+ dg: s.opts.dg,
trInfo: trInfo,
}
+ if cp != nil {
+ ss.cbuf = new(bytes.Buffer)
+ }
if trInfo != nil {
trInfo.tr.LazyLog(&trInfo.firstLine, false)
defer func() {
@@ -422,6 +481,9 @@
if err, ok := appErr.(rpcError); ok {
ss.statusCode = err.code
ss.statusDesc = err.desc
+ } else if err, ok := appErr.(transport.StreamError); ok {
+ ss.statusCode = err.Code
+ ss.statusDesc = err.Desc
} else {
ss.statusCode = convertCode(appErr)
ss.statusDesc = appErr.Error()
diff --git a/stream.go b/stream.go
index d8bdc16..63f934d 100644
--- a/stream.go
+++ b/stream.go
@@ -34,6 +34,7 @@
package grpc
import (
+ "bytes"
"errors"
"io"
"sync"
@@ -104,16 +105,29 @@
if err != nil {
return nil, toRPCErr(err)
}
+ var cp Compressor
+ if cc.dopts.cg != nil {
+ cp = cc.dopts.cg()
+ }
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
callHdr := &transport.CallHdr{
Host: cc.authority,
Method: method,
}
+ if cp != nil {
+ callHdr.SendCompress = cp.Type()
+ }
cs := &clientStream{
desc: desc,
codec: cc.dopts.codec,
+ cp: cp,
+ dg: cc.dopts.dg,
tracing: EnableTracing,
}
+ if cp != nil {
+ callHdr.SendCompress = cp.Type()
+ cs.cbuf = new(bytes.Buffer)
+ }
if cs.tracing {
cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
cs.trInfo.firstLine.client = true
@@ -153,6 +167,9 @@
p *parser
desc *StreamDesc
codec Codec
+ cp Compressor
+ cbuf *bytes.Buffer
+ dg DecompressorGenerator
tracing bool // set to EnableTracing when the clientStream is created.
@@ -198,7 +215,12 @@
}
err = toRPCErr(err)
}()
- out, err := encode(cs.codec, m, compressionNone)
+ out, err := encode(cs.codec, m, cs.cp, cs.cbuf)
+ defer func() {
+ if cs.cbuf != nil {
+ cs.cbuf.Reset()
+ }
+ }()
if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: %v", err)
}
@@ -206,7 +228,7 @@
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
- err = recv(cs.p, cs.codec, m)
+ err = recv(cs.p, cs.codec, cs.s, cs.dg, m)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@@ -225,7 +247,7 @@
return
}
// Special handling for client streaming rpc.
- err = recv(cs.p, cs.codec, m)
+ err = recv(cs.p, cs.codec, cs.s, cs.dg, m)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@@ -310,6 +332,9 @@
s *transport.Stream
p *parser
codec Codec
+ cp Compressor
+ dg DecompressorGenerator
+ cbuf *bytes.Buffer
statusCode codes.Code
statusDesc string
trInfo *traceInfo
@@ -348,7 +373,12 @@
ss.mu.Unlock()
}
}()
- out, err := encode(ss.codec, m, compressionNone)
+ out, err := encode(ss.codec, m, ss.cp, ss.cbuf)
+ defer func() {
+ if ss.cbuf != nil {
+ ss.cbuf.Reset()
+ }
+ }()
if err != nil {
err = transport.StreamErrorf(codes.Internal, "grpc: %v", err)
return err
@@ -371,5 +401,5 @@
ss.mu.Unlock()
}
}()
- return recv(ss.p, ss.codec, m)
+ return recv(ss.p, ss.codec, ss.s, ss.dg, m)
}
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 93944f8..22ca877 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -143,7 +143,6 @@
if err != nil {
return nil, err
}
-
return &testpb.SimpleResponse{
Payload: payload,
}, nil
@@ -328,8 +327,8 @@
return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}}
}
-func setUp(t *testing.T, hs *health.HealthServer, maxStream uint32, ua string, e env) (s *grpc.Server, cc *grpc.ClientConn) {
- sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream)}
+func serverSetUp(t *testing.T, hs *health.HealthServer, maxStream uint32, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, e env) (s *grpc.Server, addr string) {
+ sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.CompressON(cg), grpc.DecompressON(dg)}
la := ":0"
switch e.network {
case "unix":
@@ -353,7 +352,7 @@
}
testpb.RegisterTestServiceServer(s, &testServer{security: e.security})
go s.Serve(lis)
- addr := la
+ addr = la
switch e.network {
case "unix":
default:
@@ -363,17 +362,22 @@
}
addr = "localhost:" + port
}
+ return
+}
+
+func clientSetUp(t *testing.T, addr string, cg grpc.CompressorGenerator, dg grpc.DecompressorGenerator, ua string, e env) (cc *grpc.ClientConn) {
+ var derr error
if e.security == "tls" {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
- cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua))
+ cc, derr = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg))
} else {
- cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua))
+ cc, derr = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua), grpc.WithCompressor(cg), grpc.WithDecompressor(dg))
}
- if err != nil {
- t.Fatalf("Dial(%q) = %v", addr, err)
+ if derr != nil {
+ t.Fatalf("Dial(%q) = %v", addr, derr)
}
return
}
@@ -390,7 +394,8 @@
}
func testTimeoutOnDeadServer(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
ctx, _ := context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
@@ -443,7 +448,8 @@
func testHealthCheckOnSuccess(t *testing.T, e env) {
hs := health.NewHealthServer()
hs.SetServingStatus("grpc.health.v1alpha.Health", 1)
- s, cc := setUp(t, hs, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc)
if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1alpha.Health"); err != nil {
t.Fatalf("Health/Check(_, _) = _, %v, want _, <nil>", err)
@@ -459,7 +465,8 @@
func testHealthCheckOnFailure(t *testing.T, e env) {
hs := health.NewHealthServer()
hs.SetServingStatus("grpc.health.v1alpha.HealthCheck", 1)
- s, cc := setUp(t, hs, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc)
if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1alpha.Health"); err != grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded)
@@ -473,7 +480,8 @@
}
func testHealthCheckOff(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc)
if _, err := healthCheck(1*time.Second, cc, ""); err != grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1alpha.Health") {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.Unimplemented)
@@ -488,7 +496,8 @@
func testHealthCheckServingStatus(t *testing.T, e env) {
hs := health.NewHealthServer()
- s, cc := setUp(t, hs, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, hs, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
defer tearDown(s, cc)
out, err := healthCheck(1*time.Second, cc, "")
if err != nil {
@@ -526,7 +535,8 @@
}
func testEmptyUnaryWithUserAgent(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, testAppUA, e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, testAppUA, e)
// Wait until cc is connected.
ctx, _ := context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
@@ -553,7 +563,7 @@
t.Fatalf("header[\"ua\"] = %q, %t, want %q, true", v, ok, testAppUA)
}
tearDown(s, cc)
- ctx, _ = context.WithTimeout(context.Background(), 5 * time.Second)
+ ctx, _ = context.WithTimeout(context.Background(), 5*time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != nil {
t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Ready, err)
}
@@ -569,7 +579,8 @@
}
func testFailedEmptyUnary(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx := metadata.NewContext(context.Background(), testMetadata)
@@ -585,7 +596,8 @@
}
func testLargeUnary(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 271828
@@ -619,7 +631,8 @@
}
func testMetadataUnaryRPC(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 2718
@@ -684,7 +697,8 @@
// TODO(zhaoq): Refactor to make this clearer and add more cases to test racy
// and error-prone paths.
func testRetry(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
var wg sync.WaitGroup
@@ -714,7 +728,8 @@
// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism.
func testRPCTimeout(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 2718
@@ -746,7 +761,8 @@
}
func testCancel(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 2718
@@ -778,7 +794,8 @@
func testCancelNoIO(t *testing.T, e env) {
// Only allows 1 live stream per server transport.
- s, cc := setUp(t, nil, 1, "", e)
+ s, addr := serverSetUp(t, nil, 1, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx, cancel := context.WithCancel(context.Background())
@@ -806,12 +823,12 @@
go func() {
defer close(ch)
// This should be blocked until the 1st is canceled.
- ctx, _ := context.WithTimeout(context.Background(), 2 * time.Second)
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
}
}()
- cancel();
+ cancel()
<-ch
}
@@ -829,7 +846,8 @@
}
func testPingPong(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
stream, err := tc.FullDuplexCall(context.Background())
@@ -886,7 +904,8 @@
}
func testMetadataStreamingRPC(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx := metadata.NewContext(context.Background(), testMetadata)
@@ -952,7 +971,8 @@
}
func testServerStreaming(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
respParam := make([]*testpb.ResponseParameters, len(respSizes))
@@ -1004,7 +1024,8 @@
}
func testFailedServerStreaming(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
respParam := make([]*testpb.ResponseParameters, len(respSizes))
@@ -1034,7 +1055,8 @@
}
func testClientStreaming(t *testing.T, e env) {
- s, cc := setUp(t, nil, math.MaxUint32, "", e)
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
stream, err := tc.StreamingInputCall(context.Background())
@@ -1074,7 +1096,8 @@
func testExceedMaxStreamsLimit(t *testing.T, e env) {
// Only allows 1 live stream per server transport.
- s, cc := setUp(t, nil, 1, "", e)
+ s, addr := serverSetUp(t, nil, 1, nil, nil, e)
+ cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
_, err := tc.StreamingInputCall(context.Background())
@@ -1095,3 +1118,109 @@
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
}
}
+
+func TestCompressServerHasNoSupport(t *testing.T) {
+ for _, e := range listTestEnv() {
+ testCompressServerHasNoSupport(t, e)
+ }
+}
+
+func testCompressServerHasNoSupport(t *testing.T, e env) {
+ s, addr := serverSetUp(t, nil, math.MaxUint32, nil, nil, e)
+ cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, nil, "", e)
+ // Unary call
+ tc := testpb.NewTestServiceClient(cc)
+ defer tearDown(s, cc)
+ argSize := 271828
+ respSize := 314159
+ payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseSize: proto.Int32(int32(respSize)),
+ Payload: payload,
+ }
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code %d", err, codes.InvalidArgument)
+ }
+ // Streaming RPC
+ stream, err := tc.FullDuplexCall(context.Background())
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ respParam := []*testpb.ResponseParameters{
+ {
+ Size: proto.Int32(31415),
+ },
+ }
+ payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
+ if err != nil {
+ t.Fatal(err)
+ }
+ sreq := &testpb.StreamingOutputCallRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseParameters: respParam,
+ Payload: payload,
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = %v, want error code %d", stream, err, codes.InvalidArgument)
+ }
+}
+
+func TestCompressOK(t *testing.T) {
+ for _, e := range listTestEnv() {
+ testCompressOK(t, e)
+ }
+}
+
+func testCompressOK(t *testing.T, e env) {
+ s, addr := serverSetUp(t, nil, math.MaxUint32, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, e)
+ cc := clientSetUp(t, addr, grpc.NewGZIPCompressor, grpc.NewGZIPDecompressor, "", e)
+ // Unary call
+ tc := testpb.NewTestServiceClient(cc)
+ defer tearDown(s, cc)
+ argSize := 271828
+ respSize := 314159
+ payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseSize: proto.Int32(int32(respSize)),
+ Payload: payload,
+ }
+ if _, err := tc.UnaryCall(context.Background(), req); err != nil {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
+ }
+ // Streaming RPC
+ stream, err := tc.FullDuplexCall(context.Background())
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ respParam := []*testpb.ResponseParameters{
+ {
+ Size: proto.Int32(31415),
+ },
+ }
+ payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
+ if err != nil {
+ t.Fatal(err)
+ }
+ sreq := &testpb.StreamingOutputCallRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseParameters: respParam,
+ Payload: payload,
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err != nil {
+ t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
+ }
+}
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 9eae37d..7006cd8 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -210,6 +210,7 @@
s := &Stream{
id: t.nextID,
method: callHdr.Method,
+ sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
fc: fc,
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)),
@@ -322,6 +323,9 @@
t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})
+ if callHdr.SendCompress != "" {
+ t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
+ }
if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
}
@@ -694,8 +698,10 @@
if !endHeaders {
return s
}
-
s.mu.Lock()
+ if !endStream {
+ s.recvCompress = hDec.state.encoding
+ }
if !s.headerDone {
if !endStream && len(hDec.state.mdata) > 0 {
s.header = hDec.state.mdata
diff --git a/transport/http2_server.go b/transport/http2_server.go
index 98088d9..cce2e12 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -164,6 +164,7 @@
if !endHeaders {
return s
}
+ s.recvCompress = hDec.state.encoding
if hDec.state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
} else {
@@ -190,6 +191,7 @@
ctx: s.ctx,
recv: s.buf,
}
+ s.recvCompress = hDec.state.encoding
s.method = hDec.state.method
t.mu.Lock()
if t.state != reachable {
@@ -446,6 +448,9 @@
t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
+ if s.sendCompress != "" {
+ t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
+ }
for k, v := range md {
for _, entry := range v {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
@@ -520,6 +525,9 @@
t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
+ if s.sendCompress != "" {
+ t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
+ }
p := http2.HeadersFrameParam{
StreamID: s.id,
BlockFragment: t.hBuf.Bytes(),
diff --git a/transport/http_util.go b/transport/http_util.go
index fec4e47..f9d9fdf 100644
--- a/transport/http_util.go
+++ b/transport/http_util.go
@@ -89,6 +89,7 @@
// Records the states during HPACK decoding. Must be reset once the
// decoding of the entire headers are finished.
type decodeState struct {
+ encoding string
// statusCode caches the stream status received from the trailer
// the server sent. Client side only.
statusCode codes.Code
@@ -145,6 +146,8 @@
d.err = StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected header")
return
}
+ case "grpc-encoding":
+ d.state.encoding = f.Value
case "grpc-status":
code, err := strconv.Atoi(f.Value)
if err != nil {
diff --git a/transport/transport.go b/transport/transport.go
index e1e7f57..7956479 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -170,11 +170,13 @@
ctx context.Context
cancel context.CancelFunc
// method records the associated RPC method of the stream.
- method string
- buf *recvBuffer
- dec io.Reader
- fc *inFlow
- recvQuota uint32
+ method string
+ recvCompress string
+ sendCompress string
+ buf *recvBuffer
+ dec io.Reader
+ fc *inFlow
+ recvQuota uint32
// The accumulated inbound quota pending for window update.
updateQuota uint32
// The handler to control the window update procedure for both this
@@ -201,6 +203,17 @@
statusDesc string
}
+// RecvCompress returns the compression algorithm applied to the inbound
+// message. It is empty string if there is no compression applied.
+func (s *Stream) RecvCompress() string {
+ return s.recvCompress
+}
+
+// SetSendCompress sets the compression algorithm to the stream.
+func (s *Stream) SetSendCompress(str string) {
+ s.sendCompress = str
+}
+
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is cancelled/expired.
@@ -348,8 +361,14 @@
// CallHdr carries the information of a particular RPC.
type CallHdr struct {
- Host string // peer host
- Method string // the operation to perform on the specified host
+ // Host specifies peer host.
+ Host string
+ // Method specifies the operation to perform.
+ Method string
+ // RecvCompress specifies the compression algorithm applied on inbound messages.
+ RecvCompress string
+ // SendCompress specifies the compression algorithm applied on outbound message.
+ SendCompress string
}
// ClientTransport is the common interface for all gRPC client side transport