blob: 40df6ed4a22200f1dca405c2e77d498aec8a90dc [file] [log] [blame]
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jsonrpc2_test
import (
"context"
"encoding/json"
"fmt"
"path"
"reflect"
"strings"
"testing"
"time"
"golang.org/x/exp/event/eventtest"
"golang.org/x/exp/jsonrpc2"
"golang.org/x/exp/jsonrpc2/internal/stack/stacktest"
errors "golang.org/x/xerrors"
)
var callTests = []invoker{
call{"no_args", nil, true},
call{"one_string", "fish", "got:fish"},
call{"one_number", 10, "got:10"},
call{"join", []string{"a", "b", "c"}, "a/b/c"},
sequence{"notify", []invoker{
notify{"set", 3},
notify{"add", 5},
call{"get", nil, 8},
}},
sequence{"preempt", []invoker{
async{"a", "wait", "a"},
notify{"unblock", "a"},
collect{"a", true, false},
}},
sequence{"basic cancel", []invoker{
async{"b", "wait", "b"},
cancel{"b"},
collect{"b", nil, true},
}},
sequence{"queue", []invoker{
async{"a", "wait", "a"},
notify{"set", 1},
notify{"add", 2},
notify{"add", 3},
notify{"add", 4},
call{"peek", nil, 0}, // accumulator will not have any adds yet
notify{"unblock", "a"},
collect{"a", true, false},
call{"get", nil, 10}, // accumulator now has all the adds
}},
sequence{"fork", []invoker{
async{"a", "fork", "a"},
notify{"set", 1},
notify{"add", 2},
notify{"add", 3},
notify{"add", 4},
call{"get", nil, 10}, // fork will not have blocked the adds
notify{"unblock", "a"},
collect{"a", true, false},
}},
callErr{"error", func() {}, "marshaling call parameters: json: unsupported type"},
}
type binder struct {
framer jsonrpc2.Framer
runTest func(*handler)
}
type handler struct {
conn *jsonrpc2.Connection
accumulator int
waitersBox chan map[string]chan struct{}
calls map[string]*jsonrpc2.AsyncCall
}
type invoker interface {
Name() string
Invoke(t *testing.T, ctx context.Context, h *handler)
}
type notify struct {
method string
params interface{}
}
type call struct {
method string
params interface{}
expect interface{}
}
type callErr struct {
method string
params interface{}
expectErr string
}
type async struct {
name string
method string
params interface{}
}
type collect struct {
name string
expect interface{}
fails bool
}
type cancel struct {
name string
}
type sequence struct {
name string
tests []invoker
}
type echo call
type cancelParams struct{ ID int64 }
func TestConnectionRaw(t *testing.T) {
testConnection(t, jsonrpc2.RawFramer())
}
func TestConnectionHeader(t *testing.T) {
testConnection(t, jsonrpc2.HeaderFramer())
}
func testConnection(t *testing.T, framer jsonrpc2.Framer) {
stacktest.NoLeak(t)
ctx := eventtest.NewContext(context.Background(), t)
listener, err := jsonrpc2.NetPipe(ctx)
if err != nil {
t.Fatal(err)
}
server, err := jsonrpc2.Serve(ctx, listener, binder{framer, nil})
if err != nil {
t.Fatal(err)
}
defer func() {
listener.Close()
server.Wait()
}()
for _, test := range callTests {
t.Run(test.Name(), func(t *testing.T) {
client, err := jsonrpc2.Dial(ctx,
listener.Dialer(), binder{framer, func(h *handler) {
defer h.conn.Close()
ctx := eventtest.NewContext(ctx, t)
test.Invoke(t, ctx, h)
if call, ok := test.(*call); ok {
// also run all simple call tests in echo mode
(*echo)(call).Invoke(t, ctx, h)
}
}})
if err != nil {
t.Fatal(err)
}
client.Wait()
})
}
}
func (test notify) Name() string { return test.method }
func (test notify) Invoke(t *testing.T, ctx context.Context, h *handler) {
if err := h.conn.Notify(ctx, test.method, test.params); err != nil {
t.Fatalf("%v:Notify failed: %v", test.method, err)
}
}
func (test call) Name() string { return test.method }
func (test call) Invoke(t *testing.T, ctx context.Context, h *handler) {
results := newResults(test.expect)
if err := h.conn.Call(ctx, test.method, test.params).Await(ctx, results); err != nil {
t.Fatalf("%v:Call failed: %v", test.method, err)
}
verifyResults(t, test.method, results, test.expect)
}
func (test callErr) Name() string { return test.method }
func (test callErr) Invoke(t *testing.T, ctx context.Context, h *handler) {
var results interface{}
if err := h.conn.Call(ctx, test.method, test.params).Await(ctx, &results); err != nil {
if serr := err.Error(); !strings.Contains(serr, test.expectErr) {
t.Fatalf("%v:Call failed but with unexpected error: %q does not contain %q", test.method, serr, test.expectErr)
}
return
}
t.Fatalf("%v:Call succeeded (%v) but should have failed with error containing %q", test.method, results, test.expectErr)
}
func (test echo) Invoke(t *testing.T, ctx context.Context, h *handler) {
results := newResults(test.expect)
if err := h.conn.Call(ctx, "echo", []interface{}{test.method, test.params}).Await(ctx, results); err != nil {
t.Fatalf("%v:Echo failed: %v", test.method, err)
}
verifyResults(t, test.method, results, test.expect)
}
func (test async) Name() string { return test.name }
func (test async) Invoke(t *testing.T, ctx context.Context, h *handler) {
h.calls[test.name] = h.conn.Call(ctx, test.method, test.params)
}
func (test collect) Name() string { return test.name }
func (test collect) Invoke(t *testing.T, ctx context.Context, h *handler) {
o := h.calls[test.name]
results := newResults(test.expect)
err := o.Await(ctx, results)
switch {
case test.fails && err == nil:
t.Fatalf("%v:Collect was supposed to fail", test.name)
case !test.fails && err != nil:
t.Fatalf("%v:Collect failed: %v", test.name, err)
}
verifyResults(t, test.name, results, test.expect)
}
func (test cancel) Name() string { return test.name }
func (test cancel) Invoke(t *testing.T, ctx context.Context, h *handler) {
o := h.calls[test.name]
if err := h.conn.Notify(ctx, "cancel", &cancelParams{o.ID().Raw().(int64)}); err != nil {
t.Fatalf("%v:Collect failed: %v", test.name, err)
}
}
func (test sequence) Name() string { return test.name }
func (test sequence) Invoke(t *testing.T, ctx context.Context, h *handler) {
for _, child := range test.tests {
child.Invoke(t, ctx, h)
}
}
// newResults makes a new empty copy of the expected type to put the results into
func newResults(expect interface{}) interface{} {
switch e := expect.(type) {
case []interface{}:
var r []interface{}
for _, v := range e {
r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
}
return r
case nil:
return nil
default:
return reflect.New(reflect.TypeOf(expect)).Interface()
}
}
// verifyResults compares the results to the expected values
func verifyResults(t *testing.T, method string, results interface{}, expect interface{}) {
if expect == nil {
if results != nil {
t.Errorf("%v:Got results %+v where none expeted", method, expect)
}
return
}
val := reflect.Indirect(reflect.ValueOf(results)).Interface()
if !reflect.DeepEqual(val, expect) {
t.Errorf("%v:Results are incorrect, got %+v expect %+v", method, val, expect)
}
}
func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) (jsonrpc2.ConnectionOptions, error) {
h := &handler{
conn: conn,
waitersBox: make(chan map[string]chan struct{}, 1),
calls: make(map[string]*jsonrpc2.AsyncCall),
}
h.waitersBox <- make(map[string]chan struct{})
if b.runTest != nil {
go b.runTest(h)
}
return jsonrpc2.ConnectionOptions{
Framer: b.framer,
Preempter: h,
Handler: h,
}, nil
}
func (h *handler) waiter(name string) chan struct{} {
waiters := <-h.waitersBox
defer func() { h.waitersBox <- waiters }()
waiter, found := waiters[name]
if !found {
waiter = make(chan struct{})
waiters[name] = waiter
}
return waiter
}
func (h *handler) Preempt(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
switch req.Method {
case "unblock":
var name string
if err := json.Unmarshal(req.Params, &name); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
close(h.waiter(name))
return nil, nil
case "peek":
if len(req.Params) > 0 {
return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
}
return h.accumulator, nil
case "cancel":
var params cancelParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
h.conn.Cancel(jsonrpc2.Int64ID(params.ID))
return nil, nil
default:
return nil, jsonrpc2.ErrNotHandled
}
}
func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
switch req.Method {
case "no_args":
if len(req.Params) > 0 {
return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
}
return true, nil
case "one_string":
var v string
if err := json.Unmarshal(req.Params, &v); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
return "got:" + v, nil
case "one_number":
var v int
if err := json.Unmarshal(req.Params, &v); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
return fmt.Sprintf("got:%d", v), nil
case "set":
var v int
if err := json.Unmarshal(req.Params, &v); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
h.accumulator = v
return nil, nil
case "add":
var v int
if err := json.Unmarshal(req.Params, &v); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
h.accumulator += v
return nil, nil
case "get":
if len(req.Params) > 0 {
return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
}
return h.accumulator, nil
case "join":
var v []string
if err := json.Unmarshal(req.Params, &v); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
return path.Join(v...), nil
case "echo":
var v []interface{}
if err := json.Unmarshal(req.Params, &v); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
var result interface{}
err := h.conn.Call(ctx, v[0].(string), v[1]).Await(ctx, &result)
return result, err
case "wait":
var name string
if err := json.Unmarshal(req.Params, &name); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
select {
case <-h.waiter(name):
return true, nil
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Second):
return nil, errors.Errorf("wait for %q timed out", name)
}
case "fork":
var name string
if err := json.Unmarshal(req.Params, &name); err != nil {
return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
}
waitFor := h.waiter(name)
go func() {
select {
case <-waitFor:
h.conn.Respond(req.ID, true, nil)
case <-ctx.Done():
h.conn.Respond(req.ID, nil, ctx.Err())
case <-time.After(time.Second):
h.conn.Respond(req.ID, nil, errors.Errorf("wait for %q timed out", name))
}
}()
return nil, jsonrpc2.ErrAsyncResponse
default:
return nil, jsonrpc2.ErrNotHandled
}
}