math/big: permit internal nat.scan to accept decimal point

This will simplify parsing of rational and (eventually) floating point numbers.

Also streamlined inner loop. As a result, scan runs faster for all but short
(<= 10 digit) numbers. For short numbers it is < 10% slower (cause is known
and could be addressed in a future CL).

Minor unrelated cleanups. Added additional tests.

benchmark                     old ns/op     new ns/op     delta
BenchmarkScanPi               134465        125122        -6.95%
BenchmarkScan10Base2          493           540           +9.53%
BenchmarkScan100Base2         3608          3244          -10.09%
BenchmarkScan1000Base2        35376         32377         -8.48%
BenchmarkScan10000Base2       481504        450028        -6.54%
BenchmarkScan100000Base2      17936774      17648080      -1.61%
BenchmarkScan10Base8          258           280           +8.53%
BenchmarkScan100Base8         1389          1323          -4.75%
BenchmarkScan1000Base8        14221         13036         -8.33%
BenchmarkScan10000Base8       271298        258993        -4.54%
BenchmarkScan100000Base8      15715465      15672580      -0.27%
BenchmarkScan10Base10         249           268           +7.63%
BenchmarkScan100Base10        1324          1220          -7.85%
BenchmarkScan1000Base10       13398         12234         -8.69%
BenchmarkScan10000Base10      259157        249342        -3.79%
BenchmarkScan100000Base10     15670613      15582409      -0.56%
BenchmarkScan10Base16         231           251           +8.66%
BenchmarkScan100Base16        1093          1065          -2.56%
BenchmarkScan1000Base16       12687         12196         -3.87%
BenchmarkScan10000Base16      282349        271443        -3.86%
BenchmarkScan100000Base16     16742669      16552917      -1.13%

Change-Id: I4b9b078792788aef872b307399f00ffd34903127
Reviewed-on: https://go-review.googlesource.com/2960
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/src/math/big/int.go b/src/math/big/int.go
index 62a07d6..98f0b24 100644
--- a/src/math/big/int.go
+++ b/src/math/big/int.go
@@ -477,7 +477,7 @@
 	}
 
 	// determine mantissa
-	z.abs, base, err = z.abs.scan(r, base)
+	z.abs, base, _, err = z.abs.scan(r, base)
 	if err != nil {
 		return nil, base, err
 	}
diff --git a/src/math/big/int_test.go b/src/math/big/int_test.go
index a698e2d..d373e84 100644
--- a/src/math/big/int_test.go
+++ b/src/math/big/int_test.go
@@ -621,7 +621,7 @@
 	}
 }
 
-var bitTests = []nat{
+var bitsTests = []nat{
 	nil,
 	{0},
 	{1},
@@ -639,7 +639,7 @@
 }
 
 func TestBits(t *testing.T) {
-	for _, test := range bitTests {
+	for _, test := range bitsTests {
 		var z Int
 		z.neg = true
 		got := z.SetBits(test)
diff --git a/src/math/big/nat.go b/src/math/big/nat.go
index bf00b88..4d65c5f 100644
--- a/src/math/big/nat.go
+++ b/src/math/big/nat.go
@@ -68,7 +68,7 @@
 
 func (z nat) make(n int) nat {
 	if n <= cap(z) {
-		return z[0:n] // reuse z
+		return z[:n] // reuse z
 	}
 	// Choosing a good value for e has significant performance impact
 	// because it increases the chance that a value can be reused.
@@ -78,7 +78,7 @@
 
 func (z nat) setWord(x Word) nat {
 	if x == 0 {
-		return z.make(0)
+		return z[:0]
 	}
 	z = z.make(1)
 	z[0] = x
@@ -122,7 +122,7 @@
 		return z.add(y, x)
 	case m == 0:
 		// n == 0 because m >= n; result is 0
-		return z.make(0)
+		return z[:0]
 	case n == 0:
 		// result is x
 		return z.set(x)
@@ -148,7 +148,7 @@
 		panic("underflow")
 	case m == 0:
 		// n == 0 because m >= n; result is 0
-		return z.make(0)
+		return z[:0]
 	case n == 0:
 		// result is x
 		return z.set(x)
@@ -384,7 +384,7 @@
 	case m < n:
 		return z.mul(y, x)
 	case m == 0 || n == 0:
-		return z.make(0)
+		return z[:0]
 	case n == 1:
 		return z.mulAddWW(x, y[0], 0)
 	}
@@ -488,7 +488,7 @@
 		q = z.set(x) // result is x
 		return
 	case m == 0:
-		q = z.make(0) // result is 0
+		q = z[:0] // result is 0
 		return
 	}
 	// m > 0
@@ -504,7 +504,7 @@
 	}
 
 	if u.cmp(v) < 0 {
-		q = z.make(0)
+		q = z[:0]
 		r = z2.set(u)
 		return
 	}
@@ -543,7 +543,7 @@
 		u = nil // u is an alias for uIn or v - cannot reuse
 	}
 	u = u.make(len(uIn) + 1)
-	u.clear()
+	u.clear() // TODO(gri) no need to clear if we allocated a new u
 
 	// D1.
 	shift := leadingZeros(v[n-1])
@@ -607,51 +607,82 @@
 }
 
 // MaxBase is the largest number base accepted for string conversions.
-const MaxBase = 'z' - 'a' + 10 + 1 // = hexValue('z') + 1
+const MaxBase = 'z' - 'a' + 10 + 1
 
-func hexValue(ch rune) Word {
-	d := int(MaxBase + 1) // invalid base
-	switch {
-	case '0' <= ch && ch <= '9':
-		d = int(ch - '0')
-	case 'a' <= ch && ch <= 'z':
-		d = int(ch - 'a' + 10)
-	case 'A' <= ch && ch <= 'Z':
-		d = int(ch - 'A' + 10)
+// maxPow returns (b**n, n) such that b**n is the largest power b**n <= _M.
+// For instance maxPow(10) == (1e19, 19) for 19 decimal digits in a 64bit Word.
+// In other words, at most n digits in base b fit into a Word.
+// TODO(gri) replace this with a table, generated at build time.
+func maxPow(b Word) (p Word, n int) {
+	p, n = b, 1 // assuming b <= _M
+	for max := _M / b; p <= max; {
+		// p == b**n && p <= max
+		p *= b
+		n++
 	}
-	return Word(d)
+	// p == b**n && p <= _M
+	return
 }
 
-// scan sets z to the natural number corresponding to the longest possible prefix
-// read from r representing an unsigned integer in a given conversion base.
-// It returns z, the actual conversion base used, and an error, if any. In the
-// error case, the value of z is undefined. The syntax follows the syntax of
-// unsigned integer literals in Go.
+// pow returns x**n for n > 0, and 1 otherwise.
+func pow(x Word, n int) (p Word) {
+	p = 1
+	for n > 0 {
+		if n&1 != 0 {
+			p *= x
+		}
+		x *= x
+		n >>= 1
+	}
+	return
+}
+
+// scan scans the number corresponding to the longest possible prefix
+// from r representing an unsigned number in a given conversion base.
+// It returns the corresponding natural number res, the actual base b,
+// a digit count, and an error err, if any.
 //
-// The base argument must be 0 or a value from 2 through MaxBase. If the base
-// is 0, the string prefix determines the actual conversion base. A prefix of
-// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
-// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
+//	number = [ prefix ] digits | digits "." [ digits ] | "." digits .
+//	prefix = "0" [ "x" | "X" | "b" | "B" ] .
+//	digits = digit { digit } .
+//	digit  = "0" ... "9" | "a" ... "z" | "A" ... "Z" .
 //
-func (z nat) scan(r io.RuneScanner, base int) (nat, int, error) {
-	// reject invalid bases
-	if base < 0 || base == 1 || MaxBase < base {
-		return z, 0, errors.New("invalid number base")
+// The base argument must be a value between 0 and MaxBase (inclusive).
+// For base 0, the number prefix determines the actual base: A prefix of
+// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and
+// a ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base
+// is 10 and no prefix is permitted.
+//
+// Base argument 1 selects actual base 10 but also enables scanning a number
+// with a decimal point.
+//
+// A result digit count > 0 corresponds to the number of (non-prefix) digits
+// parsed. A digit count <= 0 indicates the presence of a decimal point (for
+// base == 1, only), and the number of fractional digits is -count. In this
+// case, the value of the scanned number is res * 10**count.
+//
+func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error) {
+	// reject illegal bases
+	if base < 0 || base > MaxBase {
+		err = errors.New("illegal number base")
+		return
 	}
 
 	// one char look-ahead
 	ch, _, err := r.ReadRune()
 	if err != nil {
-		return z, 0, err
+		return
 	}
 
-	// determine base if necessary
-	b := Word(base)
-	if base == 0 {
+	// determine actual base
+	switch base {
+	case 0:
+		// actual base is 10 unless there's a base prefix
 		b = 10
 		if ch == '0' {
 			switch ch, _, err = r.ReadRune(); err {
 			case nil:
+				// possibly one of 0x, 0X, 0b, 0B
 				b = 8
 				switch ch {
 				case 'x', 'X':
@@ -661,62 +692,120 @@
 				}
 				if b == 2 || b == 16 {
 					if ch, _, err = r.ReadRune(); err != nil {
-						return z, 0, err
+						// io.EOF is also an error in this case
+						return
 					}
 				}
 			case io.EOF:
-				return z.make(0), 10, nil
+				// input is "0"
+				res = z[:0]
+				count = 1
+				err = nil
+				return
 			default:
-				return z, 10, err
+				// read error
+				return
 			}
 		}
+	case 1:
+		// actual base is 10 and decimal point is permitted
+		b = 10
+	default:
+		b = base
 	}
 
 	// convert string
-	// - group as many digits d as possible together into a "super-digit" dd with "super-base" bb
-	// - only when bb does not fit into a word anymore, do a full number mulAddWW using bb and dd
-	z = z.make(0)
-	bb := Word(1)
-	dd := Word(0)
-	for max := _M / b; ; {
-		d := hexValue(ch)
-		if d >= b {
+	// Algorithm: Collect digits in groups of at most n digits in di
+	// and then use mulAddWW for every such group to add them to the
+	// result.
+	z = z[:0]
+	b1 := Word(b)
+	bn, n := maxPow(b1) // at most n digits in base b1 fit into Word
+	di := Word(0)       // 0 <= di < b1**i < bn
+	i := 0              // 0 <= i < n
+	dp := -1            // position of decimal point
+	for {
+		if base == 1 && ch == '.' {
+			base = 10 // no 2nd decimal point permitted
+			dp = count
+			// advance
+			if ch, _, err = r.ReadRune(); err != nil {
+				if err == io.EOF {
+					err = nil
+					break
+				}
+				return
+			}
+		}
+
+		// convert rune into digit value d1
+		var d1 Word
+		switch {
+		case '0' <= ch && ch <= '9':
+			d1 = Word(ch - '0')
+		case 'a' <= ch && ch <= 'z':
+			d1 = Word(ch - 'a' + 10)
+		case 'A' <= ch && ch <= 'Z':
+			d1 = Word(ch - 'A' + 10)
+		default:
+			d1 = MaxBase + 1
+		}
+		if d1 >= b1 {
 			r.UnreadRune() // ch does not belong to number anymore
 			break
 		}
+		count++
 
-		if bb <= max {
-			bb *= b
-			dd = dd*b + d
-		} else {
-			// bb * b would overflow
-			z = z.mulAddWW(z, bb, dd)
-			bb = b
-			dd = d
+		// collect d1 in di
+		di = di*b1 + d1
+		i++
+
+		// if di is "full", add it to the result
+		if i == n {
+			z = z.mulAddWW(z, bn, di)
+			di = 0
+			i = 0
 		}
 
+		// advance
 		if ch, _, err = r.ReadRune(); err != nil {
-			if err != io.EOF {
-				return z, int(b), err
+			if err == io.EOF {
+				err = nil
+				break
 			}
-			break
+			return
 		}
 	}
 
-	switch {
-	case bb > 1:
-		// there was at least one mantissa digit
-		z = z.mulAddWW(z, bb, dd)
-	case base == 0 && b == 8:
-		// there was only the octal prefix 0 (possibly followed by digits > 7);
-		// return base 10, not 8
-		return z, 10, nil
-	case base != 0 || b != 8:
-		// there was neither a mantissa digit nor the octal prefix 0
-		return z, int(b), errors.New("syntax error scanning number")
+	if count == 0 {
+		// no digits found
+		switch {
+		case base == 0 && b == 8:
+			// there was only the octal prefix 0 (possibly followed by digits > 7);
+			// count as one digit and return base 10, not 8
+			count = 1
+			b = 10
+		case base != 0 || b != 8:
+			// there was neither a mantissa digit nor the octal prefix 0
+			err = errors.New("syntax error scanning number")
+		}
+		return
+	}
+	// count > 0
+
+	// add remaining digits to result
+	if i > 0 {
+		z = z.mulAddWW(z, pow(b1, i), di)
+	}
+	res = z.norm()
+
+	// adjust for fraction, if any
+	if dp >= 0 {
+		// 0 <= dp <= count > 0
+		count = dp - count
 	}
 
-	return z.norm(), int(b), nil
+	return
 }
 
 // Character sets for string conversion.
@@ -799,14 +888,7 @@
 		}
 
 	} else {
-		// determine "big base"; i.e., the largest possible value bb
-		// that is a power of base b and still fits into a Word
-		// (as in 10^19 for 19 decimal digits in a 64bit Word)
-		bb := b      // big base is b**ndigits
-		ndigits := 1 // number of base b digits
-		for max := Word(_M / b); bb <= max; bb *= b {
-			ndigits++ // maximize ndigits where bb = b**ndigits, bb <= _M
-		}
+		bb, ndigits := maxPow(Word(b))
 
 		// construct table of successive squares of bb*leafSize to use in subdivisions
 		// result (table != nil) <=> (len(x) > leafSize > 0)
@@ -1047,7 +1129,7 @@
 func (z nat) shl(x nat, s uint) nat {
 	m := len(x)
 	if m == 0 {
-		return z.make(0)
+		return z[:0]
 	}
 	// m > 0
 
@@ -1064,7 +1146,7 @@
 	m := len(x)
 	n := m - int(s/_W)
 	if n <= 0 {
-		return z.make(0)
+		return z[:0]
 	}
 	// n > 0
 
@@ -1103,12 +1185,14 @@
 	panic("set bit is not 0 or 1")
 }
 
-func (z nat) bit(i uint) uint {
-	j := int(i / _W)
-	if j >= len(z) {
+// bit returns the value of the i'th bit, with lsb == bit 0.
+func (x nat) bit(i uint) uint {
+	j := i / _W
+	if j >= uint(len(x)) {
 		return 0
 	}
-	return uint(z[j] >> (i % _W) & 1)
+	// 0 <= j < len(x)
+	return uint(x[j] >> (i % _W) & 1)
 }
 
 func (z nat) and(x, y nat) nat {
diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go
index acd265b..d24ce60 100644
--- a/src/math/big/nat_test.go
+++ b/src/math/big/nat_test.go
@@ -88,7 +88,7 @@
 }
 
 func natFromString(s string) nat {
-	x, _, err := nat(nil).scan(strings.NewReader(s), 0)
+	x, _, _, err := nat(nil).scan(strings.NewReader(s), 0)
 	if err != nil {
 		panic(err)
 	}
@@ -271,7 +271,7 @@
 			t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
 		}
 
-		x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
+		x, b, _, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
 		if x.cmp(a.x) != 0 {
 			t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
 		}
@@ -285,16 +285,16 @@
 }
 
 var natScanTests = []struct {
-	s    string // string to be scanned
-	base int    // input base
-	x    nat    // expected nat
-	b    int    // expected base
-	ok   bool   // expected success
-	next rune   // next character (or 0, if at EOF)
+	s     string // string to be scanned
+	base  int    // input base
+	x     nat    // expected nat
+	b     int    // expected base
+	count int    // expected digit count
+	ok    bool   // expected success
+	next  rune   // next character (or 0, if at EOF)
 }{
 	// error: illegal base
 	{base: -1},
-	{base: 1},
 	{base: 37},
 
 	// error: no mantissa
@@ -306,31 +306,46 @@
 	{s: "0x"},
 	{s: "345", base: 2},
 
+	// error: incorrect use of decimal point
+	{s: ".0"},
+	{s: ".0", base: 10},
+	{s: ".", base: 1},
+	{s: "0x.0"},
+
 	// no errors
-	{"0", 0, nil, 10, true, 0},
-	{"0", 10, nil, 10, true, 0},
-	{"0", 36, nil, 36, true, 0},
-	{"1", 0, nat{1}, 10, true, 0},
-	{"1", 10, nat{1}, 10, true, 0},
-	{"0 ", 0, nil, 10, true, ' '},
-	{"08", 0, nil, 10, true, '8'},
-	{"018", 0, nat{1}, 8, true, '8'},
-	{"0b1", 0, nat{1}, 2, true, 0},
-	{"0b11000101", 0, nat{0xc5}, 2, true, 0},
-	{"03271", 0, nat{03271}, 8, true, 0},
-	{"10ab", 0, nat{10}, 10, true, 'a'},
-	{"1234567890", 0, nat{1234567890}, 10, true, 0},
-	{"xyz", 36, nat{(33*36+34)*36 + 35}, 36, true, 0},
-	{"xyz?", 36, nat{(33*36+34)*36 + 35}, 36, true, '?'},
-	{"0x", 16, nil, 16, true, 'x'},
-	{"0xdeadbeef", 0, nat{0xdeadbeef}, 16, true, 0},
-	{"0XDEADBEEF", 0, nat{0xdeadbeef}, 16, true, 0},
+	{"0", 0, nil, 10, 1, true, 0},
+	{"0", 10, nil, 10, 1, true, 0},
+	{"0", 36, nil, 36, 1, true, 0},
+	{"1", 0, nat{1}, 10, 1, true, 0},
+	{"1", 10, nat{1}, 10, 1, true, 0},
+	{"0 ", 0, nil, 10, 1, true, ' '},
+	{"08", 0, nil, 10, 1, true, '8'},
+	{"08", 10, nat{8}, 10, 2, true, 0},
+	{"018", 0, nat{1}, 8, 1, true, '8'},
+	{"0b1", 0, nat{1}, 2, 1, true, 0},
+	{"0b11000101", 0, nat{0xc5}, 2, 8, true, 0},
+	{"03271", 0, nat{03271}, 8, 4, true, 0},
+	{"10ab", 0, nat{10}, 10, 2, true, 'a'},
+	{"1234567890", 0, nat{1234567890}, 10, 10, true, 0},
+	{"xyz", 36, nat{(33*36+34)*36 + 35}, 36, 3, true, 0},
+	{"xyz?", 36, nat{(33*36+34)*36 + 35}, 36, 3, true, '?'},
+	{"0x", 16, nil, 16, 1, true, 'x'},
+	{"0xdeadbeef", 0, nat{0xdeadbeef}, 16, 8, true, 0},
+	{"0XDEADBEEF", 0, nat{0xdeadbeef}, 16, 8, true, 0},
+
+	// no errors, decimal point
+	{"0.", 0, nil, 10, 1, true, '.'},
+	{"0.", 1, nil, 10, 0, true, 0},
+	{"0.1.2", 1, nat{1}, 10, -1, true, '.'},
+	{".000", 1, nil, 10, -3, true, 0},
+	{"12.3", 1, nat{123}, 10, -1, true, 0},
+	{"012.345", 1, nat{12345}, 10, -3, true, 0},
 }
 
 func TestScanBase(t *testing.T) {
 	for _, a := range natScanTests {
 		r := strings.NewReader(a.s)
-		x, b, err := nat(nil).scan(r, a.base)
+		x, b, count, err := nat(nil).scan(r, a.base)
 		if err == nil && !a.ok {
 			t.Errorf("scan%+v\n\texpected error", a)
 		}
@@ -346,6 +361,9 @@
 		if b != a.b {
 			t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.base)
 		}
+		if count != a.count {
+			t.Errorf("scan%+v\n\tgot count = %d; want %d", a, count, a.count)
+		}
 		next, _, err := r.ReadRune()
 		if err == io.EOF {
 			next = 0
@@ -413,7 +431,7 @@
 // Test case for BenchmarkScanPi.
 func TestScanPi(t *testing.T) {
 	var x nat
-	z, _, err := x.scan(strings.NewReader(pi), 10)
+	z, _, _, err := x.scan(strings.NewReader(pi), 10)
 	if err != nil {
 		t.Errorf("scanning pi: %s", err)
 	}
@@ -445,7 +463,7 @@
 
 func BenchmarkStringPiParallel(b *testing.B) {
 	var x nat
-	x, _, _ = x.scan(strings.NewReader(pi), 0)
+	x, _, _, _ = x.scan(strings.NewReader(pi), 0)
 	if x.decimalString() != pi {
 		panic("benchmark incorrect: conversion failed")
 	}
@@ -757,14 +775,13 @@
 
 func TestExpNN(t *testing.T) {
 	for i, test := range expNNTests {
-		x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0)
-		y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0)
-		out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0)
+		x := natFromString(test.x)
+		y := natFromString(test.y)
+		out := natFromString(test.out)
 
 		var m nat
-
 		if len(test.m) > 0 {
-			m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0)
+			m = natFromString(test.m)
 		}
 
 		z := nat(nil).expNN(x, y, m)
@@ -843,3 +860,37 @@
 		fibo(1e5)
 	}
 }
+
+var bitTests = []struct {
+	x    string
+	i    uint
+	want uint
+}{
+	{"0", 0, 0},
+	{"0", 1, 0},
+	{"0", 1000, 0},
+
+	{"0x1", 0, 1},
+	{"0x10", 0, 0},
+	{"0x10", 3, 0},
+	{"0x10", 4, 1},
+	{"0x10", 5, 0},
+
+	{"0x8000000000000000", 62, 0},
+	{"0x8000000000000000", 63, 1},
+	{"0x8000000000000000", 64, 0},
+
+	{"0x3" + strings.Repeat("0", 32), 127, 0},
+	{"0x3" + strings.Repeat("0", 32), 128, 1},
+	{"0x3" + strings.Repeat("0", 32), 129, 1},
+	{"0x3" + strings.Repeat("0", 32), 130, 0},
+}
+
+func TestBit(t *testing.T) {
+	for i, test := range bitTests {
+		x := natFromString(test.x)
+		if got := x.bit(test.i); got != test.want {
+			t.Errorf("#%d: %s.bit(%d) = %v; want %v", i, test.x, test.i, got, test.want)
+		}
+	}
+}
diff --git a/src/math/big/rat.go b/src/math/big/rat.go
index c5339fe..d5d0470 100644
--- a/src/math/big/rat.go
+++ b/src/math/big/rat.go
@@ -549,7 +549,7 @@
 		}
 		s = s[sep+1:]
 		var err error
-		if z.b.abs, _, err = z.b.abs.scan(strings.NewReader(s), 10); err != nil {
+		if z.b.abs, _, _, err = z.b.abs.scan(strings.NewReader(s), 10); err != nil {
 			return nil, false
 		}
 		if len(z.b.abs) == 0 {