| // Copyright 2016 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package autocert |
| |
| import ( |
| "bytes" |
| "crypto/rand" |
| "crypto/rsa" |
| "crypto/tls" |
| "crypto/x509" |
| "crypto/x509/pkix" |
| "encoding/pem" |
| "fmt" |
| "html/template" |
| "math/big" |
| "net/http" |
| "net/http/httptest" |
| "strings" |
| "testing" |
| "time" |
| |
| "golang.org/x/crypto/acme/internal/acme" |
| "golang.org/x/net/context" |
| ) |
| |
| var discoTmpl = template.Must(template.New("disco").Parse(`{ |
| "new-reg": "{{.}}/new-reg", |
| "new-authz": "{{.}}/new-authz", |
| "new-cert": "{{.}}/new-cert" |
| }`)) |
| |
| var authzTmpl = template.Must(template.New("authz").Parse(`{ |
| "status": "pending", |
| "challenges": [ |
| { |
| "uri": "{{.}}/challenge/1", |
| "type": "tls-sni-01", |
| "token": "token-01" |
| }, |
| { |
| "uri": "{{.}}/challenge/2", |
| "type": "tls-sni-02", |
| "token": "token-02" |
| } |
| ] |
| }`)) |
| |
| func dummyCert(san ...string) ([]byte, error) { |
| // use smaller key to run faster on 386 |
| key, err := rsa.GenerateKey(rand.Reader, 512) |
| if err != nil { |
| return nil, err |
| } |
| t := &x509.Certificate{ |
| SerialNumber: big.NewInt(1), |
| NotBefore: time.Now(), |
| NotAfter: time.Now().Add(24 * time.Hour), |
| BasicConstraintsValid: true, |
| KeyUsage: x509.KeyUsageKeyEncipherment, |
| DNSNames: san, |
| } |
| return x509.CreateCertificate(rand.Reader, t, t, &key.PublicKey, key) |
| } |
| |
| func TestGetCertificate(t *testing.T) { |
| const domain = "example.org" |
| man := &Manager{Prompt: AcceptTOS} |
| |
| // echo token-02 | shasum -a 256 |
| // then divide result in 2 parts separated by dot |
| tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid" |
| verifyTokenCert := func() { |
| hello := &tls.ClientHelloInfo{ServerName: tokenCertName} |
| _, err := man.GetCertificate(hello) |
| if err != nil { |
| t.Errorf("verifyTokenCert: GetCertificate(%q): %v", tokenCertName, err) |
| return |
| } |
| } |
| |
| // ACME CA server stub |
| var ca *httptest.Server |
| ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("replay-nonce", "nonce") |
| switch r.URL.Path { |
| // discovery |
| case "/": |
| if err := discoTmpl.Execute(w, ca.URL); err != nil { |
| t.Fatalf("discoTmpl: %v", err) |
| } |
| // client key registration |
| case "/new-reg": |
| w.Write([]byte("{}")) |
| // domain authorization |
| case "/new-authz": |
| w.Header().Set("location", ca.URL+"/authz/1") |
| w.WriteHeader(http.StatusCreated) |
| if err := authzTmpl.Execute(w, ca.URL); err != nil { |
| t.Fatalf("authzTmpl: %v", err) |
| } |
| // accept tls-sni-02 challenge |
| case "/challenge/2": |
| verifyTokenCert() |
| w.Write([]byte("{}")) |
| // authorization status |
| case "/authz/1": |
| w.Write([]byte(`{"status": "valid"}`)) |
| // cert request |
| case "/new-cert": |
| der, err := dummyCert(domain) |
| if err != nil { |
| t.Fatalf("new-cert: dummyCert: %v", err) |
| } |
| chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL) |
| w.Header().Set("link", chainUp) |
| w.WriteHeader(http.StatusCreated) |
| w.Write(der) |
| // CA chain cert |
| case "/ca-cert": |
| der, err := dummyCert("ca") |
| if err != nil { |
| t.Fatalf("ca-cert: dummyCert: %v", err) |
| } |
| w.Write(der) |
| default: |
| t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path) |
| } |
| })) |
| defer ca.Close() |
| |
| // use smaller key to run faster on 386 |
| key, kerr := rsa.GenerateKey(rand.Reader, 512) |
| if kerr != nil { |
| t.Fatal(kerr) |
| } |
| man.Client = &acme.Client{ |
| Key: key, |
| DirectoryURL: ca.URL, |
| } |
| |
| // simulate tls.Config.GetCertificate |
| var ( |
| tlscert *tls.Certificate |
| err error |
| done = make(chan struct{}) |
| ) |
| go func() { |
| hello := &tls.ClientHelloInfo{ServerName: domain} |
| tlscert, err = man.GetCertificate(hello) |
| close(done) |
| }() |
| select { |
| case <-time.After(15 * time.Second): |
| t.Fatal("man.GetCertificate took too long to return") |
| case <-done: |
| } |
| if err != nil { |
| t.Fatalf("man.GetCertificate: %v", err) |
| } |
| |
| // verify the tlscert is the same we responded with from the CA stub |
| if len(tlscert.Certificate) == 0 { |
| t.Fatal("len(tlscert.Certificate) is 0") |
| } |
| cert, err := x509.ParseCertificate(tlscert.Certificate[0]) |
| if err != nil { |
| t.Fatalf("x509.ParseCertificate: %v", err) |
| } |
| if len(cert.DNSNames) == 0 || cert.DNSNames[0] != domain { |
| t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain) |
| } |
| |
| // make sure token cert was removed |
| done = make(chan struct{}) |
| go func() { |
| for { |
| hello := &tls.ClientHelloInfo{ServerName: tokenCertName} |
| if _, err := man.GetCertificate(hello); err != nil { |
| break |
| } |
| time.Sleep(100 * time.Millisecond) |
| } |
| close(done) |
| }() |
| select { |
| case <-time.After(5 * time.Second): |
| t.Error("token cert was not removed") |
| case <-done: |
| } |
| } |
| |
| type memCache map[string][]byte |
| |
| func (m memCache) Get(ctx context.Context, key string) ([]byte, error) { |
| v, ok := m[key] |
| if !ok { |
| return nil, ErrCacheMiss |
| } |
| return v, nil |
| } |
| |
| func (m memCache) Put(ctx context.Context, key string, data []byte) error { |
| m[key] = data |
| return nil |
| } |
| |
| func (m memCache) Delete(ctx context.Context, key string) error { |
| delete(m, key) |
| return nil |
| } |
| |
| func TestCache(t *testing.T) { |
| privKey, err := rsa.GenerateKey(rand.Reader, 2048) |
| if err != nil { |
| t.Fatal(err) |
| } |
| tmpl := &x509.Certificate{ |
| SerialNumber: big.NewInt(1), |
| Subject: pkix.Name{CommonName: "example.org"}, |
| NotAfter: time.Now().Add(time.Hour), |
| } |
| pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privKey.PublicKey, privKey) |
| if err != nil { |
| t.Fatal(err) |
| } |
| tlscert := &tls.Certificate{ |
| Certificate: [][]byte{pub}, |
| PrivateKey: privKey, |
| } |
| |
| cache := make(memCache) |
| man := Manager{Cache: cache} |
| if err := man.cachePut("example.org", tlscert); err != nil { |
| t.Fatalf("man.cachePut: %v", err) |
| } |
| res, err := man.cacheGet("example.org") |
| if err != nil { |
| t.Fatalf("man.cacheGet: %v", err) |
| } |
| if res == nil { |
| t.Fatal("res is nil") |
| } |
| |
| priv := x509.MarshalPKCS1PrivateKey(privKey) |
| dummy, err := dummyCert("dummy") |
| if err != nil { |
| t.Fatalf("dummyCert: %v", err) |
| } |
| tt := []struct { |
| key string |
| prv, pub []byte |
| }{ |
| {"dummy", priv, dummy}, |
| {"bad1", priv, []byte{1}}, |
| {"bad2", []byte{1}, pub}, |
| } |
| for i, test := range tt { |
| var buf bytes.Buffer |
| pb := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: test.prv} |
| if err := pem.Encode(&buf, pb); err != nil { |
| t.Errorf("%d: pem.Encode: %v", i, err) |
| } |
| pb = &pem.Block{Type: "CERTIFICATE", Bytes: test.pub} |
| if err := pem.Encode(&buf, pb); err != nil { |
| t.Errorf("%d: pem.Encode: %v", i, err) |
| } |
| |
| cache.Put(nil, test.key, buf.Bytes()) |
| if _, err := man.cacheGet(test.key); err == nil { |
| t.Errorf("%d: err is nil", i) |
| } |
| } |
| } |
| |
| func TestDNSNames(t *testing.T) { |
| man := Manager{ |
| DNSNames: []string{"example.com"}, |
| // prevent network round-trips, just in case |
| Client: &acme.Client{DirectoryURL: "dummy"}, |
| } |
| hello := &tls.ClientHelloInfo{ServerName: "example.org"} |
| _, err := man.GetCertificate(hello) |
| if err == nil || !strings.Contains(err.Error(), "not allowed") { |
| t.Errorf("err = %v; want 'not allowed'", err) |
| } |
| } |