webdav: implement COPY and MOVE.

Also add a -port flag to litmus_test_server.

13 of 13 copymove tests from the litmus suite pass, as does 16 of 16
basic tests.

Change-Id: Idf92cad281e15db7d4d62e28e366ea7bfa89e564
Reviewed-on: https://go-review.googlesource.com/3470
Reviewed-by: Nick Cooper <nmvc@google.com>
Reviewed-by: Robert Stepanek <robert.stepanek@gmail.com>
Reviewed-by: Nigel Tao <nigeltao@golang.org>
diff --git a/webdav/file.go b/webdav/file.go
index eb6abd5..4069f24 100644
--- a/webdav/file.go
+++ b/webdav/file.go
@@ -547,3 +547,88 @@
 	f.n.modTime = time.Now()
 	return lenp, nil
 }
+
+// copyFiles copies files and/or directories from src to dst.
+//
+// See section 9.8.5 for when various HTTP status codes apply.
+func copyFiles(fs FileSystem, src, dst string, overwrite bool, depth int, recursion int) (status int, err error) {
+	if recursion == 1000 {
+		return http.StatusInternalServerError, errRecursionTooDeep
+	}
+	recursion++
+
+	// TODO: section 9.8.3 says that "Note that an infinite-depth COPY of /A/
+	// into /A/B/ could lead to infinite recursion if not handled correctly."
+
+	srcFile, err := fs.OpenFile(src, os.O_RDONLY, 0)
+	if err != nil {
+		return http.StatusNotFound, err
+	}
+	defer srcFile.Close()
+	srcStat, err := srcFile.Stat()
+	if err != nil {
+		return http.StatusNotFound, err
+	}
+	srcPerm := srcStat.Mode() & os.ModePerm
+
+	created := false
+	if _, err := fs.Stat(dst); err != nil {
+		if os.IsNotExist(err) {
+			created = true
+		} else {
+			return http.StatusForbidden, err
+		}
+	} else {
+		if !overwrite {
+			return http.StatusPreconditionFailed, os.ErrExist
+		}
+		if err := fs.RemoveAll(dst); err != nil && !os.IsNotExist(err) {
+			return http.StatusForbidden, err
+		}
+	}
+
+	if srcStat.IsDir() {
+		if err := fs.Mkdir(dst, srcPerm); err != nil {
+			return http.StatusForbidden, err
+		}
+		if depth == infiniteDepth {
+			children, err := srcFile.Readdir(-1)
+			if err != nil {
+				return http.StatusForbidden, err
+			}
+			for _, c := range children {
+				name := c.Name()
+				s := path.Join(src, name)
+				d := path.Join(dst, name)
+				cStatus, cErr := copyFiles(fs, s, d, overwrite, depth, recursion)
+				if cErr != nil {
+					// TODO: MultiStatus.
+					return cStatus, cErr
+				}
+			}
+		}
+
+	} else {
+		dstFile, err := fs.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, srcPerm)
+		if err != nil {
+			if os.IsNotExist(err) {
+				return http.StatusConflict, err
+			}
+			return http.StatusForbidden, err
+
+		}
+		_, copyErr := io.Copy(dstFile, srcFile)
+		closeErr := dstFile.Close()
+		if copyErr != nil {
+			return http.StatusForbidden, copyErr
+		}
+		if closeErr != nil {
+			return http.StatusForbidden, closeErr
+		}
+	}
+
+	if created {
+		return http.StatusCreated, nil
+	}
+	return http.StatusNoContent, nil
+}
diff --git a/webdav/file_test.go b/webdav/file_test.go
index 6601fce..03bcb8d 100644
--- a/webdav/file_test.go
+++ b/webdav/file_test.go
@@ -335,7 +335,7 @@
 		"  stat /d/m want errNotExist",
 		"  stat /d/n want dir",
 		"  stat /d/n/q want 4",
-		"rename /d /d/n/x want err",
+		"rename /d /d/n/z want err",
 		"rename /c /d/n/q want ok",
 		"  stat /c want errNotExist",
 		"  stat /d/n/q want 2",
@@ -358,8 +358,50 @@
 		"rename /t / want err",
 		"rename /t /u/v want ok",
 		"  stat /u/v/r want 5",
-		"rename / /x want err",
+		"rename / /z want err",
 		"  find / /a /d /u /u/v /u/v/q /u/v/r",
+		"  stat /a want 1",
+		"  stat /b want errNotExist",
+		"  stat /c want errNotExist",
+		"  stat /u/v/r want 5",
+		"copy__ o=F d=0 /a /b want ok",
+		"copy__ o=T d=0 /a /c want ok",
+		"  stat /a want 1",
+		"  stat /b want 1",
+		"  stat /c want 1",
+		"  stat /u/v/r want 5",
+		"copy__ o=F d=0 /u/v/r /b want errExist",
+		"  stat /b want 1",
+		"copy__ o=T d=0 /u/v/r /b want ok",
+		"  stat /a want 1",
+		"  stat /b want 5",
+		"  stat /u/v/r want 5",
+		"rm-all /a want ok",
+		"rm-all /b want ok",
+		"mk-dir /u/v/w want ok",
+		"create /u/v/w/s SSSSSSSS want ok",
+		"  stat /d want dir",
+		"  stat /d/x want errNotExist",
+		"  stat /d/y want errNotExist",
+		"  stat /u/v/r want 5",
+		"  stat /u/v/w/s want 8",
+		"  find / /c /d /u /u/v /u/v/q /u/v/r /u/v/w /u/v/w/s",
+		"copy__ o=T d=0 /u/v /d/x want ok",
+		"copy__ o=T d=∞ /u/v /d/y want ok",
+		"rm-all /u want ok",
+		"  stat /d/x want dir",
+		"  stat /d/x/q want errNotExist",
+		"  stat /d/x/r want errNotExist",
+		"  stat /d/x/w want errNotExist",
+		"  stat /d/x/w/s want errNotExist",
+		"  stat /d/y want dir",
+		"  stat /d/y/q want 2",
+		"  stat /d/y/r want 5",
+		"  stat /d/y/w want dir",
+		"  stat /d/y/w/s want 8",
+		"  stat /u want errNotExist",
+		"  find / /c /d /d/x /d/y /d/y/q /d/y/r /d/y/w /d/y/w/s",
+		"copy__ o=F d=∞ /d/y /d/x want errExist",
 	}
 
 	for i, tc := range testCases {
@@ -403,9 +445,12 @@
 				t.Fatalf("test case #%d %q:\ngot  %s\nwant %s", i, tc, got, want)
 			}
 
-		case "mk-dir", "rename", "rm-all", "stat":
+		case "copy__", "mk-dir", "rename", "rm-all", "stat":
 			nParts := 3
-			if op == "rename" {
+			switch op {
+			case "copy__":
+				nParts = 6
+			case "rename":
 				nParts = 4
 			}
 			parts := strings.Split(arg, " ")
@@ -415,6 +460,15 @@
 
 			got, opErr := "", error(nil)
 			switch op {
+			case "copy__":
+				overwrite, depth := false, 0
+				if parts[0] == "o=T" {
+					overwrite = true
+				}
+				if parts[1] == "d=∞" {
+					depth = infiniteDepth
+				}
+				_, opErr = copyFiles(fs, parts[2], parts[3], overwrite, depth, 0)
 			case "mk-dir":
 				opErr = fs.Mkdir(parts[0], 0777)
 			case "rename":
diff --git a/webdav/litmus_test_server.go b/webdav/litmus_test_server.go
index 95df5e6..48ca718 100644
--- a/webdav/litmus_test_server.go
+++ b/webdav/litmus_test_server.go
@@ -18,6 +18,8 @@
 package main
 
 import (
+	"flag"
+	"fmt"
 	"log"
 	"net/http"
 	"net/url"
@@ -25,7 +27,10 @@
 	"golang.org/x/net/webdav"
 )
 
+var port = flag.Int("port", 9999, "server port")
+
 func main() {
+	flag.Parse()
 	http.Handle("/", &webdav.Handler{
 		FileSystem: webdav.NewMemFS(),
 		LockSystem: webdav.NewMemLS(),
@@ -36,15 +41,15 @@
 				if u, err := url.Parse(r.Header.Get("Destination")); err == nil {
 					dst = u.Path
 				}
-				ow := r.Header.Get("Overwrite")
-				log.Printf("%-8s%-25s%-25sow=%-2s%v", r.Method, r.URL.Path, dst, ow, err)
+				o := r.Header.Get("Overwrite")
+				log.Printf("%-10s%-25s%-25so=%-2s%v", r.Method, r.URL.Path, dst, o, err)
 			default:
-				log.Printf("%-8s%-30s%v", r.Method, r.URL.Path, err)
+				log.Printf("%-10s%-30s%v", r.Method, r.URL.Path, err)
 			}
 		},
 	})
 
-	const addr = ":9999"
+	addr := fmt.Sprintf(":%d", *port)
 	log.Printf("Serving %v", addr)
 	log.Fatal(http.ListenAndServe(addr, nil))
 }
diff --git a/webdav/webdav.go b/webdav/webdav.go
index 93d971f..501b6aa 100644
--- a/webdav/webdav.go
+++ b/webdav/webdav.go
@@ -11,6 +11,7 @@
 	"errors"
 	"io"
 	"net/http"
+	"net/url"
 	"os"
 	"time"
 )
@@ -37,8 +38,7 @@
 	} else if h.LockSystem == nil {
 		status, err = http.StatusInternalServerError, errNoLockSystem
 	} else {
-		// TODO: COPY, MOVE, PROPFIND, PROPPATCH methods.
-		// MOVE needs to enforce its Depth constraint. See the parseDepth comment.
+		// TODO: PROPFIND, PROPPATCH methods.
 		switch r.Method {
 		case "OPTIONS":
 			status, err = h.handleOptions(w, r)
@@ -50,6 +50,8 @@
 			status, err = h.handlePut(w, r)
 		case "MKCOL":
 			status, err = h.handleMkcol(w, r)
+		case "COPY", "MOVE":
+			status, err = h.handleCopyMove(w, r)
 		case "LOCK":
 			status, err = h.handleLock(w, r)
 		case "UNLOCK":
@@ -193,6 +195,91 @@
 	return http.StatusCreated, nil
 }
 
+func (h *Handler) handleCopyMove(w http.ResponseWriter, r *http.Request) (status int, err error) {
+	// TODO: COPY/MOVE for Properties, as per sections 9.8.2 and 9.9.1.
+
+	hdr := r.Header.Get("Destination")
+	if hdr == "" {
+		return http.StatusBadRequest, errInvalidDestination
+	}
+	u, err := url.Parse(hdr)
+	if err != nil {
+		return http.StatusBadRequest, errInvalidDestination
+	}
+	if u.Host != r.Host {
+		return http.StatusBadGateway, errInvalidDestination
+	}
+	// TODO: do we need a webdav.StripPrefix HTTP handler that's like the
+	// standard library's http.StripPrefix handler, but also strips the
+	// prefix in the Destination header?
+
+	dst, src := u.Path, r.URL.Path
+	if dst == src {
+		return http.StatusForbidden, errDestinationEqualsSource
+	}
+
+	// TODO: confirmLocks should also check dst.
+	releaser, status, err := h.confirmLocks(r)
+	if err != nil {
+		return status, err
+	}
+	defer releaser.Release()
+
+	if r.Method == "COPY" {
+		// Section 9.8.3 says that "The COPY method on a collection without a Depth
+		// header must act as if a Depth header with value "infinity" was included".
+		depth := infiniteDepth
+		if hdr := r.Header.Get("Depth"); hdr != "" {
+			depth = parseDepth(hdr)
+			if depth != 0 && depth != infiniteDepth {
+				// Section 9.8.3 says that "A client may submit a Depth header on a
+				// COPY on a collection with a value of "0" or "infinity"."
+				return http.StatusBadRequest, errInvalidDepth
+			}
+		}
+		return copyFiles(h.FileSystem, src, dst, r.Header.Get("Overwrite") != "F", depth, 0)
+	}
+
+	// Section 9.9.2 says that "The MOVE method on a collection must act as if
+	// a "Depth: infinity" header was used on it. A client must not submit a
+	// Depth header on a MOVE on a collection with any value but "infinity"."
+	if hdr := r.Header.Get("Depth"); hdr != "" {
+		if parseDepth(hdr) != infiniteDepth {
+			return http.StatusBadRequest, errInvalidDepth
+		}
+	}
+
+	created := false
+	if _, err := h.FileSystem.Stat(dst); err != nil {
+		if !os.IsNotExist(err) {
+			return http.StatusForbidden, err
+		}
+		created = true
+	} else {
+		switch r.Header.Get("Overwrite") {
+		case "T":
+			// Section 9.9.3 says that "If a resource exists at the destination
+			// and the Overwrite header is "T", then prior to performing the move,
+			// the server must perform a DELETE with "Depth: infinity" on the
+			// destination resource.
+			if err := h.FileSystem.RemoveAll(dst); err != nil {
+				return http.StatusForbidden, err
+			}
+		case "F":
+			return http.StatusPreconditionFailed, os.ErrExist
+		default:
+			return http.StatusBadRequest, errInvalidOverwrite
+		}
+	}
+	if err := h.FileSystem.Rename(src, dst); err != nil {
+		return http.StatusForbidden, err
+	}
+	if created {
+		return http.StatusCreated, nil
+	}
+	return http.StatusNoContent, nil
+}
+
 func (h *Handler) handleLock(w http.ResponseWriter, r *http.Request) (retStatus int, retErr error) {
 	duration, err := parseTimeout(r.Header.Get("Timeout"))
 	if err != nil {
@@ -308,7 +395,8 @@
 //
 // Different WebDAV methods have further constraints on valid depths:
 //	- PROPFIND has no further restrictions, as per section 9.1.
-//	- MOVE accepts only "infinity", as per section 9.2.2.
+//	- COPY accepts only "0" or "infinity", as per section 9.8.3.
+//	- MOVE accepts only "infinity", as per section 9.9.2.
 //	- LOCK accepts only "0" or "infinity", as per section 9.10.3.
 // These constraints are enforced by the handleXxx methods.
 func parseDepth(s string) int {
@@ -349,16 +437,20 @@
 }
 
 var (
-	errDirectoryNotEmpty   = errors.New("webdav: directory not empty")
-	errInvalidDepth        = errors.New("webdav: invalid depth")
-	errInvalidIfHeader     = errors.New("webdav: invalid If header")
-	errInvalidLockInfo     = errors.New("webdav: invalid lock info")
-	errInvalidLockToken    = errors.New("webdav: invalid lock token")
-	errInvalidPropfind     = errors.New("webdav: invalid propfind")
-	errInvalidResponse     = errors.New("webdav: invalid response")
-	errInvalidTimeout      = errors.New("webdav: invalid timeout")
-	errNoFileSystem        = errors.New("webdav: no file system")
-	errNoLockSystem        = errors.New("webdav: no lock system")
-	errNotADirectory       = errors.New("webdav: not a directory")
-	errUnsupportedLockInfo = errors.New("webdav: unsupported lock info")
+	errDestinationEqualsSource = errors.New("webdav: destination equals source")
+	errDirectoryNotEmpty       = errors.New("webdav: directory not empty")
+	errInvalidDepth            = errors.New("webdav: invalid depth")
+	errInvalidDestination      = errors.New("webdav: invalid destination")
+	errInvalidIfHeader         = errors.New("webdav: invalid If header")
+	errInvalidLockInfo         = errors.New("webdav: invalid lock info")
+	errInvalidLockToken        = errors.New("webdav: invalid lock token")
+	errInvalidOverwrite        = errors.New("webdav: invalid overwrite")
+	errInvalidPropfind         = errors.New("webdav: invalid propfind")
+	errInvalidResponse         = errors.New("webdav: invalid response")
+	errInvalidTimeout          = errors.New("webdav: invalid timeout")
+	errNoFileSystem            = errors.New("webdav: no file system")
+	errNoLockSystem            = errors.New("webdav: no lock system")
+	errNotADirectory           = errors.New("webdav: not a directory")
+	errRecursionTooDeep        = errors.New("webdav: recursion too deep")
+	errUnsupportedLockInfo     = errors.New("webdav: unsupported lock info")
 )