blob: aef573aaf2a7746ca0e89f0adf9785b31793dd4a [file] [log] [blame]
package test
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"github.com/chzyer/logex"
)
var (
mainRoot = ""
RootPath = os.TempDir()
ErrNotExcept = logex.Define("result not expect")
ErrNotEqual = logex.Define("result not equals")
ErrRequireNotEqual = logex.Define("result require not equals")
StrNotSuchFile = "no such file or directory"
)
func init() {
println("tmpdir:", RootPath)
}
type testException struct {
depth int
info string
}
func getMainRoot() string {
if mainRoot != "" {
return mainRoot
}
cwd, err := os.Getwd()
if err != nil {
return ""
}
for len(cwd) > 1 {
_, err := os.Stat(filepath.Join(cwd, ".git"))
if err == nil {
mainRoot = cwd + string([]rune{filepath.Separator})
break
}
cwd = filepath.Dir(cwd)
}
return mainRoot
}
func Skip() {
panic(nil)
}
type Failer interface {
FailNow()
}
func New(t Failer) {
err := recover()
if err == nil {
return
}
te, ok := err.(*testException)
if !ok {
panic(err)
}
_, file, line, _ := runtime.Caller(5 + te.depth)
if strings.HasPrefix(file, getMainRoot()) {
file = file[len(getMainRoot()):]
}
println(fmt.Sprintf("%s:%d: %s", file, line, te.info))
t.FailNow()
}
func getErr(def error, e []error) error {
if len(e) == 0 {
return def
}
return e[0]
}
func ReadAt(r io.ReaderAt, b []byte, at int64) {
n, err := r.ReadAt(b, at)
if err != nil {
Panic(0, fmt.Errorf("ReadAt error: %v", err))
}
if n != len(b) {
Panic(0, fmt.Errorf("ReadAt short read: %v, want: %v", n, len(b)))
}
}
func ReadAndCheck(r io.Reader, b []byte) {
buf := make([]byte, len(b))
Read(r, buf)
equalBytes(1, buf, b)
}
func Read(r io.Reader, b []byte) {
n, err := r.Read(b)
if err != nil && !logex.Equal(err, io.EOF) {
Panic(0, fmt.Errorf("Read error: %v", err))
}
if n != len(b) {
Panic(0, fmt.Errorf("Read: %v, want: %v", n, len(b)))
}
}
func ReadStringAt(r io.ReaderAt, off int64, s string) {
buf := make([]byte, len(s))
n, err := r.ReadAt(buf, off)
buf = buf[:n]
if err != nil {
Panic(0, fmt.Errorf("ReadStringAt: %v", err))
}
if string(buf) != s {
Panic(0, fmt.Errorf(
"ReadStringAt not match: %v, got: %v",
strconv.Quote(s),
strconv.Quote(string(buf)),
))
}
}
func ReadString(r io.Reader, s string) {
buf := make([]byte, len(s))
n, err := r.Read(buf)
if err != nil && !logex.Equal(err, io.EOF) {
Panic(0, fmt.Errorf("ReadString: %v, got: %v", strconv.Quote(s), err))
}
if n != len(buf) {
Panic(0, fmt.Errorf("ReadString: %v, got: %v", strconv.Quote(s), n))
}
if string(buf) != s {
Panic(0, fmt.Errorf(
"ReadString not match: %v, got: %v",
strconv.Quote(s),
strconv.Quote(string(buf)),
))
}
}
func WriteAt(w io.WriterAt, b []byte, at int64) {
n, err := w.WriteAt(b, at)
if err != nil {
Panic(0, err)
}
if n != len(b) {
Panic(0, "short write")
}
}
func Write(w io.Writer, b []byte) {
n, err := w.Write(b)
if err != nil {
Panic(0, err)
}
if n != len(b) {
Panic(0, "short write")
}
}
func WriteString(w io.Writer, s string) {
n, err := w.Write([]byte(s))
if err != nil {
Panic(0, err)
}
if n != len(s) {
Panic(0, "short write")
}
}
func Equals(o ...interface{}) {
if len(o)%2 != 0 {
Panic(0, "invalid Equals arguments")
}
for i := 0; i < len(o); i += 2 {
equal(1, o[i], o[i+1], nil)
}
}
func NotEqual(a, b interface{}, e ...error) {
notEqual(1, a, b, e)
}
func toInt(a interface{}) (int64, bool) {
switch n := a.(type) {
case int:
return int64(n), true
case int8:
return int64(n), true
case int16:
return int64(n), true
case int32:
return int64(n), true
case int64:
return int64(n), true
case uintptr:
return int64(n), true
default:
return -1, false
}
}
func MarkLine() {
r := strings.Repeat("-", 20)
println(r)
}
var globalMarkInfo string
func Mark(obj ...interface{}) {
globalMarkInfo = fmt.Sprint(obj...)
}
func EqualBytes(got, want []byte) {
equalBytes(0, got, want)
}
func equalBytes(n int, got, want []byte) {
a := got
b := want
size := 16
if len(a) != len(b) {
Panic(n, fmt.Sprintf("equal bytes, %v != %v", len(a), len(b)))
}
if bytes.Equal(a, b) {
return
}
for off := 0; off < len(a); off += size {
end := off + size
if end > len(a) {
end = len(a)
}
if !bytes.Equal(a[off:end], b[off:end]) {
Panic(n, fmt.Sprintf(
"equal [%v]byte in [%v, %v]:\n\tgot: %v\n\twant: %v",
len(a),
off, off+size,
a[off:end], b[off:end],
))
}
}
}
func Equal(a, b interface{}, e ...error) {
if ai, ok := toInt(a); ok {
if bi, ok := toInt(b); ok {
equal(1, ai, bi, e)
return
}
}
equal(1, a, b, e)
}
func CheckError(e error, s string) {
if e == nil {
Panic(0, ErrNotExcept)
}
if !strings.Contains(e.Error(), s) {
Panic(0, fmt.Errorf(
"want: %s, got %s",
strconv.Quote(s),
strconv.Quote(e.Error()),
))
}
}
func formatMax(o interface{}, max int) string {
aStr := fmt.Sprint(o)
if len(aStr) > max {
aStr = aStr[:max] + " ..."
}
return aStr
}
func notEqual(d int, a, b interface{}, e []error) {
_, oka := a.(error)
_, okb := b.(error)
if oka && okb {
if logex.Equal(a.(error), b.(error)) {
Panic(d, fmt.Sprintf("%v: %v",
getErr(ErrRequireNotEqual, e),
a,
))
}
return
}
if reflect.DeepEqual(a, b) {
Panic(d, fmt.Sprintf("%v: (%v, %v)",
getErr(ErrRequireNotEqual, e),
formatMax(a, 100),
formatMax(b, 100),
))
}
}
func equal(d int, a, b interface{}, e []error) {
_, oka := a.(error)
_, okb := b.(error)
if oka && okb {
if !logex.Equal(a.(error), b.(error)) {
Panic(d, fmt.Sprintf("%v: (%v, %v)",
getErr(ErrNotEqual, e),
formatMax(a, 100), formatMax(b, 100),
))
}
return
}
if !reflect.DeepEqual(a, b) {
Panic(d, fmt.Sprintf("%v: (%+v, %+v)", getErr(ErrNotEqual, e), a, b))
}
}
func Should(b bool, e ...error) {
if !b {
Panic(0, getErr(ErrNotExcept, e))
}
}
func NotNil(obj interface{}) {
if obj == nil {
Panic(0, "should not nil")
}
}
func False(obj bool) {
if obj {
Panic(0, "should false")
}
}
func True(obj bool) {
if !obj {
Panic(0, "should true")
}
}
func Nil(obj interface{}) {
if obj != nil {
// double check, incase different type with nil value
if !reflect.ValueOf(obj).IsNil() {
str := fmt.Sprint(obj)
if err, ok := obj.(error); ok {
str = logex.DecodeError(err)
}
Panic(0, fmt.Sprintf("should nil: %v", str))
}
}
}
func Panic(depth int, obj interface{}) {
t := &testException{
depth: depth,
}
if err, ok := obj.(error); ok {
t.info = logex.DecodeError(err)
} else {
t.info = fmt.Sprint(obj)
}
if globalMarkInfo != "" {
t.info = "[info:" + globalMarkInfo + "] " + t.info
}
panic(t)
}
func CleanTmp() {
os.RemoveAll(root(2))
}
func TmpFile() (*os.File, error) {
dir := root(2)
if err := os.MkdirAll(dir, 0744); err != nil {
return nil, err
}
return ioutil.TempFile(dir, "")
}
func Root() string {
p := root(2)
os.RemoveAll(root(2))
return p
}
func root(n int) string {
pc, _, _, _ := runtime.Caller(n)
name := runtime.FuncForPC(pc).Name()
if idx := strings.LastIndex(name, "."); idx > 0 {
name = name[:idx] + "/" + name[idx+1:]
}
root := os.Getenv("TEST_ROOT")
if root == "" {
root = RootPath
}
return filepath.Join(root, name)
}
func RandBytes(n int) []byte {
buf := make([]byte, n)
rand.Read(buf)
return buf
}
func SeqBytes(n int) []byte {
buf := make([]byte, n)
for idx := range buf {
buf[idx] = byte(idx)
}
return buf
}