internal/sanitizer: address jba's comments on CL 543858

For #61399

Change-Id: Ia4c4103a0d649172a82b482b6947b4ba2f2e785a
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/546315
kokoro-CI: kokoro <noreply+kokoro@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/sanitizer/sanitizer.go b/internal/sanitizer/sanitizer.go
index 741fd9e..577f458 100644
--- a/internal/sanitizer/sanitizer.go
+++ b/internal/sanitizer/sanitizer.go
@@ -8,7 +8,6 @@
 
 import (
 	"bytes"
-	"fmt"
 	"net/url"
 	"regexp"
 	"strings"
@@ -19,14 +18,15 @@
 // SanitizeBytes returns a sanitized version of the input.
 // It throws out any attributes or tags that are not explicitly
 // allowed in allowElems or allowAttributes below, including
-// any child nodes of elements that are not allowed.
+// any child nodes of elements that are not allowed. It returns
+// the empty string if there was an error parsing the input.
 func SanitizeBytes(b []byte) []byte {
 	// TODO(matloob): We want to sanitize a fragment that would
 	// appear in the body. Can we call ParseFragment without
 	// creating the body node here?
 	document, err := html.Parse(strings.NewReader("<html><head></head><body></body></html>"))
 	if err != nil {
-		panic(fmt.Errorf("error parsing document: %v", err))
+		return nil
 	}
 	body := document.FirstChild.LastChild // document.FirstChild is the <html> node
 
@@ -47,10 +47,6 @@
 // of parent-less nodes the node should be replaced with.
 func sanitize(n *html.Node) ([]*html.Node, bool) {
 	switch n.Type {
-	case html.CommentNode:
-		return nil, false
-	case html.DoctypeNode:
-		return nil, false
 	case html.TextNode:
 		return nil, true // Assume text nodes are safe
 	case html.ElementNode:
@@ -119,7 +115,7 @@
 		}
 		return nil, true
 	default:
-		return extractSanitizedChildren(n), false
+		return nil, false
 	}
 }
 
@@ -146,6 +142,9 @@
 	return keepNodes
 }
 
+// addRelNoFollow adds a rel="nofollow" attribute to the attributes
+// if the href attribute is present. If there's already a rel
+// attribute present its value is replaced with "nofollow".
 func addRelNoFollow(attrs []html.Attribute) []html.Attribute {
 	hasHref := false
 	for _, attr := range attrs {
@@ -382,10 +381,14 @@
 	return regexp.MustCompile(rx).MatchString
 }
 
+// validURL returns true if the URL is a valid url according to the
+// following rules: it must be url.Parsable when the spaces are trimmed,
+// it can not contain interior newlines, tabs, or spaces, and its scheme,
+// if present must be mailto, http, or https.
 func validURL(rawurl string) bool {
 	rawurl = strings.TrimSpace(rawurl)
 
-	if strings.ContainsAny(rawurl, " \t\n") {
+	if strings.ContainsAny(rawurl, " \t") {
 		return false
 	}
 
diff --git a/internal/sanitizer/sanitizer_test.go b/internal/sanitizer/sanitizer_test.go
index 00a3f06..ba91d96 100644
--- a/internal/sanitizer/sanitizer_test.go
+++ b/internal/sanitizer/sanitizer_test.go
@@ -5,7 +5,10 @@
 package sanitizer
 
 import (
+	"reflect"
 	"testing"
+
+	"golang.org/x/net/html"
 )
 
 func TestSanitizeBytes(t *testing.T) {
@@ -152,3 +155,71 @@
 		}
 	}
 }
+
+func TestAddRelNoFollow(t *testing.T) {
+	testCases := []struct {
+		input []html.Attribute
+		want  []html.Attribute
+	}{
+		{
+			[]html.Attribute{},
+			[]html.Attribute{},
+		},
+		{
+			[]html.Attribute{{Key: "id", Val: "foo"}},
+			[]html.Attribute{{Key: "id", Val: "foo"}},
+		},
+		{
+			[]html.Attribute{{Key: "href", Val: "https://golang.org"}},
+			[]html.Attribute{{Key: "href", Val: "https://golang.org"}, {Key: "rel", Val: "nofollow"}},
+		},
+		{
+			[]html.Attribute{{Key: "href", Val: "https://golang.org"}, {Key: "rel", Val: "nofollow"}},
+			[]html.Attribute{{Key: "href", Val: "https://golang.org"}, {Key: "rel", Val: "nofollow"}},
+		},
+		{
+			[]html.Attribute{{Key: "href", Val: "https://golang.org"}, {Key: "rel", Val: "canonical"}},
+			[]html.Attribute{{Key: "href", Val: "https://golang.org"}, {Key: "rel", Val: "nofollow"}},
+		},
+		{
+			[]html.Attribute{{Key: "id", Val: "foo"}, {Key: "rel", Val: "canonical"}},
+			[]html.Attribute{{Key: "id", Val: "foo"}, {Key: "rel", Val: "canonical"}},
+		},
+	}
+
+	for _, tc := range testCases {
+		got := addRelNoFollow(append([]html.Attribute{}, tc.input...))
+		if !reflect.DeepEqual(got, tc.want) {
+			t.Errorf("addRelNoFollow(%v): got %v, want %v", tc.input, got, tc.want)
+		}
+	}
+}
+
+func TestValidURL(t *testing.T) {
+	testCases := []struct {
+		input string
+		want  bool
+	}{
+		{"", true},
+		{"#", true},
+		{"https://golang.org", true},
+		{"http://golang.org", true},
+		{"mailto:golang.org", true},
+		{"unsupported:golang.org", false},
+		{" https://golang.org ", true},
+		{"\thttps://golang.org ", true},
+		{" https://golang.org/my file", false},
+		{" https://golang.org/my\tfile", false},
+		{" https://golang.org/my\nfile", false},
+		{"%", false},
+		{" % ", false},
+		{" %\t% ", false},
+	}
+
+	for _, tc := range testCases {
+		got := validURL(tc.input)
+		if got != tc.want {
+			t.Errorf("validURL(%q): got %v, want %v", tc.input, got, tc.want)
+		}
+	}
+}