| package test |
| |
| import ( |
| "bytes" |
| "crypto/rand" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "os" |
| "path/filepath" |
| "reflect" |
| "runtime" |
| "strconv" |
| "strings" |
| |
| "github.com/chzyer/logex" |
| ) |
| |
| var ( |
| mainRoot = "" |
| RootPath = os.TempDir() |
| ErrNotExcept = logex.Define("result not expect") |
| ErrNotEqual = logex.Define("result not equals") |
| ErrRequireNotEqual = logex.Define("result require not equals") |
| StrNotSuchFile = "no such file or directory" |
| ) |
| |
| func init() { |
| println("tmpdir:", RootPath) |
| } |
| |
| type testException struct { |
| depth int |
| info string |
| } |
| |
| func getMainRoot() string { |
| if mainRoot != "" { |
| return mainRoot |
| } |
| |
| cwd, err := os.Getwd() |
| if err != nil { |
| return "" |
| } |
| |
| for len(cwd) > 1 { |
| _, err := os.Stat(filepath.Join(cwd, ".git")) |
| if err == nil { |
| mainRoot = cwd + string([]rune{filepath.Separator}) |
| break |
| } |
| cwd = filepath.Dir(cwd) |
| } |
| return mainRoot |
| } |
| |
| func Skip() { |
| panic(nil) |
| } |
| |
| type Failer interface { |
| FailNow() |
| } |
| |
| func New(t Failer) { |
| err := recover() |
| if err == nil { |
| return |
| } |
| te, ok := err.(*testException) |
| if !ok { |
| panic(err) |
| } |
| |
| _, file, line, _ := runtime.Caller(5 + te.depth) |
| if strings.HasPrefix(file, getMainRoot()) { |
| file = file[len(getMainRoot()):] |
| } |
| println(fmt.Sprintf("%s:%d: %s", file, line, te.info)) |
| t.FailNow() |
| } |
| |
| func getErr(def error, e []error) error { |
| if len(e) == 0 { |
| return def |
| } |
| return e[0] |
| } |
| |
| func ReadAt(r io.ReaderAt, b []byte, at int64) { |
| n, err := r.ReadAt(b, at) |
| if err != nil { |
| Panic(0, fmt.Errorf("ReadAt error: %v", err)) |
| } |
| if n != len(b) { |
| Panic(0, fmt.Errorf("ReadAt short read: %v, want: %v", n, len(b))) |
| } |
| } |
| |
| func ReadAndCheck(r io.Reader, b []byte) { |
| buf := make([]byte, len(b)) |
| Read(r, buf) |
| equalBytes(1, buf, b) |
| } |
| |
| func Read(r io.Reader, b []byte) { |
| n, err := r.Read(b) |
| if err != nil && !logex.Equal(err, io.EOF) { |
| Panic(0, fmt.Errorf("Read error: %v", err)) |
| } |
| if n != len(b) { |
| Panic(0, fmt.Errorf("Read: %v, want: %v", n, len(b))) |
| } |
| } |
| |
| func ReadStringAt(r io.ReaderAt, off int64, s string) { |
| buf := make([]byte, len(s)) |
| n, err := r.ReadAt(buf, off) |
| buf = buf[:n] |
| if err != nil { |
| Panic(0, fmt.Errorf("ReadStringAt: %v", err)) |
| } |
| if string(buf) != s { |
| Panic(0, fmt.Errorf( |
| "ReadStringAt not match: %v, got: %v", |
| strconv.Quote(s), |
| strconv.Quote(string(buf)), |
| )) |
| } |
| } |
| |
| func ReadString(r io.Reader, s string) { |
| buf := make([]byte, len(s)) |
| n, err := r.Read(buf) |
| if err != nil && !logex.Equal(err, io.EOF) { |
| Panic(0, fmt.Errorf("ReadString: %v, got: %v", strconv.Quote(s), err)) |
| } |
| if n != len(buf) { |
| Panic(0, fmt.Errorf("ReadString: %v, got: %v", strconv.Quote(s), n)) |
| } |
| if string(buf) != s { |
| Panic(0, fmt.Errorf( |
| "ReadString not match: %v, got: %v", |
| strconv.Quote(s), |
| strconv.Quote(string(buf)), |
| )) |
| } |
| } |
| |
| func WriteAt(w io.WriterAt, b []byte, at int64) { |
| n, err := w.WriteAt(b, at) |
| if err != nil { |
| Panic(0, err) |
| } |
| if n != len(b) { |
| Panic(0, "short write") |
| } |
| } |
| |
| func Write(w io.Writer, b []byte) { |
| n, err := w.Write(b) |
| if err != nil { |
| Panic(0, err) |
| } |
| if n != len(b) { |
| Panic(0, "short write") |
| } |
| } |
| |
| func WriteString(w io.Writer, s string) { |
| n, err := w.Write([]byte(s)) |
| if err != nil { |
| Panic(0, err) |
| } |
| if n != len(s) { |
| Panic(0, "short write") |
| } |
| } |
| |
| func Equals(o ...interface{}) { |
| if len(o)%2 != 0 { |
| Panic(0, "invalid Equals arguments") |
| } |
| for i := 0; i < len(o); i += 2 { |
| equal(1, o[i], o[i+1], nil) |
| } |
| } |
| |
| func NotEqual(a, b interface{}, e ...error) { |
| notEqual(1, a, b, e) |
| } |
| |
| func toInt(a interface{}) (int64, bool) { |
| switch n := a.(type) { |
| case int: |
| return int64(n), true |
| case int8: |
| return int64(n), true |
| case int16: |
| return int64(n), true |
| case int32: |
| return int64(n), true |
| case int64: |
| return int64(n), true |
| case uintptr: |
| return int64(n), true |
| default: |
| return -1, false |
| } |
| } |
| |
| func MarkLine() { |
| r := strings.Repeat("-", 20) |
| println(r) |
| } |
| |
| var globalMarkInfo string |
| |
| func Mark(obj ...interface{}) { |
| globalMarkInfo = fmt.Sprint(obj...) |
| } |
| |
| func EqualBytes(got, want []byte) { |
| equalBytes(0, got, want) |
| } |
| |
| func equalBytes(n int, got, want []byte) { |
| a := got |
| b := want |
| size := 16 |
| if len(a) != len(b) { |
| Panic(n, fmt.Sprintf("equal bytes, %v != %v", len(a), len(b))) |
| } |
| if bytes.Equal(a, b) { |
| return |
| } |
| |
| for off := 0; off < len(a); off += size { |
| end := off + size |
| if end > len(a) { |
| end = len(a) |
| } |
| if !bytes.Equal(a[off:end], b[off:end]) { |
| Panic(n, fmt.Sprintf( |
| "equal [%v]byte in [%v, %v]:\n\tgot: %v\n\twant: %v", |
| len(a), |
| off, off+size, |
| a[off:end], b[off:end], |
| )) |
| } |
| } |
| } |
| |
| func Equal(a, b interface{}, e ...error) { |
| if ai, ok := toInt(a); ok { |
| if bi, ok := toInt(b); ok { |
| equal(1, ai, bi, e) |
| return |
| } |
| } |
| equal(1, a, b, e) |
| } |
| |
| func CheckError(e error, s string) { |
| if e == nil { |
| Panic(0, ErrNotExcept) |
| } |
| if !strings.Contains(e.Error(), s) { |
| Panic(0, fmt.Errorf( |
| "want: %s, got %s", |
| strconv.Quote(s), |
| strconv.Quote(e.Error()), |
| )) |
| } |
| } |
| |
| func formatMax(o interface{}, max int) string { |
| aStr := fmt.Sprint(o) |
| if len(aStr) > max { |
| aStr = aStr[:max] + " ..." |
| } |
| return aStr |
| } |
| |
| func notEqual(d int, a, b interface{}, e []error) { |
| _, oka := a.(error) |
| _, okb := b.(error) |
| if oka && okb { |
| if logex.Equal(a.(error), b.(error)) { |
| Panic(d, fmt.Sprintf("%v: %v", |
| getErr(ErrRequireNotEqual, e), |
| a, |
| )) |
| } |
| return |
| } |
| if reflect.DeepEqual(a, b) { |
| Panic(d, fmt.Sprintf("%v: (%v, %v)", |
| getErr(ErrRequireNotEqual, e), |
| formatMax(a, 100), |
| formatMax(b, 100), |
| )) |
| } |
| } |
| |
| func equal(d int, a, b interface{}, e []error) { |
| _, oka := a.(error) |
| _, okb := b.(error) |
| if oka && okb { |
| if !logex.Equal(a.(error), b.(error)) { |
| Panic(d, fmt.Sprintf("%v: (%v, %v)", |
| getErr(ErrNotEqual, e), |
| formatMax(a, 100), formatMax(b, 100), |
| )) |
| } |
| return |
| } |
| if !reflect.DeepEqual(a, b) { |
| Panic(d, fmt.Sprintf("%v: (%+v, %+v)", getErr(ErrNotEqual, e), a, b)) |
| } |
| } |
| |
| func Should(b bool, e ...error) { |
| if !b { |
| Panic(0, getErr(ErrNotExcept, e)) |
| } |
| } |
| |
| func NotNil(obj interface{}) { |
| if obj == nil { |
| Panic(0, "should not nil") |
| } |
| } |
| |
| func False(obj bool) { |
| if obj { |
| Panic(0, "should false") |
| } |
| } |
| |
| func True(obj bool) { |
| if !obj { |
| Panic(0, "should true") |
| } |
| } |
| |
| func Nil(obj interface{}) { |
| if obj != nil { |
| // double check, incase different type with nil value |
| if !reflect.ValueOf(obj).IsNil() { |
| str := fmt.Sprint(obj) |
| if err, ok := obj.(error); ok { |
| str = logex.DecodeError(err) |
| } |
| Panic(0, fmt.Sprintf("should nil: %v", str)) |
| } |
| } |
| } |
| |
| func Panic(depth int, obj interface{}) { |
| t := &testException{ |
| depth: depth, |
| } |
| if err, ok := obj.(error); ok { |
| t.info = logex.DecodeError(err) |
| } else { |
| t.info = fmt.Sprint(obj) |
| } |
| if globalMarkInfo != "" { |
| t.info = "[info:" + globalMarkInfo + "] " + t.info |
| } |
| panic(t) |
| } |
| |
| func CleanTmp() { |
| os.RemoveAll(root(2)) |
| } |
| |
| func TmpFile() (*os.File, error) { |
| dir := root(2) |
| if err := os.MkdirAll(dir, 0744); err != nil { |
| return nil, err |
| } |
| return ioutil.TempFile(dir, "") |
| } |
| |
| func Root() string { |
| p := root(2) |
| os.RemoveAll(root(2)) |
| return p |
| } |
| |
| func root(n int) string { |
| pc, _, _, _ := runtime.Caller(n) |
| name := runtime.FuncForPC(pc).Name() |
| if idx := strings.LastIndex(name, "."); idx > 0 { |
| name = name[:idx] + "/" + name[idx+1:] |
| } |
| |
| root := os.Getenv("TEST_ROOT") |
| if root == "" { |
| root = RootPath |
| } |
| return filepath.Join(root, name) |
| } |
| |
| func RandBytes(n int) []byte { |
| buf := make([]byte, n) |
| rand.Read(buf) |
| return buf |
| } |
| |
| func SeqBytes(n int) []byte { |
| buf := make([]byte, n) |
| for idx := range buf { |
| buf[idx] = byte(idx) |
| } |
| return buf |
| } |