internal/gen/bitfield: added package to simplify bit fiddling

Change-Id: Ie54d050135a39f7fc3df419fcbdc042363b7e216
Reviewed-on: https://go-review.googlesource.com/96735
Run-TryBot: Marcel van Lohuizen <mpvl@golang.org>
Reviewed-by: Ross Light <light@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/internal/gen/bitfield/bitfield.go b/internal/gen/bitfield/bitfield.go
new file mode 100644
index 0000000..a8d0a48
--- /dev/null
+++ b/internal/gen/bitfield/bitfield.go
@@ -0,0 +1,226 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package bitfield converts annotated structs into integer values.
+//
+// Any field that is marked with a bitfield tag is compacted. The tag value has
+// two parts. The part before the comma determines the method name for a
+// generated type. If left blank the name of the field is used.
+// The part after the comma determines the number of bits to use for the
+// representation.
+package bitfield
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"reflect"
+	"strconv"
+	"strings"
+)
+
+// Config determines settings for packing and generation. If a Config is used,
+// the same Config should be used for packing and generation.
+type Config struct {
+	// NumBits fixes the maximum allowed bits for the integer representation.
+	// If NumBits is not 8, 16, 32, or 64, the actual underlying integer size
+	// will be the next largest available.
+	NumBits uint
+
+	// If Package is set, code generation will write a package clause.
+	Package string
+
+	// TypeName is the name for the generated type. By default it is the name
+	// of the type of the value passed to Gen.
+	TypeName string
+}
+
+var nullConfig = &Config{}
+
+// Pack packs annotated bit ranges of struct x in an integer.
+//
+// Only fields that have a "bitfield" tag are compacted.
+func Pack(x interface{}, c *Config) (packed uint64, err error) {
+	packed, _, err = pack(x, c)
+	return
+}
+
+func pack(x interface{}, c *Config) (packed uint64, nBit uint, err error) {
+	if c == nil {
+		c = nullConfig
+	}
+	nBits := c.NumBits
+	v := reflect.ValueOf(x)
+	v = reflect.Indirect(v)
+	t := v.Type()
+	pos := 64 - nBits
+	if nBits == 0 {
+		pos = 0
+	}
+	for i := 0; i < v.NumField(); i++ {
+		v := v.Field(i)
+		field := t.Field(i)
+		f, err := parseField(field)
+
+		if err != nil {
+			return 0, 0, err
+		}
+		if f.nBits == 0 {
+			continue
+		}
+		value := uint64(0)
+		switch v.Kind() {
+		case reflect.Bool:
+			if v.Bool() {
+				value = 1
+			}
+		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+			value = v.Uint()
+		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+			x := v.Int()
+			if x < 0 {
+				return 0, 0, fmt.Errorf("bitfield: negative value for field %q not allowed", field.Name)
+			}
+			value = uint64(x)
+		}
+		if value > (1<<f.nBits)-1 {
+			return 0, 0, fmt.Errorf("bitfield: value %#x of field %q does not fit in %d bits", value, field.Name, f.nBits)
+		}
+		shift := 64 - pos - f.nBits
+		if pos += f.nBits; pos > 64 {
+			return 0, 0, fmt.Errorf("bitfield: no more bits left for field %q", field.Name)
+		}
+		packed |= value << shift
+	}
+	if nBits == 0 {
+		nBits = posToBits(pos)
+		packed >>= (64 - nBits)
+	}
+	return packed, nBits, nil
+}
+
+type field struct {
+	name  string
+	value uint64
+	nBits uint
+}
+
+// parseField parses a tag of the form [<name>][:<nBits>][,<pos>[..<end>]]
+func parseField(field reflect.StructField) (f field, err error) {
+	s, ok := field.Tag.Lookup("bitfield")
+	if !ok {
+		return f, nil
+	}
+	switch field.Type.Kind() {
+	case reflect.Bool:
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+	default:
+		return f, fmt.Errorf("bitfield: field %q is not an integer or bool type", field.Name)
+	}
+	bits := s
+	f.name = ""
+
+	if i := strings.IndexByte(s, ','); i >= 0 {
+		bits = s[:i]
+		f.name = s[i+1:]
+	}
+	if bits != "" {
+		nBits, err := strconv.ParseUint(bits, 10, 8)
+		if err != nil {
+			return f, fmt.Errorf("bitfield: invalid bit size for field %q: %v", field.Name, err)
+		}
+		f.nBits = uint(nBits)
+	}
+	if f.nBits == 0 {
+		if field.Type.Kind() == reflect.Bool {
+			f.nBits = 1
+		} else {
+			f.nBits = uint(field.Type.Bits())
+		}
+	}
+	if f.name == "" {
+		f.name = field.Name
+	}
+	return f, err
+}
+
+func posToBits(pos uint) (bits uint) {
+	switch {
+	case pos <= 8:
+		bits = 8
+	case pos <= 16:
+		bits = 16
+	case pos <= 32:
+		bits = 32
+	case pos <= 64:
+		bits = 64
+	default:
+		panic("unreachable")
+	}
+	return bits
+}
+
+// Gen generates code for unpacking integers created with Pack.
+func Gen(w io.Writer, x interface{}, c *Config) error {
+	if c == nil {
+		c = nullConfig
+	}
+	_, nBits, err := pack(x, c)
+	if err != nil {
+		return err
+	}
+
+	t := reflect.TypeOf(x)
+	if t.Kind() == reflect.Ptr {
+		t = t.Elem()
+	}
+	if c.TypeName == "" {
+		c.TypeName = t.Name()
+	}
+	firstChar := []rune(c.TypeName)[0]
+
+	buf := &bytes.Buffer{}
+
+	print := func(w io.Writer, format string, args ...interface{}) {
+		if _, e := fmt.Fprintf(w, format+"\n", args...); e != nil && err == nil {
+			err = fmt.Errorf("bitfield: write failed: %v", err)
+		}
+	}
+
+	pos := uint(0)
+	for i := 0; i < t.NumField(); i++ {
+		field := t.Field(i)
+		f, _ := parseField(field)
+		if f.nBits == 0 {
+			continue
+		}
+		shift := nBits - pos - f.nBits
+		pos += f.nBits
+
+		retType := field.Type.Name()
+		print(buf, "\nfunc (%c %s) %s() %s {", firstChar, c.TypeName, f.name, retType)
+		if field.Type.Kind() == reflect.Bool {
+			print(buf, "\tconst bit = 1 << %d", shift)
+			print(buf, "\treturn %c&bit == bit", firstChar)
+		} else {
+			print(buf, "\treturn %s((%c >> %d) & %#x)", retType, firstChar, shift, (1<<f.nBits)-1)
+		}
+		print(buf, "}")
+	}
+
+	if c.Package != "" {
+		print(w, "// Code generated by golang.org/x/text/internal/gen/bitfield. DO NOT EDIT.\n")
+		print(w, "package %s\n", c.Package)
+	}
+
+	bits := posToBits(pos)
+
+	print(w, "type %s uint%d", c.TypeName, bits)
+
+	if _, err := io.Copy(w, buf); err != nil {
+		return fmt.Errorf("bitfield: write failed: %v", err)
+	}
+	return nil
+}
diff --git a/internal/gen/bitfield/bitfield_test.go b/internal/gen/bitfield/bitfield_test.go
new file mode 100644
index 0000000..789f86d
--- /dev/null
+++ b/internal/gen/bitfield/bitfield_test.go
@@ -0,0 +1,230 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package bitfield
+
+import (
+	"bytes"
+	"fmt"
+	"io/ioutil"
+	"testing"
+)
+
+type myUint8 uint8
+
+type test1 struct { // 28 bits
+	foo  uint16 `bitfield:",fob"`
+	Bar  int8   `bitfield:"5,baz"`
+	Foo  uint64
+	bar  myUint8 `bitfield:"3"`
+	Bool bool    `bitfield:""`
+	Baz  int8    `bitfield:"3"`
+}
+
+type test2 struct {
+	larger1 uint16 `bitfield:"32"`
+	larger2 uint16 `bitfield:"32"`
+}
+
+type tooManyBits struct {
+	u1 uint16 `bitfield:"12"`
+	u2 uint16 `bitfield:"12"`
+	u3 uint16 `bitfield:"12"`
+	u4 uint16 `bitfield:"12"`
+	u5 uint16 `bitfield:"12"`
+	u6 uint16 `bitfield:"12"`
+}
+
+type just64 struct {
+	foo uint64 `bitfield:""`
+}
+
+type toUint8 struct {
+	foo bool `bitfield:""`
+}
+
+type toUint16 struct {
+	foo int `bitfield:"9"`
+}
+
+type faultySize struct {
+	foo uint64 `bitfield:"a"`
+}
+
+type faultyType struct {
+	foo *int `bitfield:"5"`
+}
+
+var (
+	maxed = test1{
+		foo:  0xffff,
+		Bar:  0x1f,
+		Foo:  0xffff,
+		bar:  0x7,
+		Bool: true,
+		Baz:  0x7,
+	}
+	alternate1 = test1{
+		foo: 0xffff,
+		bar: 0x7,
+		Baz: 0x7,
+	}
+	alternate2 = test1{
+		Bar:  0x1f,
+		Bool: true,
+	}
+	overflow = test1{
+		Bar: 0x3f,
+	}
+	negative = test1{
+		Bar: -1,
+	}
+)
+
+func TestPack(t *testing.T) {
+	testCases := []struct {
+		desc  string
+		x     interface{}
+		nBits uint
+		out   uint64
+		ok    bool
+	}{
+		{"maxed out fields", maxed, 0, 0xfffffff0, true},
+		{"maxed using less bits", maxed, 28, 0x0fffffff, true},
+
+		{"alternate1", alternate1, 0, 0xffff0770, true},
+		{"alternate2", alternate2, 0, 0x0000f880, true},
+
+		{"just64", &just64{0x0f0f0f0f}, 00, 0xf0f0f0f, true},
+		{"just64", &just64{0x0f0f0f0f}, 64, 0xf0f0f0f, true},
+		{"just64", &just64{0xffffFFFF}, 64, 0xffffffff, true},
+		{"to uint8", &toUint8{true}, 0, 0x80, true},
+		{"to uint16", &toUint16{1}, 0, 0x0080, true},
+		// errors
+		{"overflow", overflow, 0, 0, false},
+		{"too many bits", &tooManyBits{}, 0, 0, false},
+		{"fault size", &faultySize{}, 0, 0, false},
+		{"fault type", &faultyType{}, 0, 0, false},
+		{"negative", negative, 0, 0, false},
+		{"not enough bits", maxed, 27, 0, false},
+	}
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("%T/%s", tc.x, tc.desc), func(t *testing.T) {
+			v, err := Pack(tc.x, &Config{NumBits: tc.nBits})
+			if ok := err == nil; v != tc.out || ok != tc.ok {
+				t.Errorf("got %#x, %v; want %#x, %v (%v)", v, ok, tc.out, tc.ok, err)
+			}
+		})
+	}
+}
+
+func TestRoundtrip(t *testing.T) {
+	testCases := []struct {
+		x test1
+	}{
+		{maxed},
+		{alternate1},
+		{alternate2},
+	}
+	for _, tc := range testCases {
+		t.Run("", func(t *testing.T) {
+			v, err := Pack(tc.x, nil)
+			if err != nil {
+				t.Fatal(err)
+			}
+			want := tc.x
+			want.Foo = 0 // not stored
+			x := myInt(v)
+			got := test1{
+				foo:  x.fob(),
+				Bar:  x.baz(),
+				bar:  x.bar(),
+				Bool: x.Bool(),
+				Baz:  x.Baz(),
+			}
+			if got != want {
+				t.Errorf("\ngot  %#v\nwant %#v (%#x)", got, want, v)
+			}
+		})
+	}
+}
+
+func TestGen(t *testing.T) {
+	testCases := []struct {
+		desc   string
+		x      interface{}
+		config *Config
+		ok     bool
+		out    string
+	}{{
+		desc: "test1",
+		x:    &test1{},
+		ok:   true,
+		out:  test1Gen,
+	}, {
+		desc:   "test1 with options",
+		x:      &test1{},
+		config: &Config{Package: "bitfield", TypeName: "myInt"},
+		ok:     true,
+		out:    mustRead("gen1_test.go"),
+	}, {
+		desc:   "test1 with alternative bits",
+		x:      &test1{},
+		config: &Config{NumBits: 28, Package: "bitfield", TypeName: "myInt2"},
+		ok:     true,
+		out:    mustRead("gen2_test.go"),
+	}, {
+		desc:   "failure",
+		x:      &test1{},
+		config: &Config{NumBits: 27}, // Too few bits.
+		ok:     false,
+		out:    "",
+	}}
+
+	for _, tc := range testCases {
+		t.Run(tc.desc, func(t *testing.T) {
+			w := &bytes.Buffer{}
+			err := Gen(w, tc.x, tc.config)
+			if ok := err == nil; ok != tc.ok {
+				t.Fatalf("got %v; want %v (%v)", ok, tc.ok, err)
+			}
+			got := w.String()
+			if got != tc.out {
+				t.Errorf("got:\n%s\nwant:\n%s", got, tc.out)
+			}
+		})
+	}
+}
+
+const test1Gen = `type test1 uint32
+
+func (t test1) fob() uint16 {
+	return uint16((t >> 16) & 0xffff)
+}
+
+func (t test1) baz() int8 {
+	return int8((t >> 11) & 0x1f)
+}
+
+func (t test1) bar() myUint8 {
+	return myUint8((t >> 8) & 0x7)
+}
+
+func (t test1) Bool() bool {
+	const bit = 1 << 7
+	return t&bit == bit
+}
+
+func (t test1) Baz() int8 {
+	return int8((t >> 4) & 0x7)
+}
+`
+
+func mustRead(filename string) string {
+	b, err := ioutil.ReadFile(filename)
+	if err != nil {
+		panic(err)
+	}
+	return string(b)
+}
diff --git a/internal/gen/bitfield/gen1_test.go b/internal/gen/bitfield/gen1_test.go
new file mode 100644
index 0000000..2844b9d
--- /dev/null
+++ b/internal/gen/bitfield/gen1_test.go
@@ -0,0 +1,26 @@
+// Code generated by golang.org/x/text/internal/gen/bitfield. DO NOT EDIT.
+
+package bitfield
+
+type myInt uint32
+
+func (m myInt) fob() uint16 {
+	return uint16((m >> 16) & 0xffff)
+}
+
+func (m myInt) baz() int8 {
+	return int8((m >> 11) & 0x1f)
+}
+
+func (m myInt) bar() myUint8 {
+	return myUint8((m >> 8) & 0x7)
+}
+
+func (m myInt) Bool() bool {
+	const bit = 1 << 7
+	return m&bit == bit
+}
+
+func (m myInt) Baz() int8 {
+	return int8((m >> 4) & 0x7)
+}
diff --git a/internal/gen/bitfield/gen2_test.go b/internal/gen/bitfield/gen2_test.go
new file mode 100644
index 0000000..ad5f72a
--- /dev/null
+++ b/internal/gen/bitfield/gen2_test.go
@@ -0,0 +1,26 @@
+// Code generated by golang.org/x/text/internal/gen/bitfield. DO NOT EDIT.
+
+package bitfield
+
+type myInt2 uint32
+
+func (m myInt2) fob() uint16 {
+	return uint16((m >> 12) & 0xffff)
+}
+
+func (m myInt2) baz() int8 {
+	return int8((m >> 7) & 0x1f)
+}
+
+func (m myInt2) bar() myUint8 {
+	return myUint8((m >> 4) & 0x7)
+}
+
+func (m myInt2) Bool() bool {
+	const bit = 1 << 3
+	return m&bit == bit
+}
+
+func (m myInt2) Baz() int8 {
+	return int8((m >> 0) & 0x7)
+}