Merge pull request #514 from bradfitz/servehttp
Add a ServeHTTP method to *grpc.Server
diff --git a/rpc_util.go b/rpc_util.go
index e98ddbc..fadf339 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -273,7 +273,7 @@
case compressionNone:
case compressionMade:
if recvCompress == "" {
- return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
+ return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress)
}
if dc == nil || recvCompress != dc.Type() {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
diff --git a/server.go b/server.go
index bcd6196..eb56b34 100644
--- a/server.go
+++ b/server.go
@@ -39,6 +39,7 @@
"fmt"
"io"
"net"
+ "net/http"
"reflect"
"runtime"
"strings"
@@ -46,6 +47,7 @@
"time"
"golang.org/x/net/context"
+ "golang.org/x/net/http2"
"golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@@ -82,10 +84,11 @@
// Server is a gRPC server to serve RPC requests.
type Server struct {
- opts options
- mu sync.Mutex
+ opts options
+
+ mu sync.Mutex // guards following
lis map[net.Listener]bool
- conns map[transport.ServerTransport]bool
+ conns map[io.Closer]bool
m map[string]*service // service name -> service info
events trace.EventLog
}
@@ -96,6 +99,7 @@
cp Compressor
dc Decompressor
maxConcurrentStreams uint32
+ useHandlerImpl bool // use http.Handler-based server
}
// A ServerOption sets options.
@@ -149,7 +153,7 @@
s := &Server{
lis: make(map[net.Listener]bool),
opts: opts,
- conns: make(map[transport.ServerTransport]bool),
+ conns: make(map[io.Closer]bool),
m: make(map[string]*service),
}
if EnableTracing {
@@ -216,9 +220,17 @@
ErrServerStopped = errors.New("grpc: the server has been stopped")
)
+func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
+ if !ok {
+ return rawConn, nil, nil
+ }
+ return creds.ServerHandshake(rawConn)
+}
+
// Serve accepts incoming connections on the listener lis, creating a new
// ServerTransport and service goroutine for each. The service goroutines
-// read gRPC request and then call the registered handlers to reply to them.
+// read gRPC requests and then call the registered handlers to reply to them.
// Service returns when lis.Accept fails.
func (s *Server) Serve(lis net.Listener) error {
s.mu.Lock()
@@ -235,39 +247,54 @@
delete(s.lis, lis)
s.mu.Unlock()
}()
+ listenerAddr := lis.Addr()
for {
- c, err := lis.Accept()
+ rawConn, err := lis.Accept()
if err != nil {
s.mu.Lock()
s.printf("done serving; Accept = %v", err)
s.mu.Unlock()
return err
}
- var authInfo credentials.AuthInfo
- if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
- var conn net.Conn
- conn, authInfo, err = creds.ServerHandshake(c)
- if err != nil {
- s.mu.Lock()
- s.errorf("ServerHandshake(%q) failed: %v", c.RemoteAddr(), err)
- s.mu.Unlock()
- grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
- continue
- }
- c = conn
- }
- s.mu.Lock()
- if s.conns == nil {
- s.mu.Unlock()
- c.Close()
- return nil
- }
- s.mu.Unlock()
-
- go s.serveNewHTTP2Transport(c, authInfo)
+ // Start a new goroutine to deal with rawConn
+ // so we don't stall this Accept loop goroutine.
+ go s.handleRawConn(listenerAddr, rawConn)
}
}
+// handleRawConn is run in its own goroutine and handles a just-accepted
+// connection that has not had any I/O performed on it yet.
+func (s *Server) handleRawConn(listenerAddr net.Addr, rawConn net.Conn) {
+ conn, authInfo, err := s.useTransportAuthenticator(rawConn)
+ if err != nil {
+ s.mu.Lock()
+ s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
+ s.mu.Unlock()
+ grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
+ rawConn.Close()
+ return
+ }
+
+ s.mu.Lock()
+ if s.conns == nil {
+ s.mu.Unlock()
+ conn.Close()
+ return
+ }
+ s.mu.Unlock()
+
+ if s.opts.useHandlerImpl {
+ s.serveUsingHandler(listenerAddr, conn)
+ } else {
+ s.serveNewHTTP2Transport(conn, authInfo)
+ }
+}
+
+// serveNewHTTP2Transport sets up a new http/2 transport (using the
+// gRPC http2 server transport in transport/http2_server.go) and
+// serves streams on it.
+// This is run in its own goroutine (it does network I/O in
+// transport.NewServerTransport).
func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
if err != nil {
@@ -299,6 +326,59 @@
wg.Wait()
}
+var _ http.Handler = (*Server)(nil)
+
+// serveUsingHandler is called from handleRawConn when s is configured
+// to handle requests via the http.Handler interface. It sets up a
+// net/http.Server to handle the just-accepted conn. The http.Server
+// is configured to route all incoming requests (all HTTP/2 streams)
+// to ServeHTTP, which creates a new ServerTransport for each stream.
+// serveUsingHandler blocks until conn closes.
+//
+// This codepath is only used when Server.TestingUseHandlerImpl has
+// been configured. This lets the end2end tests exercise the ServeHTTP
+// method as one of the environment types.
+//
+// conn is the *tls.Conn that's already been authenticated.
+func (s *Server) serveUsingHandler(listenerAddr net.Addr, conn net.Conn) {
+ if !s.addConn(conn) {
+ conn.Close()
+ return
+ }
+ defer s.removeConn(conn)
+ connDone := make(chan struct{})
+ hs := &http.Server{
+ Handler: s,
+ ConnState: func(c net.Conn, cs http.ConnState) {
+ if cs == http.StateClosed {
+ close(connDone)
+ }
+ },
+ }
+ if err := http2.ConfigureServer(hs, &http2.Server{
+ MaxConcurrentStreams: s.opts.maxConcurrentStreams,
+ }); err != nil {
+ grpclog.Fatalf("grpc: http2.ConfigureServer: %v", err)
+ return
+ }
+ hs.Serve(&singleConnListener{addr: listenerAddr, conn: conn})
+ <-connDone
+}
+
+func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ st, err := transport.NewServerHandlerTransport(w, r)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ if !s.addConn(st) {
+ st.Close()
+ return
+ }
+ defer s.removeConn(st)
+ s.serveStreams(st)
+}
+
// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
// If tracing is not enabled, it returns nil.
func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
@@ -317,21 +397,21 @@
return trInfo
}
-func (s *Server) addConn(st transport.ServerTransport) bool {
+func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.conns == nil {
return false
}
- s.conns[st] = true
+ s.conns[c] = true
return true
}
-func (s *Server) removeConn(st transport.ServerTransport) {
+func (s *Server) removeConn(c io.Closer) {
s.mu.Lock()
defer s.mu.Unlock()
if s.conns != nil {
- delete(s.conns, st)
+ delete(s.conns, c)
}
}
@@ -606,12 +686,14 @@
cs := s.conns
s.conns = nil
s.mu.Unlock()
+
for lis := range listeners {
lis.Close()
}
for c := range cs {
c.Close()
}
+
s.mu.Lock()
if s.events != nil {
s.events.Finish()
@@ -621,16 +703,24 @@
}
// TestingCloseConns closes all exiting transports but keeps s.lis accepting new
-// connections. This is for test only now.
+// connections.
+// This is only for tests and is subject to removal.
func (s *Server) TestingCloseConns() {
s.mu.Lock()
for c := range s.conns {
c.Close()
+ delete(s.conns, c)
}
- s.conns = make(map[transport.ServerTransport]bool)
s.mu.Unlock()
}
+// TestingUseHandlerImpl enables the http.Handler-based server implementation.
+// It must be called before Serve and requires TLS credentials.
+// This is only for tests and is subject to removal.
+func (s *Server) TestingUseHandlerImpl() {
+ s.opts.useHandlerImpl = true
+}
+
// SendHeader sends header metadata. It may be called at most once from a unary
// RPC handler. The ctx is the RPC handler's Context or one derived from it.
func SendHeader(ctx context.Context, md metadata.MD) error {
@@ -661,3 +751,30 @@
}
return stream.SetTrailer(md)
}
+
+// singleConnListener is a net.Listener that yields a single conn.
+type singleConnListener struct {
+ mu sync.Mutex
+ addr net.Addr
+ conn net.Conn // nil if done
+}
+
+func (ln *singleConnListener) Addr() net.Addr { return ln.addr }
+
+func (ln *singleConnListener) Close() error {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+ ln.conn = nil
+ return nil
+}
+
+func (ln *singleConnListener) Accept() (net.Conn, error) {
+ ln.mu.Lock()
+ defer ln.mu.Unlock()
+ c := ln.conn
+ if c == nil {
+ return nil, io.EOF
+ }
+ ln.conn = nil
+ return c, nil
+}
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 946df32..0bb7289 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -329,10 +329,11 @@
}
type env struct {
- name string
- network string // The type of network such as tcp, unix, etc.
- dialer func(addr string, timeout time.Duration) (net.Conn, error)
- security string // The security protocol such as TLS, SSH, etc.
+ name string
+ network string // The type of network such as tcp, unix, etc.
+ dialer func(addr string, timeout time.Duration) (net.Conn, error)
+ security string // The security protocol such as TLS, SSH, etc.
+ httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS
}
func (e env) runnable() bool {
@@ -347,10 +348,11 @@
tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"}
unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer}
unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"}
- allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv}
+ handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true}
+ allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv}
)
-var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', or 'unix-tls' to only run the tests for that environment. Empty means all.")
+var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', 'unix-tls', or 'handler-tls' to only run the tests for that environment. Empty means all.")
func listTestEnv() (envs []env) {
if *onlyEnv != "" {
@@ -393,6 +395,9 @@
sopts = append(sopts, grpc.Creds(creds))
}
s = grpc.NewServer(sopts...)
+ if e.httpHandler {
+ s.TestingUseHandlerImpl()
+ }
if hs != nil {
healthpb.RegisterHealthServer(s, hs)
}
@@ -720,7 +725,7 @@
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
}
if !reflect.DeepEqual(trailer, testTrailerMetadata) {
- t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata)
+ t.Fatalf("Received trailer metadata %v, want %v", trailer, testTrailerMetadata)
}
}
@@ -1030,11 +1035,13 @@
if e.security == "tls" {
delete(headerMD, "transport_security_type")
}
+ delete(headerMD, "trailer") // ignore if present
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
// test the cached value.
headerMD, err = stream.Header()
+ delete(headerMD, "trailer") // ignore if present
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
diff --git a/transport/handler_server.go b/transport/handler_server.go
new file mode 100644
index 0000000..5d1bffe
--- /dev/null
+++ b/transport/handler_server.go
@@ -0,0 +1,329 @@
+/*
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+// This file is the implementation of a gRPC server using HTTP/2 which
+// uses the standard Go http2 Server implementation (via the
+// http.Handler interface), rather than speaking low-level HTTP/2
+// frames itself. It is the implementation of *grpc.Server.ServeHTTP.
+
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/net/context"
+ "golang.org/x/net/http2"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/peer"
+)
+
+// NewServerHandlerTransport returns a ServerTransport handling gRPC
+// from inside an http.Handler. It requires that the http Server
+// supports HTTP/2.
+func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) {
+ if r.ProtoMajor != 2 {
+ return nil, errors.New("gRPC requires HTTP/2")
+ }
+ if r.Method != "POST" {
+ return nil, errors.New("invalid gRPC request method")
+ }
+ if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
+ return nil, errors.New("invalid gRPC request content-type")
+ }
+ if _, ok := w.(http.Flusher); !ok {
+ return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher")
+ }
+ if _, ok := w.(http.CloseNotifier); !ok {
+ return nil, errors.New("gRPC requires a ResponseWriter supporting http.CloseNotifier")
+ }
+
+ st := &serverHandlerTransport{
+ rw: w,
+ req: r,
+ closedCh: make(chan struct{}),
+ wroteStatus: make(chan struct{}),
+ }
+
+ if v := r.Header.Get("grpc-timeout"); v != "" {
+ to, err := timeoutDecode(v)
+ if err != nil {
+ return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
+ }
+ st.timeoutSet = true
+ st.timeout = to
+ }
+
+ var metakv []string
+ for k, vv := range r.Header {
+ k = strings.ToLower(k)
+ if isReservedHeader(k) {
+ continue
+ }
+ for _, v := range vv {
+ if k == "user-agent" {
+ // user-agent is special. Copying logic of http_util.go.
+ if i := strings.LastIndex(v, " "); i == -1 {
+ // There is no application user agent string being set
+ continue
+ } else {
+ v = v[:i]
+ }
+ }
+ metakv = append(metakv, k, v)
+
+ }
+ }
+ st.headerMD = metadata.Pairs(metakv...)
+
+ return st, nil
+}
+
+// serverHandlerTransport is an implementation of ServerTransport
+// which replies to exactly one gRPC request (exactly one HTTP request),
+// using the net/http.Handler interface. This http.Handler is guranteed
+// at this point to be speaking over HTTP/2, so it's able to speak valid
+// gRPC.
+type serverHandlerTransport struct {
+ rw http.ResponseWriter
+ req *http.Request
+ timeoutSet bool
+ timeout time.Duration
+ didCommonHeaders bool
+
+ headerMD metadata.MD
+
+ closeOnce sync.Once
+ closedCh chan struct{} // closed on Close
+
+ wroteStatus chan struct{} // closed on WriteStatus
+}
+
+func (ht *serverHandlerTransport) Close() error {
+ ht.closeOnce.Do(ht.closeCloseChanOnce)
+ return nil
+}
+
+func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) }
+
+func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }
+
+// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
+// the empty string if unknown.
+type strAddr string
+
+func (a strAddr) Network() string {
+ if a != "" {
+ // Per the documentation on net/http.Request.RemoteAddr, if this is
+ // set, it's set to the IP:port of the peer (hence, TCP):
+ // https://golang.org/pkg/net/http/#Request
+ //
+ // If we want to support Unix sockets later, we can
+ // add our own grpc-specific convention within the
+ // grpc codebase to set RemoteAddr to a different
+ // format, or probably better: we can attach it to the
+ // context and use that from serverHandlerTransport.RemoteAddr.
+ return "tcp"
+ }
+ return ""
+}
+
+func (a strAddr) String() string { return string(a) }
+
+func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
+ ht.writeCommonHeaders(s)
+
+ // And flush, in case no header or body has been sent yet.
+ // This forces a separation of headers and trailers if this is the
+ // first call (for example, in end2end tests's TestNoService).
+ ht.rw.(http.Flusher).Flush()
+
+ h := ht.rw.Header()
+ h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
+ if statusDesc != "" {
+ h.Set("Grpc-Message", statusDesc)
+ }
+ if md := s.Trailer(); len(md) > 0 {
+ for k, vv := range md {
+ for _, v := range vv {
+ // http2 ResponseWriter mechanism to
+ // send undeclared Trailers after the
+ // headers have possibly been written.
+ h.Add(http2.TrailerPrefix+k, v)
+ }
+ }
+ }
+ close(ht.wroteStatus)
+ return nil
+}
+
+// writeCommonHeaders sets common headers on the first write
+// call (Write, WriteHeader, or WriteStatus).
+func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
+ if ht.didCommonHeaders {
+ return
+ }
+ ht.didCommonHeaders = true
+
+ h := ht.rw.Header()
+ h["Date"] = nil // suppress Date to make tests happy; TODO: restore
+ h.Set("Content-Type", "application/grpc")
+
+ // Predeclare trailers we'll set later in WriteStatus (after the body).
+ // This is a SHOULD in the HTTP RFC, and the way you add (known)
+ // Trailers per the net/http.ResponseWriter contract.
+ // See https://golang.org/pkg/net/http/#ResponseWriter
+ // and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
+ h.Add("Trailer", "Grpc-Status")
+ h.Add("Trailer", "Grpc-Message")
+
+ if s.sendCompress != "" {
+ h.Set("Grpc-Encoding", s.sendCompress)
+ }
+}
+
+func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
+ ht.writeCommonHeaders(s)
+ ht.rw.Write(data)
+ if !opts.Delay {
+ ht.rw.(http.Flusher).Flush()
+ }
+ return nil
+}
+
+func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
+ ht.writeCommonHeaders(s)
+ h := ht.rw.Header()
+ for k, vv := range md {
+ for _, v := range vv {
+ h.Add(k, v)
+ }
+ }
+ ht.rw.WriteHeader(200)
+ ht.rw.(http.Flusher).Flush()
+ return nil
+}
+
+func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
+ // With this transport type there will be exactly 1 stream: this HTTP request.
+
+ var ctx context.Context
+ var cancel context.CancelFunc
+ if ht.timeoutSet {
+ ctx, cancel = context.WithTimeout(context.Background(), ht.timeout)
+ } else {
+ ctx, cancel = context.WithCancel(context.Background())
+ }
+
+ // clientGone receives a single value if peer is gone, either
+ // because the underlying connection is dead or because the
+ // peer sends an http2 RST_STREAM.
+ clientGone := ht.rw.(http.CloseNotifier).CloseNotify()
+ go func() {
+ select {
+ case <-ht.closedCh:
+ case <-clientGone:
+ }
+ cancel()
+ }()
+
+ req := ht.req
+
+ s := &Stream{
+ id: 0, // irrelevant
+ windowHandler: func(int) {}, // nothing
+ cancel: cancel,
+ buf: newRecvBuffer(),
+ st: ht,
+ method: req.URL.Path,
+ recvCompress: req.Header.Get("grpc-encoding"),
+ }
+ pr := &peer.Peer{
+ Addr: ht.RemoteAddr(),
+ }
+ if req.TLS != nil {
+ pr.AuthInfo = credentials.TLSInfo{*req.TLS}
+ }
+ ctx = metadata.NewContext(ctx, ht.headerMD)
+ ctx = peer.NewContext(ctx, pr)
+ s.ctx = newContextWithStream(ctx, s)
+ s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
+
+ // requestOver is closed when either the request's context is done
+ // or the status has been written via WriteStatus.
+ requestOver := make(chan struct{})
+
+ // readerDone is closed when the Body.Read-ing goroutine exits.
+ readerDone := make(chan struct{})
+ go func() {
+ defer close(readerDone)
+ for {
+ buf := make([]byte, 1024) // TODO: minimize garbage, optimize recvBuffer code/ownership
+ n, err := req.Body.Read(buf)
+ select {
+ case <-requestOver:
+ return
+ default:
+ }
+ if n > 0 {
+ s.buf.put(&recvMsg{data: buf[:n]})
+ }
+ if err != nil {
+ s.buf.put(&recvMsg{err: err})
+ break
+ }
+ }
+ }()
+
+ // runStream is provided by the *grpc.Server.serveStreams.
+ // It starts a goroutine handling s and exits immediately.
+ runStream(s)
+
+ // Wait for the stream to be done. It is considered done when
+ // either its context is done, or we've written its status.
+ select {
+ case <-ctx.Done():
+ case <-ht.wroteStatus:
+ }
+ close(requestOver)
+
+ // Wait for reading goroutine to finish.
+ req.Body.Close()
+ <-readerDone
+}
diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go
new file mode 100644
index 0000000..faa95af
--- /dev/null
+++ b/transport/handler_server_test.go
@@ -0,0 +1,386 @@
+/*
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "reflect"
+ "testing"
+ "time"
+
+ "golang.org/x/net/context"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/metadata"
+)
+
+func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
+ type testCase struct {
+ name string
+ req *http.Request
+ wantErr string
+ modrw func(http.ResponseWriter) http.ResponseWriter
+ check func(*serverHandlerTransport, *testCase) error
+ }
+ tests := []testCase{
+ {
+ name: "http/1.1",
+ req: &http.Request{
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ },
+ wantErr: "gRPC requires HTTP/2",
+ },
+ {
+ name: "bad method",
+ req: &http.Request{
+ ProtoMajor: 2,
+ Method: "GET",
+ Header: http.Header{},
+ RequestURI: "/",
+ },
+ wantErr: "invalid gRPC request method",
+ },
+ {
+ name: "bad content type",
+ req: &http.Request{
+ ProtoMajor: 2,
+ Method: "POST",
+ Header: http.Header{
+ "Content-Type": {"application/foo"},
+ },
+ RequestURI: "/service/foo.bar",
+ },
+ wantErr: "invalid gRPC request content-type",
+ },
+ {
+ name: "not flusher",
+ req: &http.Request{
+ ProtoMajor: 2,
+ Method: "POST",
+ Header: http.Header{
+ "Content-Type": {"application/grpc"},
+ },
+ RequestURI: "/service/foo.bar",
+ },
+ modrw: func(w http.ResponseWriter) http.ResponseWriter {
+ // Return w without its Flush method
+ type onlyCloseNotifier interface {
+ http.ResponseWriter
+ http.CloseNotifier
+ }
+ return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
+ },
+ wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
+ },
+ {
+ name: "not closenotifier",
+ req: &http.Request{
+ ProtoMajor: 2,
+ Method: "POST",
+ Header: http.Header{
+ "Content-Type": {"application/grpc"},
+ },
+ RequestURI: "/service/foo.bar",
+ },
+ modrw: func(w http.ResponseWriter) http.ResponseWriter {
+ // Return w without its CloseNotify method
+ type onlyFlusher interface {
+ http.ResponseWriter
+ http.Flusher
+ }
+ return struct{ onlyFlusher }{w.(onlyFlusher)}
+ },
+ wantErr: "gRPC requires a ResponseWriter supporting http.CloseNotifier",
+ },
+ {
+ name: "valid",
+ req: &http.Request{
+ ProtoMajor: 2,
+ Method: "POST",
+ Header: http.Header{
+ "Content-Type": {"application/grpc"},
+ },
+ URL: &url.URL{
+ Path: "/service/foo.bar",
+ },
+ RequestURI: "/service/foo.bar",
+ },
+ check: func(t *serverHandlerTransport, tt *testCase) error {
+ if t.req != tt.req {
+ return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
+ }
+ if t.rw == nil {
+ return errors.New("t.rw = nil; want non-nil")
+ }
+ return nil
+ },
+ },
+ {
+ name: "with timeout",
+ req: &http.Request{
+ ProtoMajor: 2,
+ Method: "POST",
+ Header: http.Header{
+ "Content-Type": []string{"application/grpc"},
+ "Grpc-Timeout": {"200m"},
+ },
+ URL: &url.URL{
+ Path: "/service/foo.bar",
+ },
+ RequestURI: "/service/foo.bar",
+ },
+ check: func(t *serverHandlerTransport, tt *testCase) error {
+ if !t.timeoutSet {
+ return errors.New("timeout not set")
+ }
+ if want := 200 * time.Millisecond; t.timeout != want {
+ return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
+ }
+ return nil
+ },
+ },
+ {
+ name: "with bad timeout",
+ req: &http.Request{
+ ProtoMajor: 2,
+ Method: "POST",
+ Header: http.Header{
+ "Content-Type": []string{"application/grpc"},
+ "Grpc-Timeout": {"tomorrow"},
+ },
+ URL: &url.URL{
+ Path: "/service/foo.bar",
+ },
+ RequestURI: "/service/foo.bar",
+ },
+ wantErr: `stream error: code = 13 desc = "malformed time-out: transport: timeout unit is not recognized: \"tomorrow\""`,
+ },
+ {
+ name: "with metadata",
+ req: &http.Request{
+ ProtoMajor: 2,
+ Method: "POST",
+ Header: http.Header{
+ "Content-Type": []string{"application/grpc"},
+ "meta-foo": {"foo-val"},
+ "meta-bar": {"bar-val1", "bar-val2"},
+ "user-agent": {"x/y a/b"},
+ },
+ URL: &url.URL{
+ Path: "/service/foo.bar",
+ },
+ RequestURI: "/service/foo.bar",
+ },
+ check: func(ht *serverHandlerTransport, tt *testCase) error {
+ want := metadata.MD{
+ "meta-bar": {"bar-val1", "bar-val2"},
+ "user-agent": {"x/y"},
+ "meta-foo": {"foo-val"},
+ }
+ if !reflect.DeepEqual(ht.headerMD, want) {
+ return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want)
+ }
+ return nil
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ rw := newTestHandlerResponseWriter()
+ if tt.modrw != nil {
+ rw = tt.modrw(rw)
+ }
+ got, gotErr := NewServerHandlerTransport(rw, tt.req)
+ if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
+ t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr)
+ continue
+ }
+ if gotErr != nil {
+ continue
+ }
+ if tt.check != nil {
+ if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
+ t.Errorf("%s: %v", tt.name, err)
+ }
+ }
+ }
+}
+
+type testHandlerResponseWriter struct {
+ *httptest.ResponseRecorder
+ closeNotify chan bool
+}
+
+func (w testHandlerResponseWriter) CloseNotify() <-chan bool { return w.closeNotify }
+func (w testHandlerResponseWriter) Flush() {}
+
+func newTestHandlerResponseWriter() http.ResponseWriter {
+ return testHandlerResponseWriter{
+ ResponseRecorder: httptest.NewRecorder(),
+ closeNotify: make(chan bool, 1),
+ }
+}
+
+type handleStreamTest struct {
+ t *testing.T
+ bodyw *io.PipeWriter
+ req *http.Request
+ rw testHandlerResponseWriter
+ ht *serverHandlerTransport
+}
+
+func newHandleStreamTest(t *testing.T) *handleStreamTest {
+ bodyr, bodyw := io.Pipe()
+ req := &http.Request{
+ ProtoMajor: 2,
+ Method: "POST",
+ Header: http.Header{
+ "Content-Type": {"application/grpc"},
+ },
+ URL: &url.URL{
+ Path: "/service/foo.bar",
+ },
+ RequestURI: "/service/foo.bar",
+ Body: bodyr,
+ }
+ rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
+ ht, err := NewServerHandlerTransport(rw, req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return &handleStreamTest{
+ t: t,
+ bodyw: bodyw,
+ ht: ht.(*serverHandlerTransport),
+ rw: rw,
+ }
+}
+
+func TestHandlerTransport_HandleStreams(t *testing.T) {
+ st := newHandleStreamTest(t)
+ st.ht.HandleStreams(func(s *Stream) {
+ if want := "/service/foo.bar"; s.method != want {
+ t.Errorf("stream method = %q; want %q", s.method, want)
+ }
+ st.bodyw.Close() // no body
+ st.ht.WriteStatus(s, codes.OK, "")
+ })
+ wantHeader := http.Header{
+ "Date": nil,
+ "Content-Type": {"application/grpc"},
+ "Trailer": {"Grpc-Status", "Grpc-Message"},
+ "Grpc-Status": {"0"},
+ }
+ if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
+ t.Errorf("Header+Trailer Map: %#v; want %#v", st.rw.HeaderMap, wantHeader)
+ }
+}
+
+// Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
+func TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
+ handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
+}
+
+// Tests that codes.InvalidArgument will close the body, per comment in handler_server.go.
+func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
+ handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
+}
+
+func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
+ st := newHandleStreamTest(t)
+ st.ht.HandleStreams(func(s *Stream) {
+ st.ht.WriteStatus(s, statusCode, msg)
+ })
+ wantHeader := http.Header{
+ "Date": nil,
+ "Content-Type": {"application/grpc"},
+ "Trailer": {"Grpc-Status", "Grpc-Message"},
+ "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
+ "Grpc-Message": {msg},
+ }
+ if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
+ t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
+ }
+}
+
+func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
+ bodyr, bodyw := io.Pipe()
+ req := &http.Request{
+ ProtoMajor: 2,
+ Method: "POST",
+ Header: http.Header{
+ "Content-Type": {"application/grpc"},
+ "Grpc-Timeout": {"200m"},
+ },
+ URL: &url.URL{
+ Path: "/service/foo.bar",
+ },
+ RequestURI: "/service/foo.bar",
+ Body: bodyr,
+ }
+ rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
+ ht, err := NewServerHandlerTransport(rw, req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ht.HandleStreams(func(s *Stream) {
+ defer bodyw.Close()
+ select {
+ case <-s.ctx.Done():
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for ctx.Done")
+ return
+ }
+ err := s.ctx.Err()
+ if err != context.DeadlineExceeded {
+ t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
+ return
+ }
+ ht.WriteStatus(s, codes.DeadlineExceeded, "too slow")
+ })
+ wantHeader := http.Header{
+ "Date": nil,
+ "Content-Type": {"application/grpc"},
+ "Trailer": {"Grpc-Status", "Grpc-Message"},
+ "Grpc-Status": {"4"},
+ "Grpc-Message": {"too slow"},
+ }
+ if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
+ t.Errorf("Header+Trailer Map: %#v; want %#v", rw.HeaderMap, wantHeader)
+ }
+}