autocert: validate SNI values more, add tests
Change-Id: I810c8dcc90c056d7fa66bba59c0936f54aabdfc7
Reviewed-on: https://go-review.googlesource.com/42497
Run-TryBot: Alex Vaghin <ddos@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Alex Vaghin <ddos@google.com>
diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go
index 2b5d068..12a98a4 100644
--- a/acme/autocert/autocert.go
+++ b/acme/autocert/autocert.go
@@ -177,6 +177,10 @@
return nil, errors.New("acme/autocert: missing server name")
}
+ if strings.ContainsAny(name, `/\`) {
+ return nil, errors.New("acme/autocert: bogus SNI value")
+ }
+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
diff --git a/acme/autocert/autocert_test.go b/acme/autocert/autocert_test.go
index 4bb4640..643ab2b 100644
--- a/acme/autocert/autocert_test.go
+++ b/acme/autocert/autocert_test.go
@@ -560,3 +560,42 @@
}
}
}
+
+type cacheGetFunc func(ctx context.Context, key string) ([]byte, error)
+
+func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) {
+ return f(ctx, key)
+}
+
+func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error {
+ return fmt.Errorf("unsupported Put of %q = %q", key, data)
+}
+
+func (f cacheGetFunc) Delete(ctx context.Context, key string) error {
+ return fmt.Errorf("unsupported Delete of %q", key)
+}
+
+func TestManagerGetCertificateBogusSNI(t *testing.T) {
+ m := Manager{
+ Prompt: AcceptTOS,
+ Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) {
+ return nil, fmt.Errorf("cache.Get of %s", key)
+ }),
+ }
+ tests := []struct {
+ name string
+ wantErr string
+ }{
+ {"foo.com", "cache.Get of foo.com"},
+ {"foo.com.", "cache.Get of foo.com"},
+ {`a\b`, "acme/autocert: bogus SNI value"},
+ {"", "acme/autocert: missing server name"},
+ }
+ for _, tt := range tests {
+ _, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name})
+ got := fmt.Sprint(err)
+ if got != tt.wantErr {
+ t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr)
+ }
+ }
+}