// Copyright 2015 The Go Authors.  All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package binres

import (
	"fmt"
	"unicode/utf16"
)

const (
	SortedFlag uint32 = 1 << 0
	UTF8Flag          = 1 << 8
)

// PoolRef is the i'th string in a pool.
type PoolRef uint32

// Resolve returns the string entry of PoolRef in pl.
func (ref PoolRef) Resolve(pl *Pool) string {
	return pl.strings[ref]
}

// Pool is a container for string and style span collections.
//
// Pool has the following structure marshalled:
//
//		chunkHeader
//	 uint32 number of strings in this pool
//	 uint32 number of style spans in pool
//	 uint32 SortedFlag, UTF8Flag
//	 uint32 index of string data from header
//	 uint32 index of style data from header
//	 []uint32 string indices starting at zero
//	 []uint16 or []uint8 concatenation of string entries
//
// UTF-16 entries are as follows:
//
//	uint16 string length, exclusive
//	uint16 [optional] low word if high bit of length set
//	[n]byte data
//	uint16 0x0000 terminator
//
// UTF-8 entries are as follows:
//
//	uint8 character length, exclusive
//	uint8 [optional] low word if high bit of character length set
//	uint8 byte length, exclusive
//	uint8 [optional] low word if high bit of byte length set
//	[n]byte data
//	uint8 0x00 terminator
type Pool struct {
	chunkHeader
	strings []string
	styles  []*Span
	flags   uint32 // SortedFlag, UTF8Flag
}

// ref returns the PoolRef of s, inserting s if it doesn't exist.
func (pl *Pool) ref(s string) PoolRef {
	for i, x := range pl.strings {
		if s == x {
			return PoolRef(i)
		}
	}
	pl.strings = append(pl.strings, s)
	return PoolRef(len(pl.strings) - 1)
}

// RefByName returns the PoolRef of s, or error if not exists.
func (pl *Pool) RefByName(s string) (PoolRef, error) {
	for i, x := range pl.strings {
		if s == x {
			return PoolRef(i), nil
		}
	}
	return 0, fmt.Errorf("PoolRef by name %q does not exist", s)
}

func (pl *Pool) IsSorted() bool { return pl.flags&SortedFlag == SortedFlag }
func (pl *Pool) IsUTF8() bool   { return pl.flags&UTF8Flag == UTF8Flag }

func (pl *Pool) UnmarshalBinary(bin []byte) error {
	if err := (&pl.chunkHeader).UnmarshalBinary(bin); err != nil {
		return err
	}
	if pl.typ != ResStringPool {
		return fmt.Errorf("have type %s, want %s", pl.typ, ResStringPool)
	}

	pl.strings = make([]string, btou32(bin[8:]))
	pl.styles = make([]*Span, btou32(bin[12:]))
	pl.flags = btou32(bin[16:])

	offstrings := btou32(bin[20:])
	offstyle := btou32(bin[24:])
	hdrlen := 28

	if pl.IsUTF8() {
		for i := range pl.strings {
			offset := int(offstrings + btou32(bin[hdrlen+i*4:]))

			// if leading bit set for nchars and nbytes,
			// treat first byte as 7-bit high word and next as low word.
			nchars := int(bin[offset])
			offset++
			if nchars&(1<<7) != 0 {
				n0 := nchars ^ (1 << 7) // high word
				n1 := int(bin[offset])  // low word
				nchars = n0*(1<<8) + n1
				offset++
			}

			// TODO(d) At least one case in android.jar (api-10) resource table has only
			// highbit set, making 7-bit highword zero. The reason for this is currently
			// unknown but would make tests that unmarshal-marshal to match bytes impossible.
			// The values noted were high-word: 0 (after highbit unset), low-word: 141
			// The size may be treated as an int8 triggering the use of two bytes to store size
			// even though the implementation uses uint8.
			nbytes := int(bin[offset])
			offset++
			if nbytes&(1<<7) != 0 {
				n0 := nbytes ^ (1 << 7) // high word
				n1 := int(bin[offset])  // low word
				nbytes = n0*(1<<8) + n1
				offset++
			}

			data := bin[offset : offset+nbytes]
			if x := uint8(bin[offset+nbytes]); x != 0 {
				return fmt.Errorf("expected zero terminator, got 0x%02X for nchars=%v nbytes=%v data=%q", x, nchars, nbytes, data)
			}
			pl.strings[i] = string(data)
		}
	} else {
		for i := range pl.strings {
			offset := int(offstrings + btou32(bin[hdrlen+i*4:])) // read index of string

			// if leading bit set for nchars, treat first byte as 7-bit high word and next as low word.
			nchars := int(btou16(bin[offset:]))
			offset += 2
			if nchars&(1<<15) != 0 { // TODO(d) this is untested
				n0 := nchars ^ (1 << 15)        // high word
				n1 := int(btou16(bin[offset:])) // low word
				nchars = n0*(1<<16) + n1
				offset += 2
			}

			data := make([]uint16, nchars)
			for i := range data {
				data[i] = btou16(bin[offset+2*i:])
			}

			if x := btou16(bin[offset+nchars*2:]); x != 0 {
				return fmt.Errorf("expected zero terminator, got 0x%04X for nchars=%v data=%q", x, nchars, data)
			}
			pl.strings[i] = string(utf16.Decode(data))
		}
	}

	// TODO decode styles
	_ = offstyle

	return nil
}

func (pl *Pool) MarshalBinary() ([]byte, error) {
	if pl.IsUTF8() {
		return nil, fmt.Errorf("encode utf8 not supported")
	}

	var (
		hdrlen = 28
		// indices of string indices
		iis    = make([]uint32, len(pl.strings))
		iislen = len(iis) * 4
		// utf16 encoded strings concatenated together
		strs []uint16
	)
	for i, x := range pl.strings {
		if len(x)>>16 > 0 {
			panic(fmt.Errorf("string lengths over 1<<15 not yet supported, got len %d", len(x)))
		}
		p := utf16.Encode([]rune(x))
		if len(p) == 0 {
			strs = append(strs, 0x0000, 0x0000)
		} else {
			strs = append(strs, uint16(len(p))) // string length (implicitly includes zero terminator to follow)
			strs = append(strs, p...)
			strs = append(strs, 0) // zero terminated
		}
		// indices start at zero
		if i+1 != len(iis) {
			iis[i+1] = uint32(len(strs) * 2) // utf16 byte index
		}
	}

	// check strings is 4-byte aligned, pad with zeros if not.
	for x := (len(strs) * 2) % 4; x != 0; x -= 2 {
		strs = append(strs, 0x0000)
	}

	strslen := len(strs) * 2

	hdr := chunkHeader{
		typ:            ResStringPool,
		headerByteSize: 28,
		byteSize:       uint32(28 + iislen + strslen),
	}

	bin := make([]byte, hdr.byteSize)

	hdrbin, err := hdr.MarshalBinary()
	if err != nil {
		return nil, err
	}
	copy(bin, hdrbin)

	putu32(bin[8:], uint32(len(pl.strings)))
	putu32(bin[12:], uint32(len(pl.styles)))
	putu32(bin[16:], pl.flags)
	putu32(bin[20:], uint32(hdrlen+iislen))
	putu32(bin[24:], 0) // index of styles start, is 0 when styles length is 0

	buf := bin[28:]
	for _, x := range iis {
		putu32(buf, x)
		buf = buf[4:]
	}

	for _, x := range strs {
		putu16(buf, x)
		buf = buf[2:]
	}

	if len(buf) != 0 {
		panic(fmt.Errorf("failed to fill allocated buffer, %v bytes left over", len(buf)))
	}

	return bin, nil
}

type Span struct {
	name PoolRef

	firstChar, lastChar uint32
}

func (spn *Span) UnmarshalBinary(bin []byte) error {
	const end = 0xFFFFFFFF
	spn.name = PoolRef(btou32(bin))
	if spn.name == end {
		return nil
	}
	spn.firstChar = btou32(bin[4:])
	spn.lastChar = btou32(bin[8:])
	return nil
}

// Map contains a uint32 slice mapping strings in the string
// pool back to resource identifiers. The i'th element of the slice
// is also the same i'th element of the string pool.
type Map struct {
	chunkHeader
	rs []TableRef
}

func (m *Map) UnmarshalBinary(bin []byte) error {
	(&m.chunkHeader).UnmarshalBinary(bin)
	buf := bin[m.headerByteSize:m.byteSize]
	m.rs = make([]TableRef, len(buf)/4)
	for i := range m.rs {
		m.rs[i] = TableRef(btou32(buf[i*4:]))
	}
	return nil
}

func (m *Map) MarshalBinary() ([]byte, error) {
	m.typ = ResXMLResourceMap
	m.headerByteSize = 8
	m.byteSize = uint32(m.headerByteSize) + uint32(len(m.rs)*4)
	bin := make([]byte, m.byteSize)
	b, err := m.chunkHeader.MarshalBinary()
	if err != nil {
		return nil, err
	}
	copy(bin, b)
	for i, r := range m.rs {
		putu32(bin[8+i*4:], uint32(r))
	}
	return bin, nil
}
