Merge pull request #464 from iamqizhao/master
i) revise picker API and channel state API; ii) add unicastNamingPick to support custom name discovery.
diff --git a/call_test.go b/call_test.go
index 53b4cfc..48d25e5 100644
--- a/call_test.go
+++ b/call_test.go
@@ -81,10 +81,11 @@
break
}
if err != nil {
- t.Fatalf("Failed to receive the message from the client.")
+ return
}
if pf != compressionNone {
- t.Fatalf("Received the mistaken message format %d, want %d", pf, compressionNone)
+ t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone)
+ return
}
var v string
codec := testCodec{}
diff --git a/clientconn.go b/clientconn.go
index 6da8867..bf66914 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -89,6 +89,12 @@
}
}
+func WithPicker(p Picker) DialOption {
+ return func(o *dialOptions) {
+ o.picker = p
+ }
+}
+
// WithBlock returns a DialOption which makes caller of Dial blocks until the underlying
// connection is up. Without this, Dial returns immediately and connecting the server
// happens in background.
@@ -154,7 +160,9 @@
cc.dopts.codec = protoCodec{}
}
if cc.dopts.picker == nil {
- cc.dopts.picker = &unicastPicker{}
+ cc.dopts.picker = &unicastPicker{
+ target: target,
+ }
}
if err := cc.dopts.picker.Init(cc); err != nil {
return nil, err
@@ -209,15 +217,15 @@
// State returns the connectivity state of cc.
// This is EXPERIMENTAL API.
-func (cc *ClientConn) State() ConnectivityState {
+func (cc *ClientConn) State() (ConnectivityState, error) {
return cc.dopts.picker.State()
}
-// WaitForStateChange blocks until the state changes to something other than the sourceState
-// or timeout fires on cc. It returns false if timeout fires, and true otherwise.
+// WaitForStateChange blocks until the state changes to something other than the sourceState.
+// It returns the new state or error.
// This is EXPERIMENTAL API.
-func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool {
- return cc.dopts.picker.WaitForStateChange(timeout, sourceState)
+func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
+ return cc.dopts.picker.WaitForStateChange(ctx, sourceState)
}
// Close starts to tear down the ClientConn.
@@ -229,6 +237,7 @@
type Conn struct {
target string
dopts dialOptions
+ resetChan chan int
shutdownChan chan struct{}
events trace.EventLog
@@ -249,6 +258,7 @@
c := &Conn{
target: cc.target,
dopts: cc.dopts,
+ resetChan: make(chan int, 1),
shutdownChan: make(chan struct{}),
}
if EnableTracing {
@@ -317,26 +327,20 @@
return cc.state
}
-// WaitForStateChange blocks until the state changes to something other than the sourceState
-// or timeout fires. It returns false if timeout fires and true otherwise.
-// TODO(zhaoq): Rewrite for complex Picker.
-func (cc *Conn) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool {
- start := time.Now()
+// WaitForStateChange blocks until the state changes to something other than the sourceState.
+func (cc *Conn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
cc.mu.Lock()
defer cc.mu.Unlock()
if sourceState != cc.state {
- return true
- }
- expired := timeout <= time.Since(start)
- if expired {
- return false
+ return cc.state, nil
}
done := make(chan struct{})
+ var err error
go func() {
select {
- case <-time.After(timeout - time.Since(start)):
+ case <-ctx.Done():
cc.mu.Lock()
- expired = true
+ err = ctx.Err()
cc.stateCV.Broadcast()
cc.mu.Unlock()
case <-done:
@@ -345,11 +349,20 @@
defer close(done)
for sourceState == cc.state {
cc.stateCV.Wait()
- if expired {
- return false
+ if err != nil {
+ return cc.state, err
}
}
- return true
+ return cc.state, nil
+}
+
+// NotifyReset tries to signal the underlying transport needs to be reset due to
+// for example a name resolution change in flight.
+func (cc *Conn) NotifyReset() {
+ select {
+ case cc.resetChan <- 0:
+ default:
+ }
}
func (cc *Conn) resetTransport(closeTransport bool) error {
@@ -391,7 +404,11 @@
copts.Timeout = timeout
}
connectTime := time.Now()
- newTransport, err := transport.NewClientTransport(cc.target, &copts)
+ addr, err := cc.dopts.picker.PickAddr()
+ var newTransport transport.ClientTransport
+ if err == nil {
+ newTransport, err = transport.NewClientTransport(addr, &copts)
+ }
if err != nil {
cc.mu.Lock()
if cc.state == Shutdown {
@@ -422,7 +439,7 @@
closeTransport = false
time.Sleep(sleepTime)
retries++
- grpclog.Printf("grpc: ClientConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target)
+ grpclog.Printf("grpc: Conn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target)
continue
}
cc.mu.Lock()
@@ -445,6 +462,27 @@
}
}
+func (cc *Conn) reconnect() bool {
+ cc.mu.Lock()
+ if cc.state == Shutdown {
+ // cc.Close() has been invoked.
+ cc.mu.Unlock()
+ return false
+ }
+ cc.state = TransientFailure
+ cc.stateCV.Broadcast()
+ cc.mu.Unlock()
+ if err := cc.resetTransport(true); err != nil {
+ // The ClientConn is closing.
+ cc.mu.Lock()
+ cc.printf("transport exiting: %v", err)
+ cc.mu.Unlock()
+ grpclog.Printf("grpc: Conn.transportMonitor exits due to: %v", err)
+ return false
+ }
+ return true
+}
+
// Run in a goroutine to track the error in transport and create the
// new transport if an error happens. It returns when the channel is closing.
func (cc *Conn) transportMonitor() {
@@ -454,25 +492,19 @@
// the ClientConn is idle (i.e., no RPC in flight).
case <-cc.shutdownChan:
return
+ case <-cc.resetChan:
+ if !cc.reconnect() {
+ return
+ }
case <-cc.transport.Error():
- cc.mu.Lock()
- if cc.state == Shutdown {
- // cc.Close() has been invoked.
- cc.mu.Unlock()
+ if !cc.reconnect() {
return
}
- cc.state = TransientFailure
- cc.stateCV.Broadcast()
- cc.mu.Unlock()
- if err := cc.resetTransport(true); err != nil {
- // The ClientConn is closing.
- cc.mu.Lock()
- cc.printf("transport exiting: %v", err)
- cc.mu.Unlock()
- grpclog.Printf("grpc: ClientConn.transportMonitor exits due to: %v", err)
- return
+ // Tries to drain reset signal if there is any since it is out-dated.
+ select {
+ case <-cc.resetChan:
+ default:
}
- continue
}
}
}
diff --git a/picker.go b/picker.go
index bc48573..b83c859 100644
--- a/picker.go
+++ b/picker.go
@@ -34,9 +34,13 @@
package grpc
import (
- "time"
+ "container/list"
+ "fmt"
+ "sync"
"golang.org/x/net/context"
+ "google.golang.org/grpc/grpclog"
+ "google.golang.org/grpc/naming"
"google.golang.org/grpc/transport"
)
@@ -48,12 +52,14 @@
// Pick blocks until either a transport.ClientTransport is ready for the upcoming RPC
// or some error happens.
Pick(ctx context.Context) (transport.ClientTransport, error)
+ // PickAddr picks a peer address for connecting. This will be called repeated for
+ // connecting/reconnecting.
+ PickAddr() (string, error)
// State returns the connectivity state of the underlying connections.
- State() ConnectivityState
+ State() (ConnectivityState, error)
// WaitForStateChange blocks until the state changes to something other than
- // the sourceState or timeout fires on cc. It returns false if timeout fires,
- // and true otherwise.
- WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool
+ // the sourceState. It returns the new state or error.
+ WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error)
// Close closes all the Conn's owned by this Picker.
Close() error
}
@@ -61,7 +67,8 @@
// unicastPicker is the default Picker which is used when there is no custom Picker
// specified by users. It always picks the same Conn.
type unicastPicker struct {
- conn *Conn
+ target string
+ conn *Conn
}
func (p *unicastPicker) Init(cc *ClientConn) error {
@@ -77,12 +84,16 @@
return p.conn.Wait(ctx)
}
-func (p *unicastPicker) State() ConnectivityState {
- return p.conn.State()
+func (p *unicastPicker) PickAddr() (string, error) {
+ return p.target, nil
}
-func (p *unicastPicker) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool {
- return p.conn.WaitForStateChange(timeout, sourceState)
+func (p *unicastPicker) State() (ConnectivityState, error) {
+ return p.conn.State(), nil
+}
+
+func (p *unicastPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
+ return p.conn.WaitForStateChange(ctx, sourceState)
}
func (p *unicastPicker) Close() error {
@@ -91,3 +102,142 @@
}
return nil
}
+
+// unicastNamingPicker picks an address from a name resolver to set up the connection.
+type unicastNamingPicker struct {
+ cc *ClientConn
+ resolver naming.Resolver
+ watcher naming.Watcher
+ mu sync.Mutex
+ // The list of the addresses are obtained from watcher.
+ addrs *list.List
+ // It tracks the current picked addr by PickAddr(). The next PickAddr may
+ // push it forward on addrs.
+ pickedAddr *list.Element
+ conn *Conn
+}
+
+// NewUnicastNamingPicker creates a Picker to pick addresses from a name resolver
+// to connect.
+func NewUnicastNamingPicker(r naming.Resolver) Picker {
+ return &unicastNamingPicker{
+ resolver: r,
+ addrs: list.New(),
+ }
+}
+
+type addrInfo struct {
+ addr string
+ // Set to true if this addrInfo needs to be deleted in the next PickAddrr() call.
+ deleting bool
+}
+
+// processUpdates calls Watcher.Next() once and processes the obtained updates.
+func (p *unicastNamingPicker) processUpdates() error {
+ updates, err := p.watcher.Next()
+ if err != nil {
+ return err
+ }
+ for _, update := range updates {
+ switch update.Op {
+ case naming.Add:
+ p.mu.Lock()
+ p.addrs.PushBack(&addrInfo{
+ addr: update.Addr,
+ })
+ p.mu.Unlock()
+ // Initial connection setup
+ if p.conn == nil {
+ conn, err := NewConn(p.cc)
+ if err != nil {
+ return err
+ }
+ p.conn = conn
+ }
+ case naming.Delete:
+ p.mu.Lock()
+ for e := p.addrs.Front(); e != nil; e = e.Next() {
+ if update.Addr == e.Value.(*addrInfo).addr {
+ if e == p.pickedAddr {
+ // Do not remove the element now if it is the current picked
+ // one. We leave the deletion to the next PickAddr() call.
+ e.Value.(*addrInfo).deleting = true
+ // Notify Conn to close it. All the live RPCs on this connection
+ // will be aborted.
+ p.conn.NotifyReset()
+ } else {
+ p.addrs.Remove(e)
+ }
+ }
+ }
+ p.mu.Unlock()
+ default:
+ grpclog.Println("Unknown update.Op %d", update.Op)
+ }
+ }
+ return nil
+}
+
+// monitor runs in a standalone goroutine to keep watching name resolution updates until the watcher
+// is closed.
+func (p *unicastNamingPicker) monitor() {
+ for {
+ if err := p.processUpdates(); err != nil {
+ return
+ }
+ }
+}
+
+func (p *unicastNamingPicker) Init(cc *ClientConn) error {
+ w, err := p.resolver.Resolve(cc.target)
+ if err != nil {
+ return err
+ }
+ p.watcher = w
+ p.cc = cc
+ // Get the initial name resolution.
+ if err := p.processUpdates(); err != nil {
+ return err
+ }
+ go p.monitor()
+ return nil
+}
+
+func (p *unicastNamingPicker) Pick(ctx context.Context) (transport.ClientTransport, error) {
+ return p.conn.Wait(ctx)
+}
+
+func (p *unicastNamingPicker) PickAddr() (string, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.pickedAddr == nil {
+ p.pickedAddr = p.addrs.Front()
+ } else {
+ pa := p.pickedAddr
+ p.pickedAddr = pa.Next()
+ if pa.Value.(*addrInfo).deleting {
+ p.addrs.Remove(pa)
+ }
+ if p.pickedAddr == nil {
+ p.pickedAddr = p.addrs.Front()
+ }
+ }
+ if p.pickedAddr == nil {
+ return "", fmt.Errorf("there is no address available to pick")
+ }
+ return p.pickedAddr.Value.(*addrInfo).addr, nil
+}
+
+func (p *unicastNamingPicker) State() (ConnectivityState, error) {
+ return 0, fmt.Errorf("State() is not supported for unicastNamingPicker")
+}
+
+func (p *unicastNamingPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
+ return 0, fmt.Errorf("WaitForStateChange is not supported for unicastNamingPciker")
+}
+
+func (p *unicastNamingPicker) Close() error {
+ p.watcher.Close()
+ p.conn.Close()
+ return nil
+}
diff --git a/picker_test.go b/picker_test.go
new file mode 100644
index 0000000..efe0c23
--- /dev/null
+++ b/picker_test.go
@@ -0,0 +1,182 @@
+/*
+ *
+ * Copyright 2014, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ * * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package grpc
+
+import (
+ "fmt"
+ "math"
+ "testing"
+ "time"
+
+ "golang.org/x/net/context"
+ "google.golang.org/grpc/naming"
+)
+
+type testWatcher struct {
+ // the channel to receives name resolution updates
+ update chan *naming.Update
+ // the side channel to get to know how many updates in a batch
+ side chan int
+ // the channel to notifiy update injector that the update reading is done
+ readDone chan int
+}
+
+func (w *testWatcher) Next() (updates []*naming.Update, err error) {
+ n := <-w.side
+ if n == 0 {
+ return nil, fmt.Errorf("w.side is closed")
+ }
+ for i := 0; i < n; i++ {
+ u := <-w.update
+ if u != nil {
+ updates = append(updates, u)
+ }
+ }
+ w.readDone <- 0
+ return
+}
+
+func (w *testWatcher) Close() {
+}
+
+func (w *testWatcher) inject(updates []*naming.Update) {
+ w.side <- len(updates)
+ for _, u := range updates {
+ w.update <- u
+ }
+ <-w.readDone
+}
+
+type testNameResolver struct {
+ w *testWatcher
+ addr string
+}
+
+func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
+ r.w = &testWatcher{
+ update: make(chan *naming.Update, 1),
+ side: make(chan int, 1),
+ readDone: make(chan int),
+ }
+ r.w.side <- 1
+ r.w.update <- &naming.Update{
+ Op: naming.Add,
+ Addr: r.addr,
+ }
+ go func() {
+ <-r.w.readDone
+ }()
+ return r.w, nil
+}
+
+func startServers(t *testing.T, numServers, port int, maxStreams uint32) ([]*server, *testNameResolver) {
+ var servers []*server
+ for i := 0; i < numServers; i++ {
+ s := &server{readyChan: make(chan bool)}
+ servers = append(servers, s)
+ go s.start(t, port, maxStreams)
+ s.wait(t, 2*time.Second)
+ }
+ // Point to server1
+ addr := "127.0.0.1:" + servers[0].port
+ return servers, &testNameResolver{addr: addr}
+}
+
+func TestNameDiscovery(t *testing.T) {
+ // Start 3 servers on 3 ports.
+ servers, r := startServers(t, 3, 0, math.MaxUint32)
+ cc, err := Dial("foo.bar.com", WithPicker(NewUnicastNamingPicker(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
+ if err != nil {
+ t.Fatalf("Failed to create ClientConn: %v", err)
+ }
+ var reply string
+ if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
+ t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
+ }
+ // Inject name resolution change to point to the second server now.
+ var updates []*naming.Update
+ updates = append(updates, &naming.Update{
+ Op: naming.Delete,
+ Addr: "127.0.0.1:" + servers[0].port,
+ })
+ updates = append(updates, &naming.Update{
+ Op: naming.Add,
+ Addr: "127.0.0.1:" + servers[1].port,
+ })
+ r.w.inject(updates)
+ servers[0].stop()
+ if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
+ t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
+ }
+ // Add another server address (server#3) to name resolution
+ updates = nil
+ updates = append(updates, &naming.Update{
+ Op: naming.Add,
+ Addr: "127.0.0.1:" + servers[2].port,
+ })
+ r.w.inject(updates)
+ // Stop server#2. The library should direct to server#3 automatically.
+ servers[1].stop()
+ if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
+ t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
+ }
+ cc.Close()
+ servers[2].stop()
+}
+
+func TestEmptyAddrs(t *testing.T) {
+ servers, r := startServers(t, 1, 0, math.MaxUint32)
+ cc, err := Dial("foo.bar.com", WithPicker(NewUnicastNamingPicker(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
+ if err != nil {
+ t.Fatalf("Failed to create ClientConn: %v", err)
+ }
+ var reply string
+ if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
+ t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
+ }
+ // Inject name resolution change to remove the server address so that there is no address
+ // available after that.
+ var updates []*naming.Update
+ updates = append(updates, &naming.Update{
+ Op: naming.Delete,
+ Addr: "127.0.0.1:" + servers[0].port,
+ })
+ r.w.inject(updates)
+ ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
+ if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err == nil {
+ t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want non-<nil>", err)
+ }
+ cc.Close()
+ servers[0].stop()
+}
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 2633617..be7691d 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -388,32 +388,35 @@
func testTimeoutOnDeadServer(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, "", e)
tc := testpb.NewTestServiceClient(cc)
- if ok := cc.WaitForStateChange(time.Second, grpc.Idle); !ok {
- t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Idle, ok)
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
+ t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Idle, err)
}
- if ok := cc.WaitForStateChange(time.Second, grpc.Connecting); !ok {
- t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Connecting, ok)
+ ctx, _ = context.WithTimeout(context.Background(), time.Second)
+ if _, err := cc.WaitForStateChange(ctx, grpc.Connecting); err != nil {
+ t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Connecting, err)
}
- if cc.State() != grpc.Ready {
- t.Fatalf("cc.State() = %s, want %s", cc.State(), grpc.Ready)
+ if state, err := cc.State(); err != nil || state != grpc.Ready {
+ t.Fatalf("cc.State() = %s, %v, want %s, <nil>", state, err, grpc.Ready)
}
- if ok := cc.WaitForStateChange(time.Millisecond, grpc.Ready); ok {
- t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want false", grpc.Ready, ok)
+ ctx, _ = context.WithTimeout(context.Background(), time.Second)
+ if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != context.DeadlineExceeded {
+ t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, %v", grpc.Ready, err, context.DeadlineExceeded)
}
s.Stop()
// Set -1 as the timeout to make sure if transportMonitor gets error
// notification in time the failure path of the 1st invoke of
// ClientConn.wait hits the deadline exceeded error.
- ctx, _ := context.WithTimeout(context.Background(), -1)
+ ctx, _ = context.WithTimeout(context.Background(), -1)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("TestService/EmptyCall(%v, _) = _, error %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded)
}
- if ok := cc.WaitForStateChange(time.Second, grpc.Ready); !ok {
- t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Ready, ok)
+ ctx, _ = context.WithTimeout(context.Background(), time.Second)
+ if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != nil {
+ t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Ready, err)
}
- state := cc.State()
- if state != grpc.Connecting && state != grpc.TransientFailure {
- t.Fatalf("cc.State() = %s, want %s or %s", state, grpc.Connecting, grpc.TransientFailure)
+ if state, err := cc.State(); err != nil || (state != grpc.Connecting && state != grpc.TransientFailure) {
+ t.Fatalf("cc.State() = %s, %v, want %s or %s, <nil>", state, err, grpc.Connecting, grpc.TransientFailure)
}
cc.Close()
}
@@ -521,17 +524,20 @@
func testEmptyUnaryWithUserAgent(t *testing.T, e env) {
s, cc := setUp(t, nil, math.MaxUint32, testAppUA, e)
// Wait until cc is connected.
- if ok := cc.WaitForStateChange(time.Second, grpc.Idle); !ok {
- t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Idle, ok)
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
+ t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Idle, err)
}
- if ok := cc.WaitForStateChange(10*time.Second, grpc.Connecting); !ok {
- t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Connecting, ok)
+ ctx, _ = context.WithTimeout(context.Background(), time.Second)
+ if _, err := cc.WaitForStateChange(ctx, grpc.Connecting); err != nil {
+ t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Connecting, err)
}
- if cc.State() != grpc.Ready {
- t.Fatalf("cc.State() = %s, want %s", cc.State(), grpc.Ready)
+ if state, err := cc.State(); err != nil || state != grpc.Ready {
+ t.Fatalf("cc.State() = %s, %v, want %s, <nil>", state, err, grpc.Ready)
}
- if ok := cc.WaitForStateChange(time.Second, grpc.Ready); ok {
- t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want false", grpc.Ready, ok)
+ ctx, _ = context.WithTimeout(context.Background(), time.Second)
+ if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err == nil {
+ t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, %v", grpc.Ready, context.DeadlineExceeded)
}
tc := testpb.NewTestServiceClient(cc)
var header metadata.MD
@@ -543,11 +549,12 @@
t.Fatalf("header[\"ua\"] = %q, %t, want %q, true", v, ok, testAppUA)
}
tearDown(s, cc)
- if ok := cc.WaitForStateChange(5*time.Second, grpc.Ready); !ok {
- t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Ready, ok)
+ ctx, _ = context.WithTimeout(context.Background(), 5 * time.Second)
+ if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != nil {
+ t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Ready, err)
}
- if cc.State() != grpc.Shutdown {
- t.Fatalf("cc.State() = %s, want %s", cc.State(), grpc.Shutdown)
+ if state, err := cc.State(); err != nil || state != grpc.Shutdown {
+ t.Fatalf("cc.State() = %s, %v, want %s, <nil>", state, err, grpc.Shutdown)
}
}