frontend: add tests and redirect to https on frontend
Change-Id: I1c116b0899a1ea37870a147277c4460366c2ac5d
Reviewed-on: https://go-review.googlesource.com/84976
Reviewed-by: Russ Cox <rsc@golang.org>
diff --git a/frontend/edit.go b/frontend/edit.go
index 5edd633..2e8809f 100644
--- a/frontend/edit.go
+++ b/frontend/edit.go
@@ -36,7 +36,6 @@
w.Write([]byte(`<h1>Unavailable For Legal Reasons</h1><p>Viewing and/or sharing code snippets is not available in your country for legal reasons. This message might also appear if your country is misdetected. If you believe this is an error, please <a href="https://golang.org/issue">file an issue</a>.</p>`))
return
}
- ctx := r.Context()
id := r.URL.Path[3:]
serveText := false
if strings.HasSuffix(id, ".go") {
@@ -44,7 +43,7 @@
serveText = true
}
- if err := s.db.GetSnippet(ctx, id, snip); err != nil {
+ if err := s.db.GetSnippet(r.Context(), id, snip); err != nil {
if err != datastore.ErrNoSuchEntity {
s.log.Errorf("loading Snippet: %v", err)
}
@@ -62,7 +61,15 @@
return
}
}
- editTemplate.Execute(w, &editData{snip, allowShare(r)})
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ data := &editData{
+ Snippet: snip,
+ Share: allowShare(r),
+ }
+ if err := editTemplate.Execute(w, data); err != nil {
+ s.log.Errorf("editTemplate.Execute(w, %+v): %v", data, err)
+ return
+ }
}
const hello = `package main
diff --git a/frontend/server.go b/frontend/server.go
index 881c6d7..136124c 100644
--- a/frontend/server.go
+++ b/frontend/server.go
@@ -74,7 +74,13 @@
}
func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- if os.Getenv("GAE_INSTANCE") != "" {
+ if r.Header.Get("X-Forwarded-Proto") == "http" {
+ r.URL.Scheme = "https"
+ r.URL.Host = r.Host
+ http.Redirect(w, r, r.URL.String(), http.StatusFound)
+ return
+ }
+ if r.Header.Get("X-Forwarded-Proto") == "https" {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; preload")
}
s.mux.ServeHTTP(w, r)
diff --git a/frontend/server_test.go b/frontend/server_test.go
index 19bc261..4eec012 100644
--- a/frontend/server_test.go
+++ b/frontend/server_test.go
@@ -3,7 +3,14 @@
// license that can be found in the LICENSE file.
package main
-import "testing"
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
type testLogger struct {
t *testing.T
@@ -28,4 +35,94 @@
}
func TestEdit(t *testing.T) {
+ s, err := newServer(testingOptions(t))
+ if err != nil {
+ t.Fatalf("newServer(testingOptions(t)): %v", err)
+ }
+ id := "bar"
+ barBody := []byte("Snippy McSnipface")
+ snip := &snippet{Body: barBody}
+ if err := s.db.PutSnippet(nil, id, snip); err != nil {
+ t.Fatalf("s.dbPutSnippet(nil, %+v, %+v): %v", id, snip, err)
+ }
+
+ testCases := []struct {
+ desc string
+ url string
+ statusCode int
+ headers map[string]string
+ respBody []byte
+ }{
+ {"foo.play.golang.org to play.golang.org", "https://foo.play.golang.org", http.StatusFound, map[string]string{"Location": "https://play.golang.org"}, nil},
+ {"Unknown snippet", "https://play.golang.org/p/foo", http.StatusNotFound, nil, nil},
+ {"Existing snippet", "https://play.golang.org/p/" + id, http.StatusOK, nil, nil},
+ {"Plaintext snippet", "https://play.golang.org/p/" + id + ".go", http.StatusOK, nil, barBody},
+ {"Download snippet", "https://play.golang.org/p/" + id + ".go?download=true", http.StatusOK, map[string]string{"Content-Disposition": fmt.Sprintf(`attachment; filename="%s.go"`, id)}, barBody},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(http.MethodGet, tc.url, nil)
+ w := httptest.NewRecorder()
+ s.handleEdit(w, req)
+ resp := w.Result()
+ if got, want := resp.StatusCode, tc.statusCode; got != want {
+ t.Errorf("%s: got unexpected status code %d; want %d", tc.desc, got, want)
+ }
+ for k, v := range tc.headers {
+ if got, want := resp.Header.Get(k), v; got != want {
+ t.Errorf("Got header value %q of %q; want %q", k, got, want)
+ }
+ }
+ if tc.respBody != nil {
+ defer resp.Body.Close()
+ b, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("%s: ioutil.ReadAll(resp.Body): %v", tc.desc, err)
+ }
+ if !bytes.Equal(b, tc.respBody) {
+ t.Errorf("%s: got unexpected body %q; want %q", tc.desc, b, tc.respBody)
+ }
+ }
+ }
+}
+
+func TestShare(t *testing.T) {
+ s, err := newServer(testingOptions(t))
+ if err != nil {
+ t.Fatalf("newServer(testingOptions(t)): %v", err)
+ }
+
+ const url = "https://play.golang.org/share"
+ testCases := []struct {
+ desc string
+ method string
+ statusCode int
+ reqBody []byte
+ respBody []byte
+ }{
+ {"OPTIONS no-op", http.MethodOptions, http.StatusOK, nil, nil},
+ {"Non-POST request", http.MethodGet, http.StatusMethodNotAllowed, nil, nil},
+ {"Standard flow", http.MethodPost, http.StatusOK, []byte("Snippy McSnipface"), []byte("ti55j8ibFJ")},
+ {"Snippet too large", http.MethodPost, http.StatusRequestEntityTooLarge, make([]byte, maxSnippetSize+1), nil},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(tc.method, url, bytes.NewReader(tc.reqBody))
+ w := httptest.NewRecorder()
+ s.handleShare(w, req)
+ resp := w.Result()
+ if got, want := resp.StatusCode, tc.statusCode; got != want {
+ t.Errorf("%s: got unexpected status code %d; want %d", tc.desc, got, want)
+ }
+ if tc.respBody != nil {
+ defer resp.Body.Close()
+ b, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("%s: ioutil.ReadAll(resp.Body): %v", tc.desc, err)
+ }
+ if !bytes.Equal(b, tc.respBody) {
+ t.Errorf("%s: got unexpected body %q; want %q", tc.desc, b, tc.respBody)
+ }
+ }
+ }
}