| /* |
| * Copyright 2016, Google Inc. |
| * All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions are |
| * met: |
| * |
| * * Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * * Redistributions in binary form must reproduce the above |
| * copyright notice, this list of conditions and the following disclaimer |
| * in the documentation and/or other materials provided with the |
| * distribution. |
| * * Neither the name of Google Inc. nor the names of its |
| * contributors may be used to endorse or promote products derived from |
| * this software without specific prior written permission. |
| * |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| * |
| */ |
| |
| package transport |
| |
| import ( |
| "errors" |
| "fmt" |
| "io" |
| "net/http" |
| "net/http/httptest" |
| "net/url" |
| "reflect" |
| "testing" |
| "time" |
| |
| "golang.org/x/net/context" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/metadata" |
| ) |
| |
| func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { |
| type testCase struct { |
| name string |
| req *http.Request |
| wantErr string |
| modrw func(http.ResponseWriter) http.ResponseWriter |
| check func(*serverHandlerTransport, *testCase) error |
| } |
| tests := []testCase{ |
| { |
| name: "http/1.1", |
| req: &http.Request{ |
| ProtoMajor: 1, |
| ProtoMinor: 1, |
| }, |
| wantErr: "gRPC requires HTTP/2", |
| }, |
| { |
| name: "bad method", |
| req: &http.Request{ |
| ProtoMajor: 2, |
| Method: "GET", |
| Header: http.Header{}, |
| RequestURI: "/", |
| }, |
| wantErr: "invalid gRPC request method", |
| }, |
| { |
| name: "bad content type", |
| req: &http.Request{ |
| ProtoMajor: 2, |
| Method: "POST", |
| Header: http.Header{ |
| "Content-Type": {"application/foo"}, |
| }, |
| RequestURI: "/service/foo.bar", |
| }, |
| wantErr: "invalid gRPC request content-type", |
| }, |
| { |
| name: "not flusher", |
| req: &http.Request{ |
| ProtoMajor: 2, |
| Method: "POST", |
| Header: http.Header{ |
| "Content-Type": {"application/grpc"}, |
| }, |
| RequestURI: "/service/foo.bar", |
| }, |
| modrw: func(w http.ResponseWriter) http.ResponseWriter { |
| // Return w without its Flush method |
| type onlyCloseNotifier interface { |
| http.ResponseWriter |
| http.CloseNotifier |
| } |
| return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)} |
| }, |
| wantErr: "gRPC requires a ResponseWriter supporting http.Flusher", |
| }, |
| { |
| name: "not closenotifier", |
| req: &http.Request{ |
| ProtoMajor: 2, |
| Method: "POST", |
| Header: http.Header{ |
| "Content-Type": {"application/grpc"}, |
| }, |
| RequestURI: "/service/foo.bar", |
| }, |
| modrw: func(w http.ResponseWriter) http.ResponseWriter { |
| // Return w without its CloseNotify method |
| type onlyFlusher interface { |
| http.ResponseWriter |
| http.Flusher |
| } |
| return struct{ onlyFlusher }{w.(onlyFlusher)} |
| }, |
| wantErr: "gRPC requires a ResponseWriter supporting http.CloseNotifier", |
| }, |
| { |
| name: "valid", |
| req: &http.Request{ |
| ProtoMajor: 2, |
| Method: "POST", |
| Header: http.Header{ |
| "Content-Type": {"application/grpc"}, |
| }, |
| URL: &url.URL{ |
| Path: "/service/foo.bar", |
| }, |
| RequestURI: "/service/foo.bar", |
| }, |
| check: func(t *serverHandlerTransport, tt *testCase) error { |
| if t.req != tt.req { |
| return fmt.Errorf("t.req = %p; want %p", t.req, tt.req) |
| } |
| if t.rw == nil { |
| return errors.New("t.rw = nil; want non-nil") |
| } |
| return nil |
| }, |
| }, |
| { |
| name: "with timeout", |
| req: &http.Request{ |
| ProtoMajor: 2, |
| Method: "POST", |
| Header: http.Header{ |
| "Content-Type": []string{"application/grpc"}, |
| "Grpc-Timeout": {"200m"}, |
| }, |
| URL: &url.URL{ |
| Path: "/service/foo.bar", |
| }, |
| RequestURI: "/service/foo.bar", |
| }, |
| check: func(t *serverHandlerTransport, tt *testCase) error { |
| if !t.timeoutSet { |
| return errors.New("timeout not set") |
| } |
| if want := 200 * time.Millisecond; t.timeout != want { |
| return fmt.Errorf("timeout = %v; want %v", t.timeout, want) |
| } |
| return nil |
| }, |
| }, |
| { |
| name: "with bad timeout", |
| req: &http.Request{ |
| ProtoMajor: 2, |
| Method: "POST", |
| Header: http.Header{ |
| "Content-Type": []string{"application/grpc"}, |
| "Grpc-Timeout": {"tomorrow"}, |
| }, |
| URL: &url.URL{ |
| Path: "/service/foo.bar", |
| }, |
| RequestURI: "/service/foo.bar", |
| }, |
| wantErr: `stream error: code = 13 desc = "malformed time-out: transport: timeout unit is not recognized: \"tomorrow\""`, |
| }, |
| { |
| name: "with metadata", |
| req: &http.Request{ |
| ProtoMajor: 2, |
| Method: "POST", |
| Header: http.Header{ |
| "Content-Type": []string{"application/grpc"}, |
| "meta-foo": {"foo-val"}, |
| "meta-bar": {"bar-val1", "bar-val2"}, |
| "user-agent": {"x/y a/b"}, |
| }, |
| URL: &url.URL{ |
| Path: "/service/foo.bar", |
| }, |
| RequestURI: "/service/foo.bar", |
| }, |
| check: func(ht *serverHandlerTransport, tt *testCase) error { |
| want := metadata.MD{ |
| "meta-bar": {"bar-val1", "bar-val2"}, |
| "user-agent": {"x/y"}, |
| "meta-foo": {"foo-val"}, |
| } |
| if !reflect.DeepEqual(ht.headerMD, want) { |
| return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want) |
| } |
| return nil |
| }, |
| }, |
| } |
| |
| for _, tt := range tests { |
| rw := newTestHandlerResponseWriter() |
| if tt.modrw != nil { |
| rw = tt.modrw(rw) |
| } |
| got, gotErr := NewServerHandlerTransport(rw, tt.req) |
| if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) { |
| t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr) |
| continue |
| } |
| if gotErr != nil { |
| continue |
| } |
| if tt.check != nil { |
| if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil { |
| t.Errorf("%s: %v", tt.name, err) |
| } |
| } |
| } |
| } |
| |
| type testHandlerResponseWriter struct { |
| *httptest.ResponseRecorder |
| closeNotify chan bool |
| } |
| |
| func (w testHandlerResponseWriter) CloseNotify() <-chan bool { return w.closeNotify } |
| func (w testHandlerResponseWriter) Flush() {} |
| |
| func newTestHandlerResponseWriter() http.ResponseWriter { |
| return testHandlerResponseWriter{ |
| ResponseRecorder: httptest.NewRecorder(), |
| closeNotify: make(chan bool, 1), |
| } |
| } |
| |
| type handleStreamTest struct { |
| t *testing.T |
| bodyw *io.PipeWriter |
| req *http.Request |
| rw testHandlerResponseWriter |
| ht *serverHandlerTransport |
| } |
| |
| func newHandleStreamTest(t *testing.T) *handleStreamTest { |
| bodyr, bodyw := io.Pipe() |
| req := &http.Request{ |
| ProtoMajor: 2, |
| Method: "POST", |
| Header: http.Header{ |
| "Content-Type": {"application/grpc"}, |
| }, |
| URL: &url.URL{ |
| Path: "/service/foo.bar", |
| }, |
| RequestURI: "/service/foo.bar", |
| Body: bodyr, |
| } |
| rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) |
| ht, err := NewServerHandlerTransport(rw, req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| return &handleStreamTest{ |
| t: t, |
| bodyw: bodyw, |
| ht: ht.(*serverHandlerTransport), |
| rw: rw, |
| } |
| } |
| |
| func TestHandlerTransport_HandleStreams(t *testing.T) { |
| st := newHandleStreamTest(t) |
| st.ht.HandleStreams(func(s *Stream) { |
| if want := "/service/foo.bar"; s.method != want { |
| t.Errorf("stream method = %q; want %q", s.method, want) |
| } |
| st.bodyw.Close() // no body |
| st.ht.WriteStatus(s, codes.OK, "") |
| }) |
| wantHeader := http.Header{ |
| "Date": nil, |
| "Content-Type": {"application/grpc"}, |
| "Trailer": {"Grpc-Status", "Grpc-Message"}, |
| "Grpc-Status": {"0"}, |
| } |
| if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { |
| t.Errorf("Header+Trailer Map: %#v; want %#v", st.rw.HeaderMap, wantHeader) |
| } |
| } |
| |
| // Tests that codes.Unimplemented will close the body, per comment in handler_server.go. |
| func TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) { |
| handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented") |
| } |
| |
| // Tests that codes.InvalidArgument will close the body, per comment in handler_server.go. |
| func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) { |
| handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg") |
| } |
| |
| func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) { |
| st := newHandleStreamTest(t) |
| st.ht.HandleStreams(func(s *Stream) { |
| st.ht.WriteStatus(s, statusCode, msg) |
| }) |
| wantHeader := http.Header{ |
| "Date": nil, |
| "Content-Type": {"application/grpc"}, |
| "Trailer": {"Grpc-Status", "Grpc-Message"}, |
| "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, |
| "Grpc-Message": {msg}, |
| } |
| if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { |
| t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader) |
| } |
| } |
| |
| func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { |
| bodyr, bodyw := io.Pipe() |
| req := &http.Request{ |
| ProtoMajor: 2, |
| Method: "POST", |
| Header: http.Header{ |
| "Content-Type": {"application/grpc"}, |
| "Grpc-Timeout": {"200m"}, |
| }, |
| URL: &url.URL{ |
| Path: "/service/foo.bar", |
| }, |
| RequestURI: "/service/foo.bar", |
| Body: bodyr, |
| } |
| rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) |
| ht, err := NewServerHandlerTransport(rw, req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| ht.HandleStreams(func(s *Stream) { |
| defer bodyw.Close() |
| select { |
| case <-s.ctx.Done(): |
| case <-time.After(5 * time.Second): |
| t.Errorf("timeout waiting for ctx.Done") |
| return |
| } |
| err := s.ctx.Err() |
| if err != context.DeadlineExceeded { |
| t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded) |
| return |
| } |
| ht.WriteStatus(s, codes.DeadlineExceeded, "too slow") |
| }) |
| wantHeader := http.Header{ |
| "Date": nil, |
| "Content-Type": {"application/grpc"}, |
| "Trailer": {"Grpc-Status", "Grpc-Message"}, |
| "Grpc-Status": {"4"}, |
| "Grpc-Message": {"too slow"}, |
| } |
| if !reflect.DeepEqual(rw.HeaderMap, wantHeader) { |
| t.Errorf("Header+Trailer Map: %#v; want %#v", rw.HeaderMap, wantHeader) |
| } |
| } |