blob: f5d2fadfbb3ba6c547499de3390c8b2b00cd0ccf [file] [log] [blame]
Rob Pike3f5966d2010-08-03 08:04:33 +10001// Copyright 2010 The Go Authors. All rights reserved.
Brad Fitzpatrick719cde22010-07-28 11:30:00 -07002// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package io_test
6
7import (
8 . "io"
Rob Pike3f5966d2010-08-03 08:04:33 +10009 "bytes"
10 "crypto/sha1"
11 "fmt"
Brad Fitzpatrick719cde22010-07-28 11:30:00 -070012 "os"
13 "strings"
14 "testing"
15)
16
17func TestMultiReader(t *testing.T) {
18 var mr Reader
19 var buf []byte
20 nread := 0
21 withFooBar := func(tests func()) {
22 r1 := strings.NewReader("foo ")
23 r2 := strings.NewReader("bar")
24 mr = MultiReader(r1, r2)
25 buf = make([]byte, 20)
26 tests()
27 }
28 expectRead := func(size int, expected string, eerr os.Error) {
29 nread++
30 n, gerr := mr.Read(buf[0:size])
31 if n != len(expected) {
32 t.Errorf("#%d, expected %d bytes; got %d",
33 nread, len(expected), n)
34 }
35 got := string(buf[0:n])
36 if got != expected {
37 t.Errorf("#%d, expected %q; got %q",
38 nread, expected, got)
39 }
40 if gerr != eerr {
41 t.Errorf("#%d, expected error %v; got %v",
42 nread, eerr, gerr)
43 }
44 buf = buf[n:]
45 }
46 withFooBar(func() {
47 expectRead(2, "fo", nil)
48 expectRead(5, "o ", nil)
49 expectRead(5, "bar", nil)
50 expectRead(5, "", os.EOF)
51 })
52 withFooBar(func() {
53 expectRead(4, "foo ", nil)
54 expectRead(1, "b", nil)
55 expectRead(3, "ar", nil)
56 expectRead(1, "", os.EOF)
57 })
58 withFooBar(func() {
59 expectRead(5, "foo ", nil)
60 })
61}
Rob Pike3f5966d2010-08-03 08:04:33 +100062
63func TestMultiWriter(t *testing.T) {
64 sha1 := sha1.New()
65 sink := new(bytes.Buffer)
66 mw := MultiWriter(sha1, sink)
67
68 sourceString := "My input text."
69 source := strings.NewReader(sourceString)
70 written, err := Copy(mw, source)
71
72 if written != int64(len(sourceString)) {
73 t.Errorf("short write of %d, not %d", written, len(sourceString))
74 }
75
76 if err != nil {
77 t.Errorf("unexpected error: %v", err)
78 }
79
80 sha1hex := fmt.Sprintf("%x", sha1.Sum())
81 if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
82 t.Error("incorrect sha1 value")
83 }
84
85 if sink.String() != sourceString {
86 t.Error("expected %q; got %q", sourceString, sink.String())
87 }
88}