frontend: add tests and redirect to https on frontend

Change-Id: I1c116b0899a1ea37870a147277c4460366c2ac5d
Reviewed-on: https://go-review.googlesource.com/84976
Reviewed-by: Russ Cox <rsc@golang.org>
diff --git a/frontend/edit.go b/frontend/edit.go
index 5edd633..2e8809f 100644
--- a/frontend/edit.go
+++ b/frontend/edit.go
@@ -36,7 +36,6 @@
 			w.Write([]byte(`<h1>Unavailable For Legal Reasons</h1><p>Viewing and/or sharing code snippets is not available in your country for legal reasons. This message might also appear if your country is misdetected. If you believe this is an error, please <a href="https://golang.org/issue">file an issue</a>.</p>`))
 			return
 		}
-		ctx := r.Context()
 		id := r.URL.Path[3:]
 		serveText := false
 		if strings.HasSuffix(id, ".go") {
@@ -44,7 +43,7 @@
 			serveText = true
 		}
 
-		if err := s.db.GetSnippet(ctx, id, snip); err != nil {
+		if err := s.db.GetSnippet(r.Context(), id, snip); err != nil {
 			if err != datastore.ErrNoSuchEntity {
 				s.log.Errorf("loading Snippet: %v", err)
 			}
@@ -62,7 +61,15 @@
 			return
 		}
 	}
-	editTemplate.Execute(w, &editData{snip, allowShare(r)})
+	w.Header().Set("Content-Type", "text/html; charset=utf-8")
+	data := &editData{
+		Snippet: snip,
+		Share:   allowShare(r),
+	}
+	if err := editTemplate.Execute(w, data); err != nil {
+		s.log.Errorf("editTemplate.Execute(w, %+v): %v", data, err)
+		return
+	}
 }
 
 const hello = `package main
diff --git a/frontend/server.go b/frontend/server.go
index 881c6d7..136124c 100644
--- a/frontend/server.go
+++ b/frontend/server.go
@@ -74,7 +74,13 @@
 }
 
 func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	if os.Getenv("GAE_INSTANCE") != "" {
+	if r.Header.Get("X-Forwarded-Proto") == "http" {
+		r.URL.Scheme = "https"
+		r.URL.Host = r.Host
+		http.Redirect(w, r, r.URL.String(), http.StatusFound)
+		return
+	}
+	if r.Header.Get("X-Forwarded-Proto") == "https" {
 		w.Header().Set("Strict-Transport-Security", "max-age=31536000; preload")
 	}
 	s.mux.ServeHTTP(w, r)
diff --git a/frontend/server_test.go b/frontend/server_test.go
index 19bc261..4eec012 100644
--- a/frontend/server_test.go
+++ b/frontend/server_test.go
@@ -3,7 +3,14 @@
 // license that can be found in the LICENSE file.
 package main
 
-import "testing"
+import (
+	"bytes"
+	"fmt"
+	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+)
 
 type testLogger struct {
 	t *testing.T
@@ -28,4 +35,94 @@
 }
 
 func TestEdit(t *testing.T) {
+	s, err := newServer(testingOptions(t))
+	if err != nil {
+		t.Fatalf("newServer(testingOptions(t)): %v", err)
+	}
+	id := "bar"
+	barBody := []byte("Snippy McSnipface")
+	snip := &snippet{Body: barBody}
+	if err := s.db.PutSnippet(nil, id, snip); err != nil {
+		t.Fatalf("s.dbPutSnippet(nil, %+v, %+v): %v", id, snip, err)
+	}
+
+	testCases := []struct {
+		desc       string
+		url        string
+		statusCode int
+		headers    map[string]string
+		respBody   []byte
+	}{
+		{"foo.play.golang.org to play.golang.org", "https://foo.play.golang.org", http.StatusFound, map[string]string{"Location": "https://play.golang.org"}, nil},
+		{"Unknown snippet", "https://play.golang.org/p/foo", http.StatusNotFound, nil, nil},
+		{"Existing snippet", "https://play.golang.org/p/" + id, http.StatusOK, nil, nil},
+		{"Plaintext snippet", "https://play.golang.org/p/" + id + ".go", http.StatusOK, nil, barBody},
+		{"Download snippet", "https://play.golang.org/p/" + id + ".go?download=true", http.StatusOK, map[string]string{"Content-Disposition": fmt.Sprintf(`attachment; filename="%s.go"`, id)}, barBody},
+	}
+
+	for _, tc := range testCases {
+		req := httptest.NewRequest(http.MethodGet, tc.url, nil)
+		w := httptest.NewRecorder()
+		s.handleEdit(w, req)
+		resp := w.Result()
+		if got, want := resp.StatusCode, tc.statusCode; got != want {
+			t.Errorf("%s: got unexpected status code %d; want %d", tc.desc, got, want)
+		}
+		for k, v := range tc.headers {
+			if got, want := resp.Header.Get(k), v; got != want {
+				t.Errorf("Got header value %q of %q; want %q", k, got, want)
+			}
+		}
+		if tc.respBody != nil {
+			defer resp.Body.Close()
+			b, err := ioutil.ReadAll(resp.Body)
+			if err != nil {
+				t.Errorf("%s: ioutil.ReadAll(resp.Body): %v", tc.desc, err)
+			}
+			if !bytes.Equal(b, tc.respBody) {
+				t.Errorf("%s: got unexpected body %q; want %q", tc.desc, b, tc.respBody)
+			}
+		}
+	}
+}
+
+func TestShare(t *testing.T) {
+	s, err := newServer(testingOptions(t))
+	if err != nil {
+		t.Fatalf("newServer(testingOptions(t)): %v", err)
+	}
+
+	const url = "https://play.golang.org/share"
+	testCases := []struct {
+		desc       string
+		method     string
+		statusCode int
+		reqBody    []byte
+		respBody   []byte
+	}{
+		{"OPTIONS no-op", http.MethodOptions, http.StatusOK, nil, nil},
+		{"Non-POST request", http.MethodGet, http.StatusMethodNotAllowed, nil, nil},
+		{"Standard flow", http.MethodPost, http.StatusOK, []byte("Snippy McSnipface"), []byte("ti55j8ibFJ")},
+		{"Snippet too large", http.MethodPost, http.StatusRequestEntityTooLarge, make([]byte, maxSnippetSize+1), nil},
+	}
+
+	for _, tc := range testCases {
+		req := httptest.NewRequest(tc.method, url, bytes.NewReader(tc.reqBody))
+		w := httptest.NewRecorder()
+		s.handleShare(w, req)
+		resp := w.Result()
+		if got, want := resp.StatusCode, tc.statusCode; got != want {
+			t.Errorf("%s: got unexpected status code %d; want %d", tc.desc, got, want)
+		}
+		if tc.respBody != nil {
+			defer resp.Body.Close()
+			b, err := ioutil.ReadAll(resp.Body)
+			if err != nil {
+				t.Errorf("%s: ioutil.ReadAll(resp.Body): %v", tc.desc, err)
+			}
+			if !bytes.Equal(b, tc.respBody) {
+				t.Errorf("%s: got unexpected body %q; want %q", tc.desc, b, tc.respBody)
+			}
+		}
+	}
 }