| // Copyright 2018 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package jsonrpc2_test |
| |
| import ( |
| "context" |
| "encoding/json" |
| "flag" |
| "fmt" |
| "io" |
| "log" |
| "path" |
| "reflect" |
| "testing" |
| "time" |
| |
| "golang.org/x/tools/internal/jsonrpc2" |
| ) |
| |
| var logRPC = flag.Bool("logrpc", false, "Enable jsonrpc2 communication logging") |
| |
| type callTest struct { |
| method string |
| params interface{} |
| expect interface{} |
| } |
| |
| var callTests = []callTest{ |
| {"no_args", nil, true}, |
| {"one_string", "fish", "got:fish"}, |
| {"one_number", 10, "got:10"}, |
| {"join", []string{"a", "b", "c"}, "a/b/c"}, |
| //TODO: expand the test cases |
| } |
| |
| func (test *callTest) newResults() interface{} { |
| switch e := test.expect.(type) { |
| case []interface{}: |
| var r []interface{} |
| for _, v := range e { |
| r = append(r, reflect.New(reflect.TypeOf(v)).Interface()) |
| } |
| return r |
| case nil: |
| return nil |
| default: |
| return reflect.New(reflect.TypeOf(test.expect)).Interface() |
| } |
| } |
| |
| func (test *callTest) verifyResults(t *testing.T, results interface{}) { |
| if results == nil { |
| return |
| } |
| val := reflect.Indirect(reflect.ValueOf(results)).Interface() |
| if !reflect.DeepEqual(val, test.expect) { |
| t.Errorf("%v:Results are incorrect, got %+v expect %+v", test.method, val, test.expect) |
| } |
| } |
| |
| func TestPlainCall(t *testing.T) { |
| ctx := context.Background() |
| a, b := prepare(ctx, t, false) |
| for _, test := range callTests { |
| results := test.newResults() |
| if err := a.Call(ctx, test.method, test.params, results); err != nil { |
| t.Fatalf("%v:Call failed: %v", test.method, err) |
| } |
| test.verifyResults(t, results) |
| if err := b.Call(ctx, test.method, test.params, results); err != nil { |
| t.Fatalf("%v:Call failed: %v", test.method, err) |
| } |
| test.verifyResults(t, results) |
| } |
| } |
| |
| func TestHeaderCall(t *testing.T) { |
| ctx := context.Background() |
| a, b := prepare(ctx, t, true) |
| for _, test := range callTests { |
| results := test.newResults() |
| if err := a.Call(ctx, test.method, test.params, results); err != nil { |
| t.Fatalf("%v:Call failed: %v", test.method, err) |
| } |
| test.verifyResults(t, results) |
| if err := b.Call(ctx, test.method, test.params, results); err != nil { |
| t.Fatalf("%v:Call failed: %v", test.method, err) |
| } |
| test.verifyResults(t, results) |
| } |
| } |
| |
| func prepare(ctx context.Context, t *testing.T, withHeaders bool) (*jsonrpc2.Conn, *jsonrpc2.Conn) { |
| aR, bW := io.Pipe() |
| bR, aW := io.Pipe() |
| a := run(ctx, t, withHeaders, aR, aW) |
| b := run(ctx, t, withHeaders, bR, bW) |
| return a, b |
| } |
| |
| func run(ctx context.Context, t *testing.T, withHeaders bool, r io.ReadCloser, w io.WriteCloser) *jsonrpc2.Conn { |
| var stream jsonrpc2.Stream |
| if withHeaders { |
| stream = jsonrpc2.NewHeaderStream(r, w) |
| } else { |
| stream = jsonrpc2.NewStream(r, w) |
| } |
| conn := jsonrpc2.NewConn(stream) |
| conn.AddHandler(&handle{log: *logRPC}) |
| go func() { |
| defer func() { |
| r.Close() |
| w.Close() |
| }() |
| if err := conn.Run(ctx); err != nil { |
| t.Fatalf("Stream failed: %v", err) |
| } |
| }() |
| return conn |
| } |
| |
| type handle struct { |
| log bool |
| } |
| |
| func (h *handle) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool { |
| switch r.Method { |
| case "no_args": |
| if r.Params != nil { |
| r.Reply(ctx, nil, jsonrpc2.NewErrorf(jsonrpc2.CodeInvalidParams, "Expected no params")) |
| return true |
| } |
| r.Reply(ctx, true, nil) |
| case "one_string": |
| var v string |
| if err := json.Unmarshal(*r.Params, &v); err != nil { |
| r.Reply(ctx, nil, jsonrpc2.NewErrorf(jsonrpc2.CodeParseError, "%v", err.Error())) |
| return true |
| } |
| r.Reply(ctx, "got:"+v, nil) |
| case "one_number": |
| var v int |
| if err := json.Unmarshal(*r.Params, &v); err != nil { |
| r.Reply(ctx, nil, jsonrpc2.NewErrorf(jsonrpc2.CodeParseError, "%v", err.Error())) |
| return true |
| } |
| r.Reply(ctx, fmt.Sprintf("got:%d", v), nil) |
| case "join": |
| var v []string |
| if err := json.Unmarshal(*r.Params, &v); err != nil { |
| r.Reply(ctx, nil, jsonrpc2.NewErrorf(jsonrpc2.CodeParseError, "%v", err.Error())) |
| return true |
| } |
| r.Reply(ctx, path.Join(v...), nil) |
| default: |
| r.Reply(ctx, nil, jsonrpc2.NewErrorf(jsonrpc2.CodeMethodNotFound, "method %q not found", r.Method)) |
| } |
| return true |
| } |
| |
| func (h *handle) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID, cancelled bool) bool { |
| return false |
| } |
| |
| func (h *handle) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { |
| if h.log { |
| if r.ID != nil { |
| log.Printf("%v call [%v] %s %v", direction, r.ID, r.Method, r.Params) |
| } else { |
| log.Printf("%v notification %s %v", direction, r.Method, r.Params) |
| } |
| ctx = context.WithValue(ctx, "method", r.Method) |
| ctx = context.WithValue(ctx, "start", time.Now()) |
| } |
| return ctx |
| } |
| |
| func (h *handle) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { |
| if h.log { |
| method := ctx.Value("method") |
| elapsed := time.Since(ctx.Value("start").(time.Time)) |
| log.Printf("%v response in %v [%v] %s %v", direction, elapsed, r.ID, method, r.Result) |
| } |
| return ctx |
| } |
| |
| func (h *handle) Done(ctx context.Context, err error) { |
| } |
| |
| func (h *handle) Read(ctx context.Context, bytes int64) context.Context { |
| return ctx |
| } |
| |
| func (h *handle) Wrote(ctx context.Context, bytes int64) context.Context { |
| return ctx |
| } |
| |
| func (h *handle) Error(ctx context.Context, err error) { |
| log.Printf("%v", err) |
| } |