secure/precis: add test for multiple Transform calls.

Change-Id: Iea031eef1f29153bbe158a5287238b81fd3ee8af
Reviewed-on: https://go-review.googlesource.com/47871
Reviewed-by: Nigel Tao <nigeltao@golang.org>
Run-TryBot: Nigel Tao <nigeltao@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/secure/precis/enforce_test.go b/secure/precis/enforce_test.go
index 0cb5b60..d36a980 100644
--- a/secure/precis/enforce_test.go
+++ b/secure/precis/enforce_test.go
@@ -5,12 +5,14 @@
 package precis
 
 import (
+	"bytes"
 	"fmt"
 	"reflect"
 	"testing"
 
 	"golang.org/x/text/internal/testtext"
 	"golang.org/x/text/secure/bidirule"
+	"golang.org/x/text/transform"
 )
 
 type testCase struct {
@@ -320,3 +322,72 @@
 		t.Errorf("got %f allocs, want 0", n)
 	}
 }
+
+func min(a, b int) int {
+	if a < b {
+		return a
+	}
+	return b
+}
+
+// TestTransformerShortBuffers tests that the precis.Transformer implements the
+// spirit, not just the letter (the method signatures), of the
+// transform.Transformer interface.
+//
+// In particular, it tests that, if one or both of the dst or src buffers are
+// short, so that multiple Transform calls are required to complete the overall
+// transformation, the end result is identical to one Transform call with
+// sufficiently long buffers.
+func TestTransformerShortBuffers(t *testing.T) {
+	srcUnit := []byte("a\u0300cce\u0301nts") // NFD normalization form.
+	wantUnit := []byte("àccénts")            // NFC normalization form.
+	src := bytes.Repeat(srcUnit, 16)
+	want := bytes.Repeat(wantUnit, 16)
+	const long = 4096
+	dst := make([]byte, long)
+
+	// 5, 7, 9, 11, 13, 16 and 17 are all pair-wise co-prime, which means that
+	// slicing the dst and src buffers into 5, 7, 13 and 17 byte chunks will
+	// fall at different places inside the repeated srcUnit's and wantUnit's.
+	if len(srcUnit) != 11 || len(wantUnit) != 9 || len(src) > long || len(want) > long {
+		t.Fatal("inconsistent lengths")
+	}
+
+	tr := NewFreeform().NewTransformer()
+	for _, deltaD := range []int{5, 7, 13, 17, long} {
+	loop:
+		for _, deltaS := range []int{5, 7, 13, 17, long} {
+			tr.Reset()
+			d0 := 0
+			s0 := 0
+			for {
+				d1 := min(len(dst), d0+deltaD)
+				s1 := min(len(src), s0+deltaS)
+				nDst, nSrc, err := tr.Transform(dst[d0:d1:d1], src[s0:s1:s1], s1 == len(src))
+				d0 += nDst
+				s0 += nSrc
+				if err == nil {
+					break
+				}
+				if err == transform.ErrShortDst || err == transform.ErrShortSrc {
+					continue
+				}
+				t.Errorf("deltaD=%d, deltaS=%d: %v", deltaD, deltaS, err)
+				continue loop
+			}
+			if s0 != len(src) {
+				t.Errorf("deltaD=%d, deltaS=%d: s0: got %d, want %d", deltaD, deltaS, s0, len(src))
+				continue
+			}
+			if d0 != len(want) {
+				t.Errorf("deltaD=%d, deltaS=%d: d0: got %d, want %d", deltaD, deltaS, d0, len(want))
+				continue
+			}
+			got := dst[:d0]
+			if !bytes.Equal(got, want) {
+				t.Errorf("deltaD=%d, deltaS=%d:\ngot  %q\nwant %q", deltaD, deltaS, got, want)
+				continue
+			}
+		}
+	}
+}