// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
	"fmt"

	. "github.com/mmcloughlin/avo/build"
	. "github.com/mmcloughlin/avo/gotypes"
	. "github.com/mmcloughlin/avo/operand"
	. "github.com/mmcloughlin/avo/reg"

	// Ensure "go mod tidy" doesn't remove the golang.org/x/crypto module
	// dependency, which is necessary to access the field.Element type.
	_ "golang.org/x/crypto/curve25519"
)

//go:generate go run . -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field

func main() {
	Package("golang.org/x/crypto/curve25519/internal/field")
	ConstraintExpr("amd64,gc,!purego")
	feMul()
	feSquare()
	Generate()
}

type namedComponent struct {
	Component
	name string
}

func (c namedComponent) String() string { return c.name }

type uint128 struct {
	name   string
	hi, lo GPVirtual
}

func (c uint128) String() string { return c.name }

func feSquare() {
	TEXT("feSquare", NOSPLIT, "func(out, a *Element)")
	Doc("feSquare sets out = a * a. It works like feSquareGeneric.")
	Pragma("noescape")

	a := Dereference(Param("a"))
	l0 := namedComponent{a.Field("l0"), "l0"}
	l1 := namedComponent{a.Field("l1"), "l1"}
	l2 := namedComponent{a.Field("l2"), "l2"}
	l3 := namedComponent{a.Field("l3"), "l3"}
	l4 := namedComponent{a.Field("l4"), "l4"}

	// r0 = l0×l0 + 19×2×(l1×l4 + l2×l3)
	r0 := uint128{"r0", GP64(), GP64()}
	mul64(r0, 1, l0, l0)
	addMul64(r0, 38, l1, l4)
	addMul64(r0, 38, l2, l3)

	// r1 = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
	r1 := uint128{"r1", GP64(), GP64()}
	mul64(r1, 2, l0, l1)
	addMul64(r1, 38, l2, l4)
	addMul64(r1, 19, l3, l3)

	// r2 = = 2×l0×l2 + l1×l1 + 19×2×l3×l4
	r2 := uint128{"r2", GP64(), GP64()}
	mul64(r2, 2, l0, l2)
	addMul64(r2, 1, l1, l1)
	addMul64(r2, 38, l3, l4)

	// r3 = = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
	r3 := uint128{"r3", GP64(), GP64()}
	mul64(r3, 2, l0, l3)
	addMul64(r3, 2, l1, l2)
	addMul64(r3, 19, l4, l4)

	// r4 = = 2×l0×l4 + 2×l1×l3 + l2×l2
	r4 := uint128{"r4", GP64(), GP64()}
	mul64(r4, 2, l0, l4)
	addMul64(r4, 2, l1, l3)
	addMul64(r4, 1, l2, l2)

	Comment("First reduction chain")
	maskLow51Bits := GP64()
	MOVQ(Imm((1<<51)-1), maskLow51Bits)
	c0, r0lo := shiftRightBy51(&r0)
	c1, r1lo := shiftRightBy51(&r1)
	c2, r2lo := shiftRightBy51(&r2)
	c3, r3lo := shiftRightBy51(&r3)
	c4, r4lo := shiftRightBy51(&r4)
	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
	maskAndAdd(r4lo, maskLow51Bits, c3, 1)

	Comment("Second reduction chain (carryPropagate)")
	// c0 = r0 >> 51
	MOVQ(r0lo, c0)
	SHRQ(Imm(51), c0)
	// c1 = r1 >> 51
	MOVQ(r1lo, c1)
	SHRQ(Imm(51), c1)
	// c2 = r2 >> 51
	MOVQ(r2lo, c2)
	SHRQ(Imm(51), c2)
	// c3 = r3 >> 51
	MOVQ(r3lo, c3)
	SHRQ(Imm(51), c3)
	// c4 = r4 >> 51
	MOVQ(r4lo, c4)
	SHRQ(Imm(51), c4)
	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
	maskAndAdd(r4lo, maskLow51Bits, c3, 1)

	Comment("Store output")
	out := Dereference(Param("out"))
	Store(r0lo, out.Field("l0"))
	Store(r1lo, out.Field("l1"))
	Store(r2lo, out.Field("l2"))
	Store(r3lo, out.Field("l3"))
	Store(r4lo, out.Field("l4"))

	RET()
}

func feMul() {
	TEXT("feMul", NOSPLIT, "func(out, a, b *Element)")
	Doc("feMul sets out = a * b. It works like feMulGeneric.")
	Pragma("noescape")

	a := Dereference(Param("a"))
	a0 := namedComponent{a.Field("l0"), "a0"}
	a1 := namedComponent{a.Field("l1"), "a1"}
	a2 := namedComponent{a.Field("l2"), "a2"}
	a3 := namedComponent{a.Field("l3"), "a3"}
	a4 := namedComponent{a.Field("l4"), "a4"}

	b := Dereference(Param("b"))
	b0 := namedComponent{b.Field("l0"), "b0"}
	b1 := namedComponent{b.Field("l1"), "b1"}
	b2 := namedComponent{b.Field("l2"), "b2"}
	b3 := namedComponent{b.Field("l3"), "b3"}
	b4 := namedComponent{b.Field("l4"), "b4"}

	// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
	r0 := uint128{"r0", GP64(), GP64()}
	mul64(r0, 1, a0, b0)
	addMul64(r0, 19, a1, b4)
	addMul64(r0, 19, a2, b3)
	addMul64(r0, 19, a3, b2)
	addMul64(r0, 19, a4, b1)

	// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
	r1 := uint128{"r1", GP64(), GP64()}
	mul64(r1, 1, a0, b1)
	addMul64(r1, 1, a1, b0)
	addMul64(r1, 19, a2, b4)
	addMul64(r1, 19, a3, b3)
	addMul64(r1, 19, a4, b2)

	// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
	r2 := uint128{"r2", GP64(), GP64()}
	mul64(r2, 1, a0, b2)
	addMul64(r2, 1, a1, b1)
	addMul64(r2, 1, a2, b0)
	addMul64(r2, 19, a3, b4)
	addMul64(r2, 19, a4, b3)

	// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
	r3 := uint128{"r3", GP64(), GP64()}
	mul64(r3, 1, a0, b3)
	addMul64(r3, 1, a1, b2)
	addMul64(r3, 1, a2, b1)
	addMul64(r3, 1, a3, b0)
	addMul64(r3, 19, a4, b4)

	// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
	r4 := uint128{"r4", GP64(), GP64()}
	mul64(r4, 1, a0, b4)
	addMul64(r4, 1, a1, b3)
	addMul64(r4, 1, a2, b2)
	addMul64(r4, 1, a3, b1)
	addMul64(r4, 1, a4, b0)

	Comment("First reduction chain")
	maskLow51Bits := GP64()
	MOVQ(Imm((1<<51)-1), maskLow51Bits)
	c0, r0lo := shiftRightBy51(&r0)
	c1, r1lo := shiftRightBy51(&r1)
	c2, r2lo := shiftRightBy51(&r2)
	c3, r3lo := shiftRightBy51(&r3)
	c4, r4lo := shiftRightBy51(&r4)
	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
	maskAndAdd(r4lo, maskLow51Bits, c3, 1)

	Comment("Second reduction chain (carryPropagate)")
	// c0 = r0 >> 51
	MOVQ(r0lo, c0)
	SHRQ(Imm(51), c0)
	// c1 = r1 >> 51
	MOVQ(r1lo, c1)
	SHRQ(Imm(51), c1)
	// c2 = r2 >> 51
	MOVQ(r2lo, c2)
	SHRQ(Imm(51), c2)
	// c3 = r3 >> 51
	MOVQ(r3lo, c3)
	SHRQ(Imm(51), c3)
	// c4 = r4 >> 51
	MOVQ(r4lo, c4)
	SHRQ(Imm(51), c4)
	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
	maskAndAdd(r4lo, maskLow51Bits, c3, 1)

	Comment("Store output")
	out := Dereference(Param("out"))
	Store(r0lo, out.Field("l0"))
	Store(r1lo, out.Field("l1"))
	Store(r2lo, out.Field("l2"))
	Store(r3lo, out.Field("l3"))
	Store(r4lo, out.Field("l4"))

	RET()
}

// mul64 sets r to i * aX * bX.
func mul64(r uint128, i int, aX, bX namedComponent) {
	switch i {
	case 1:
		Comment(fmt.Sprintf("%s = %s×%s", r, aX, bX))
		Load(aX, RAX)
	case 2:
		Comment(fmt.Sprintf("%s = 2×%s×%s", r, aX, bX))
		Load(aX, RAX)
		SHLQ(Imm(1), RAX)
	default:
		panic("unsupported i value")
	}
	MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
	MOVQ(RAX, r.lo)
	MOVQ(RDX, r.hi)
}

// addMul64 sets r to r + i * aX * bX.
func addMul64(r uint128, i uint64, aX, bX namedComponent) {
	switch i {
	case 1:
		Comment(fmt.Sprintf("%s += %s×%s", r, aX, bX))
		Load(aX, RAX)
	default:
		Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
		IMUL3Q(Imm(i), Load(aX, GP64()), RAX)
	}
	MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
	ADDQ(RAX, r.lo)
	ADCQ(RDX, r.hi)
}

// shiftRightBy51 returns r >> 51 and r.lo.
//
// After this function is called, the uint128 may not be used anymore.
func shiftRightBy51(r *uint128) (out, lo GPVirtual) {
	out = r.hi
	lo = r.lo
	SHLQ(Imm(64-51), r.lo, r.hi)
	r.lo, r.hi = nil, nil // make sure the uint128 is unusable
	return
}

// maskAndAdd sets r = r&mask + c*i.
func maskAndAdd(r, mask, c GPVirtual, i uint64) {
	ANDQ(mask, r)
	if i != 1 {
		IMUL3Q(Imm(i), c, c)
	}
	ADDQ(c, r)
}

func mustAddr(c Component) Op {
	b, err := c.Resolve()
	if err != nil {
		panic(err)
	}
	return b.Addr
}
