io: add WriteString support to MultiWriter

Fixes #11805

Change-Id: I081e16b869dc706bd847ee645bb902bc671c123f
Reviewed-on: https://go-review.googlesource.com/12485
Reviewed-by: Josh Bleecher Snyder <josharian@gmail.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/io/multi.go b/src/io/multi.go
index e26cc53..16860aa 100644
--- a/src/io/multi.go
+++ b/src/io/multi.go
@@ -52,6 +52,30 @@
 	return len(p), nil
 }
 
+var _ stringWriter = (*multiWriter)(nil)
+
+func (t *multiWriter) WriteString(s string) (n int, err error) {
+	var p []byte // lazily initialized if/when needed
+	for _, w := range t.writers {
+		if sw, ok := w.(stringWriter); ok {
+			n, err = sw.WriteString(s)
+		} else {
+			if p == nil {
+				p = []byte(s)
+			}
+			n, err = w.Write(p)
+		}
+		if err != nil {
+			return
+		}
+		if n != len(s) {
+			err = ErrShortWrite
+			return
+		}
+	}
+	return len(s), nil
+}
+
 // MultiWriter creates a writer that duplicates its writes to all the
 // provided writers, similar to the Unix tee(1) command.
 func MultiWriter(writers ...Writer) Writer {
diff --git a/src/io/multi_test.go b/src/io/multi_test.go
index 56c6769..e80592e 100644
--- a/src/io/multi_test.go
+++ b/src/io/multi_test.go
@@ -62,8 +62,59 @@
 }
 
 func TestMultiWriter(t *testing.T) {
-	sha1 := sha1.New()
 	sink := new(bytes.Buffer)
+	// Hide bytes.Buffer's WriteString method:
+	testMultiWriter(t, struct {
+		Writer
+		fmt.Stringer
+	}{sink, sink})
+}
+
+func TestMultiWriter_String(t *testing.T) {
+	testMultiWriter(t, new(bytes.Buffer))
+}
+
+// test that a multiWriter.WriteString calls results in at most 1 allocation,
+// even if multiple targets don't support WriteString.
+func TestMultiWriter_WriteStringSingleAlloc(t *testing.T) {
+	var sink1, sink2 bytes.Buffer
+	type simpleWriter struct { // hide bytes.Buffer's WriteString
+		Writer
+	}
+	mw := MultiWriter(simpleWriter{&sink1}, simpleWriter{&sink2})
+	allocs := int(testing.AllocsPerRun(1000, func() {
+		WriteString(mw, "foo")
+	}))
+	if allocs != 1 {
+		t.Errorf("num allocations = %d; want 1", allocs)
+	}
+}
+
+type writeStringChecker struct{ called bool }
+
+func (c *writeStringChecker) WriteString(s string) (n int, err error) {
+	c.called = true
+	return len(s), nil
+}
+
+func (c *writeStringChecker) Write(p []byte) (n int, err error) {
+	return len(p), nil
+}
+
+func TestMultiWriter_StringCheckCall(t *testing.T) {
+	var c writeStringChecker
+	mw := MultiWriter(&c)
+	WriteString(mw, "foo")
+	if !c.called {
+		t.Error("did not see WriteString call to writeStringChecker")
+	}
+}
+
+func testMultiWriter(t *testing.T, sink interface {
+	Writer
+	fmt.Stringer
+}) {
+	sha1 := sha1.New()
 	mw := MultiWriter(sha1, sink)
 
 	sourceString := "My input text."