webdav: implement memLS.Confirm.

LockSystem.Confirm now takes two names (a src and a dst), and returns a
func() instead of a 1-method Releaser.

We now pass 25 of 30 "locks" litmus tests:
- 4 failures are due to PROPFIND / PROPPATCH not being implemented.
- 1 failure is due to shared locks being unsupported.
- 3 warnings are also due to PROPFIND / PROPPATCH.
- 1 failure (cond_put_corrupt_token) is due to returning 412
  (Precondition Failed) instead of 423 (Locked), but IIUC
  the spec says to return 412.
- 11 tests were skipped, presumably due to earlier failures.

Change-Id: I3f4c178cdc4b99c6acb7f59783b4fd9b94f606ec
Reviewed-on: https://go-review.googlesource.com/3860
Reviewed-by: Dave Cheney <dave@cheney.net>
diff --git a/webdav/lock.go b/webdav/lock.go
index 99eb77d..344ac5c 100644
--- a/webdav/lock.go
+++ b/webdav/lock.go
@@ -32,31 +32,26 @@
 	ETag  string
 }
 
-// Releaser releases previously confirmed lock claims.
-//
-// Calling Release does not unlock the lock, in the WebDAV UNLOCK sense, but
-// once LockSystem.Confirm has confirmed that a lock claim is valid, that lock
-// cannot be Confirmed again until it has been Released.
-type Releaser interface {
-	Release()
-}
-
 // LockSystem manages access to a collection of named resources. The elements
 // in a lock name are separated by slash ('/', U+002F) characters, regardless
 // of host operating system convention.
 type LockSystem interface {
 	// Confirm confirms that the caller can claim all of the locks specified by
 	// the given conditions, and that holding the union of all of those locks
-	// gives exclusive access to the named resource.
+	// gives exclusive access to all of the named resources. Up to two resources
+	// can be named. Empty names are ignored.
 	//
-	// Exactly one of r and err will be non-nil. If r is non-nil, all of the
-	// requested locks are held until r.Release is called.
+	// Exactly one of release and err will be non-nil. If release is non-nil,
+	// all of the requested locks are held until release is called. Calling
+	// release does not unlock the lock, in the WebDAV UNLOCK sense, but once
+	// Confirm has confirmed that a lock claim is valid, that lock cannot be
+	// Confirmed again until it has been released.
 	//
 	// If Confirm returns ErrConfirmationFailed then the Handler will continue
 	// to try any other set of locks presented (a WebDAV HTTP request can
 	// present more than one set of locks). If it returns any other non-nil
 	// error, the Handler will write a "500 Internal Server Error" HTTP status.
-	Confirm(now time.Time, name string, conditions ...Condition) (r Releaser, err error)
+	Confirm(now time.Time, name0, name1 string, conditions ...Condition) (release func(), err error)
 
 	// Create creates a lock with the given depth, duration, owner and root
 	// (name). The depth will either be negative (meaning infinite) or zero.
@@ -149,14 +144,89 @@
 	}
 }
 
-func (m *memLS) Confirm(now time.Time, name string, conditions ...Condition) (Releaser, error) {
+func (m *memLS) Confirm(now time.Time, name0, name1 string, conditions ...Condition) (func(), error) {
 	m.mu.Lock()
 	defer m.mu.Unlock()
 	m.collectExpiredNodes(now)
-	name = slashClean(name)
 
-	// TODO: touch n.held.
-	panic("TODO")
+	var n0, n1 *memLSNode
+	if name0 != "" {
+		if n0 = m.lookup(slashClean(name0), conditions...); n0 == nil {
+			return nil, ErrConfirmationFailed
+		}
+	}
+	if name1 != "" {
+		if n1 = m.lookup(slashClean(name1), conditions...); n1 == nil {
+			return nil, ErrConfirmationFailed
+		}
+	}
+
+	// Don't hold the same node twice.
+	if n1 == n0 {
+		n1 = nil
+	}
+
+	if n0 != nil {
+		m.hold(n0)
+	}
+	if n1 != nil {
+		m.hold(n1)
+	}
+	return func() {
+		m.mu.Lock()
+		defer m.mu.Unlock()
+		if n1 != nil {
+			m.unhold(n1)
+		}
+		if n0 != nil {
+			m.unhold(n0)
+		}
+	}, nil
+}
+
+// lookup returns the node n that locks the named resource, provided that n
+// matches at least one of the given conditions and that lock isn't held by
+// another party. Otherwise, it returns nil.
+//
+// n may be a parent of the named resource, if n is an infinite depth lock.
+func (m *memLS) lookup(name string, conditions ...Condition) (n *memLSNode) {
+	// TODO: support Condition.Not and Condition.ETag.
+	for _, c := range conditions {
+		n = m.byToken[c.Token]
+		if n == nil || n.held {
+			continue
+		}
+		if name == n.details.Root {
+			return n
+		}
+		if n.details.ZeroDepth {
+			continue
+		}
+		if n.details.Root == "/" || strings.HasPrefix(name, n.details.Root+"/") {
+			return n
+		}
+	}
+	return nil
+}
+
+func (m *memLS) hold(n *memLSNode) {
+	if n.held {
+		panic("webdav: memLS inconsistent held state")
+	}
+	n.held = true
+	if n.details.Duration >= 0 && n.byExpiryIndex >= 0 {
+		heap.Remove(&m.byExpiry, n.byExpiryIndex)
+	}
+}
+
+func (m *memLS) unhold(n *memLSNode) {
+	if !n.held {
+		panic("webdav: memLS inconsistent held state")
+	}
+	n.held = false
+	if n.details.Duration >= 0 {
+		heap.Push(&m.byExpiry, n)
+	}
 }
 
 func (m *memLS) Create(now time.Time, details LockDetails) (string, error) {
diff --git a/webdav/lock_test.go b/webdav/lock_test.go
index 7e22672..116d6c0 100644
--- a/webdav/lock_test.go
+++ b/webdav/lock_test.go
@@ -61,7 +61,7 @@
 }
 
 var lockTestDurations = []time.Duration{
-	-1,              // A negative duration means to never expire.
+	infiniteTimeout, // infiniteTimeout means to never expire.
 	0,               // A zero duration means to expire immediately.
 	100 * time.Hour, // A very large duration will not expire in these tests.
 }
@@ -102,7 +102,7 @@
 	for _, name := range lockTestNames {
 		_, err := m.Create(now, LockDetails{
 			Root:      name,
-			Duration:  -1,
+			Duration:  infiniteTimeout,
 			ZeroDepth: lockTestZeroDepth(name),
 		})
 		if err != nil {
@@ -155,6 +155,147 @@
 	check(0, "/")
 }
 
+func TestMemLSLookup(t *testing.T) {
+	now := time.Unix(0, 0)
+	m := NewMemLS().(*memLS)
+
+	badToken := m.nextToken()
+	t.Logf("badToken=%q", badToken)
+
+	for _, name := range lockTestNames {
+		token, err := m.Create(now, LockDetails{
+			Root:      name,
+			Duration:  infiniteTimeout,
+			ZeroDepth: lockTestZeroDepth(name),
+		})
+		if err != nil {
+			t.Fatalf("creating lock for %q: %v", name, err)
+		}
+		t.Logf("%-15q -> node=%p token=%q", name, m.byName[name], token)
+	}
+
+	baseNames := append([]string{"/a", "/b/c"}, lockTestNames...)
+	for _, baseName := range baseNames {
+		for _, suffix := range []string{"", "/0", "/1/2/3"} {
+			name := baseName + suffix
+
+			goodToken := ""
+			base := m.byName[baseName]
+			if base != nil && (suffix == "" || !lockTestZeroDepth(baseName)) {
+				goodToken = base.token
+			}
+
+			for _, token := range []string{badToken, goodToken} {
+				if token == "" {
+					continue
+				}
+
+				got := m.lookup(name, Condition{Token: token})
+				want := base
+				if token == badToken {
+					want = nil
+				}
+				if got != want {
+					t.Errorf("name=%-20qtoken=%q (bad=%t): got %p, want %p",
+						name, token, token == badToken, got, want)
+				}
+			}
+		}
+	}
+}
+
+func TestMemLSConfirm(t *testing.T) {
+	now := time.Unix(0, 0)
+	m := NewMemLS().(*memLS)
+	alice, err := m.Create(now, LockDetails{
+		Root:      "/alice",
+		Duration:  infiniteTimeout,
+		ZeroDepth: false,
+	})
+	tweedle, err := m.Create(now, LockDetails{
+		Root:      "/tweedle",
+		Duration:  infiniteTimeout,
+		ZeroDepth: false,
+	})
+	if err != nil {
+		t.Fatalf("Create: %v", err)
+	}
+	if err := m.consistent(); err != nil {
+		t.Fatalf("Create: inconsistent state: %v", err)
+	}
+
+	// Test a mismatch between name and condition.
+	_, err = m.Confirm(now, "/tweedle/dee", "", Condition{Token: alice})
+	if err != ErrConfirmationFailed {
+		t.Fatalf("Confirm (mismatch): got %v, want ErrConfirmationFailed", err)
+	}
+	if err := m.consistent(); err != nil {
+		t.Fatalf("Confirm (mismatch): inconsistent state: %v", err)
+	}
+
+	// Test two names (that fall under the same lock) in the one Confirm call.
+	release, err := m.Confirm(now, "/tweedle/dee", "/tweedle/dum", Condition{Token: tweedle})
+	if err != nil {
+		t.Fatalf("Confirm (twins): %v", err)
+	}
+	if err := m.consistent(); err != nil {
+		t.Fatalf("Confirm (twins): inconsistent state: %v", err)
+	}
+	release()
+	if err := m.consistent(); err != nil {
+		t.Fatalf("release (twins): inconsistent state: %v", err)
+	}
+
+	// Test the same two names in overlapping Confirm / release calls.
+	releaseDee, err := m.Confirm(now, "/tweedle/dee", "", Condition{Token: tweedle})
+	if err != nil {
+		t.Fatalf("Confirm (sequence #0): %v", err)
+	}
+	if err := m.consistent(); err != nil {
+		t.Fatalf("Confirm (sequence #0): inconsistent state: %v", err)
+	}
+
+	_, err = m.Confirm(now, "/tweedle/dum", "", Condition{Token: tweedle})
+	if err != ErrConfirmationFailed {
+		t.Fatalf("Confirm (sequence #1): got %v, want ErrConfirmationFailed", err)
+	}
+	if err := m.consistent(); err != nil {
+		t.Fatalf("Confirm (sequence #1): inconsistent state: %v", err)
+	}
+
+	releaseDee()
+	if err := m.consistent(); err != nil {
+		t.Fatalf("release (sequence #2): inconsistent state: %v", err)
+	}
+
+	releaseDum, err := m.Confirm(now, "/tweedle/dum", "", Condition{Token: tweedle})
+	if err != nil {
+		t.Fatalf("Confirm (sequence #3): %v", err)
+	}
+	if err := m.consistent(); err != nil {
+		t.Fatalf("Confirm (sequence #3): inconsistent state: %v", err)
+	}
+
+	// Test that you can't unlock a held lock.
+	err = m.Unlock(now, tweedle)
+	if err != ErrLocked {
+		t.Fatalf("Unlock (sequence #4): got %v, want ErrLocked", err)
+	}
+
+	releaseDum()
+	if err := m.consistent(); err != nil {
+		t.Fatalf("release (sequence #5): inconsistent state: %v", err)
+	}
+
+	err = m.Unlock(now, tweedle)
+	if err != nil {
+		t.Fatalf("Unlock (sequence #6): %v", err)
+	}
+	if err := m.consistent(); err != nil {
+		t.Fatalf("Unlock (sequence #6): inconsistent state: %v", err)
+	}
+}
+
 func TestMemLSNonCanonicalRoot(t *testing.T) {
 	now := time.Unix(0, 0)
 	m := NewMemLS().(*memLS)
@@ -304,29 +445,43 @@
 	}
 }
 
-func TestMemLSCreateRefreshUnlock(t *testing.T) {
+func TestMemLS(t *testing.T) {
 	now := time.Unix(0, 0)
 	m := NewMemLS().(*memLS)
 	rng := rand.New(rand.NewSource(0))
 	tokens := map[string]string{}
-	nCreate, nRefresh, nUnlock := 0, 0, 0
+	nConfirm, nCreate, nRefresh, nUnlock := 0, 0, 0, 0
 	const N = 2000
 
 	for i := 0; i < N; i++ {
 		name := lockTestNames[rng.Intn(len(lockTestNames))]
 		duration := lockTestDurations[rng.Intn(len(lockTestDurations))]
-		unlocked := false
+		confirmed, unlocked := false, false
 
-		// If the name was already locked, there's a 50-50 chance that
-		// we refresh or unlock it. Otherwise, we create a lock.
+		// If the name was already locked, we randomly confirm/release, refresh
+		// or unlock it. Otherwise, we create a lock.
 		token := tokens[name]
 		if token != "" {
-			if rng.Intn(2) == 0 {
+			switch rng.Intn(3) {
+			case 0:
+				confirmed = true
+				nConfirm++
+				release, err := m.Confirm(now, name, "", Condition{Token: token})
+				if err != nil {
+					t.Fatalf("iteration #%d: Confirm %q: %v", i, name, err)
+				}
+				if err := m.consistent(); err != nil {
+					t.Fatalf("iteration #%d: inconsistent state: %v", i, err)
+				}
+				release()
+
+			case 1:
 				nRefresh++
 				if _, err := m.Refresh(now, token, duration); err != nil {
 					t.Fatalf("iteration #%d: Refresh %q: %v", i, name, err)
 				}
-			} else {
+
+			case 2:
 				unlocked = true
 				nUnlock++
 				if err := m.Unlock(now, token); err != nil {
@@ -347,12 +502,14 @@
 			}
 		}
 
-		if duration == 0 || unlocked {
-			// A zero-duration lock should expire immediately and is
-			// effectively equivalent to being unlocked.
-			tokens[name] = ""
-		} else {
-			tokens[name] = token
+		if !confirmed {
+			if duration == 0 || unlocked {
+				// A zero-duration lock should expire immediately and is
+				// effectively equivalent to being unlocked.
+				tokens[name] = ""
+			} else {
+				tokens[name] = token
+			}
 		}
 
 		if err := m.consistent(); err != nil {
@@ -360,6 +517,9 @@
 		}
 	}
 
+	if nConfirm < N/10 {
+		t.Fatalf("too few Confirm calls: got %d, want >= %d", nConfirm, N/10)
+	}
 	if nCreate < N/10 {
 		t.Fatalf("too few Create calls: got %d, want >= %d", nCreate, N/10)
 	}
@@ -464,6 +624,11 @@
 		if _, ok := m.byName[n.details.Root]; !ok {
 			return fmt.Errorf("node at name %q in m.byExpiry but not in m.byName", n.details.Root)
 		}
+
+		// No node in m.byExpiry should be held.
+		if n.held {
+			return fmt.Errorf("node at name %q in m.byExpiry is held", n.details.Root)
+		}
 	}
 	return nil
 }
diff --git a/webdav/webdav.go b/webdav/webdav.go
index 34a872c..45484b6 100644
--- a/webdav/webdav.go
+++ b/webdav/webdav.go
@@ -70,34 +70,88 @@
 	}
 }
 
-type nopReleaser struct{}
+func (h *Handler) lock(now time.Time, root string) (token string, status int, err error) {
+	token, err = h.LockSystem.Create(now, LockDetails{
+		Root:      root,
+		Duration:  infiniteTimeout,
+		ZeroDepth: true,
+	})
+	if err != nil {
+		if err == ErrLocked {
+			return "", StatusLocked, err
+		}
+		return "", http.StatusInternalServerError, err
+	}
+	return token, 0, nil
+}
 
-func (nopReleaser) Release() {}
-
-func (h *Handler) confirmLocks(r *http.Request) (releaser Releaser, status int, err error) {
+func (h *Handler) confirmLocks(r *http.Request, src, dst string) (release func(), status int, err error) {
 	hdr := r.Header.Get("If")
 	if hdr == "" {
-		return nopReleaser{}, 0, nil
+		// An empty If header means that the client hasn't previously created locks.
+		// Even if this client doesn't care about locks, we still need to check that
+		// the resources aren't locked by another client, so we create temporary
+		// locks that would conflict with another client's locks. These temporary
+		// locks are unlocked at the end of the HTTP request.
+		now, srcToken, dstToken := time.Now(), "", ""
+		if src != "" {
+			srcToken, status, err = h.lock(now, src)
+			if err != nil {
+				return nil, status, err
+			}
+		}
+		if dst != "" {
+			dstToken, status, err = h.lock(now, dst)
+			if err != nil {
+				if srcToken != "" {
+					h.LockSystem.Unlock(now, srcToken)
+				}
+				return nil, status, err
+			}
+		}
+
+		return func() {
+			if dstToken != "" {
+				h.LockSystem.Unlock(now, dstToken)
+			}
+			if srcToken != "" {
+				h.LockSystem.Unlock(now, srcToken)
+			}
+		}, 0, nil
 	}
+
 	ih, ok := parseIfHeader(hdr)
 	if !ok {
 		return nil, http.StatusBadRequest, errInvalidIfHeader
 	}
 	// ih is a disjunction (OR) of ifLists, so any ifList will do.
 	for _, l := range ih.lists {
-		path := l.resourceTag
-		if path == "" {
-			path = r.URL.Path
+		lsrc := l.resourceTag
+		if lsrc == "" {
+			lsrc = src
+		} else {
+			u, err := url.Parse(lsrc)
+			if err != nil {
+				continue
+			}
+			if u.Host != r.Host {
+				continue
+			}
+			lsrc = u.Path
 		}
-		releaser, err = h.LockSystem.Confirm(time.Now(), path, l.conditions...)
+		release, err = h.LockSystem.Confirm(time.Now(), lsrc, dst, l.conditions...)
 		if err == ErrConfirmationFailed {
 			continue
 		}
 		if err != nil {
 			return nil, http.StatusInternalServerError, err
 		}
-		return releaser, 0, nil
+		return release, 0, nil
 	}
+	// Section 10.4.1 says that "If this header is evaluated and all state lists
+	// fail, then the request must fail with a 412 (Precondition Failed) status."
+	// We follow the spec even though the cond_put_corrupt_token test case from
+	// the litmus test warns on seeing a 412 instead of a 423 (Locked).
 	return nil, http.StatusPreconditionFailed, ErrLocked
 }
 
@@ -134,11 +188,11 @@
 }
 
 func (h *Handler) handleDelete(w http.ResponseWriter, r *http.Request) (status int, err error) {
-	releaser, status, err := h.confirmLocks(r)
+	release, status, err := h.confirmLocks(r, r.URL.Path, "")
 	if err != nil {
 		return status, err
 	}
-	defer releaser.Release()
+	defer release()
 
 	// TODO: return MultiStatus where appropriate.
 
@@ -158,11 +212,11 @@
 }
 
 func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int, err error) {
-	releaser, status, err := h.confirmLocks(r)
+	release, status, err := h.confirmLocks(r, r.URL.Path, "")
 	if err != nil {
 		return status, err
 	}
-	defer releaser.Release()
+	defer release()
 
 	f, err := h.FileSystem.OpenFile(r.URL.Path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
 	if err != nil {
@@ -176,11 +230,11 @@
 }
 
 func (h *Handler) handleMkcol(w http.ResponseWriter, r *http.Request) (status int, err error) {
-	releaser, status, err := h.confirmLocks(r)
+	release, status, err := h.confirmLocks(r, r.URL.Path, "")
 	if err != nil {
 		return status, err
 	}
-	defer releaser.Release()
+	defer release()
 
 	if r.ContentLength > 0 {
 		return http.StatusUnsupportedMediaType, nil
@@ -213,18 +267,25 @@
 	// prefix in the Destination header?
 
 	dst, src := u.Path, r.URL.Path
+	if dst == "" {
+		return http.StatusBadGateway, errInvalidDestination
+	}
 	if dst == src {
 		return http.StatusForbidden, errDestinationEqualsSource
 	}
 
-	// TODO: confirmLocks should also check dst.
-	releaser, status, err := h.confirmLocks(r)
-	if err != nil {
-		return status, err
-	}
-	defer releaser.Release()
-
 	if r.Method == "COPY" {
+		// Section 7.5.1 says that a COPY only needs to lock the destination,
+		// not both destination and source. Strictly speaking, this is racy,
+		// even though a COPY doesn't modify the source, if a concurrent
+		// operation modifies the source. However, the litmus test explicitly
+		// checks that COPYing a locked-by-another source is OK.
+		release, status, err := h.confirmLocks(r, "", dst)
+		if err != nil {
+			return status, err
+		}
+		defer release()
+
 		// Section 9.8.3 says that "The COPY method on a collection without a Depth
 		// header must act as if a Depth header with value "infinity" was included".
 		depth := infiniteDepth
@@ -239,6 +300,12 @@
 		return copyFiles(h.FileSystem, src, dst, r.Header.Get("Overwrite") != "F", depth, 0)
 	}
 
+	release, status, err := h.confirmLocks(r, src, dst)
+	if err != nil {
+		return status, err
+	}
+	defer release()
+
 	// Section 9.9.2 says that "The MOVE method on a collection must act as if
 	// a "Depth: infinity" header was used on it. A client must not submit a
 	// Depth header on a MOVE on a collection with any value but "infinity"."