| // Copyright 2022 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| //go:build !js |
| |
| // Test that Resolver.Dial can be a func returning an in-memory net.Conn |
| // speaking DNS. |
| |
| package net |
| |
| import ( |
| "bytes" |
| "context" |
| "errors" |
| "fmt" |
| "reflect" |
| "sort" |
| "testing" |
| "time" |
| |
| "golang.org/x/net/dns/dnsmessage" |
| ) |
| |
| func TestResolverDialFunc(t *testing.T) { |
| r := &Resolver{ |
| PreferGo: true, |
| Dial: newResolverDialFunc(&resolverDialHandler{ |
| StartDial: func(network, address string) error { |
| t.Logf("StartDial(%q, %q) ...", network, address) |
| return nil |
| }, |
| Question: func(h dnsmessage.Header, q dnsmessage.Question) { |
| t.Logf("Header: %+v for %q (type=%v, class=%v)", h, |
| q.Name.String(), q.Type, q.Class) |
| }, |
| // TODO: add test without HandleA* hooks specified at all, that Go |
| // doesn't issue retries; map to something terminal. |
| HandleA: func(w AWriter, name string) error { |
| w.AddIP([4]byte{1, 2, 3, 4}) |
| w.AddIP([4]byte{5, 6, 7, 8}) |
| return nil |
| }, |
| HandleAAAA: func(w AAAAWriter, name string) error { |
| w.AddIP([16]byte{1: 1, 15: 15}) |
| w.AddIP([16]byte{2: 2, 14: 14}) |
| return nil |
| }, |
| HandleSRV: func(w SRVWriter, name string) error { |
| w.AddSRV(1, 2, 80, "foo.bar.") |
| w.AddSRV(2, 3, 81, "bar.baz.") |
| return nil |
| }, |
| }), |
| } |
| ctx := context.Background() |
| const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld." |
| |
| t.Run("LookupIP", func(t *testing.T) { |
| ips, err := r.LookupIP(ctx, "ip", fakeDomain) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) { |
| t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want) |
| } |
| }) |
| |
| t.Run("LookupSRV", func(t *testing.T) { |
| _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain) |
| if err != nil { |
| t.Fatal(err) |
| } |
| want := []*SRV{ |
| { |
| Target: "foo.bar.", |
| Port: 80, |
| Priority: 1, |
| Weight: 2, |
| }, |
| { |
| Target: "bar.baz.", |
| Port: 81, |
| Priority: 2, |
| Weight: 3, |
| }, |
| } |
| if !reflect.DeepEqual(got, want) { |
| t.Errorf("wrong result. got:") |
| for _, r := range got { |
| t.Logf(" - %+v", r) |
| } |
| } |
| }) |
| } |
| |
| func sortedIPStrings(ips []IP) []string { |
| ret := make([]string, len(ips)) |
| for i, ip := range ips { |
| ret[i] = ip.String() |
| } |
| sort.Strings(ret) |
| return ret |
| } |
| |
| func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) { |
| return func(ctx context.Context, network, address string) (Conn, error) { |
| a := &resolverFuncConn{ |
| h: h, |
| network: network, |
| address: address, |
| ttl: 10, // 10 second default if unset |
| } |
| if h.StartDial != nil { |
| if err := h.StartDial(network, address); err != nil { |
| return nil, err |
| } |
| } |
| return a, nil |
| } |
| } |
| |
| type resolverDialHandler struct { |
| // StartDial, if non-nil, is called when Go first calls Resolver.Dial. |
| // Any error returned aborts the dial and is returned unwrapped. |
| StartDial func(network, address string) error |
| |
| Question func(dnsmessage.Header, dnsmessage.Question) |
| |
| // err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2). |
| // A nil error means success. |
| HandleA func(w AWriter, name string) error |
| HandleAAAA func(w AAAAWriter, name string) error |
| HandleSRV func(w SRVWriter, name string) error |
| } |
| |
| type ResponseWriter struct{ a *resolverFuncConn } |
| |
| func (w ResponseWriter) header() dnsmessage.ResourceHeader { |
| q := w.a.q |
| return dnsmessage.ResourceHeader{ |
| Name: q.Name, |
| Type: q.Type, |
| Class: q.Class, |
| TTL: w.a.ttl, |
| } |
| } |
| |
| // SetTTL sets the TTL for subsequent written resources. |
| // Once a resource has been written, SetTTL calls are no-ops. |
| // That is, it can only be called at most once, before anything |
| // else is written. |
| func (w ResponseWriter) SetTTL(seconds uint32) { |
| // ... intention is last one wins and mutates all previously |
| // written records too, but that's a little annoying. |
| // But it's also annoying if the requirement is it needs to be set |
| // last. |
| // And it's also annoying if it's possible for users to set |
| // different TTLs per Answer. |
| if w.a.wrote { |
| return |
| } |
| w.a.ttl = seconds |
| |
| } |
| |
| type AWriter struct{ ResponseWriter } |
| |
| func (w AWriter) AddIP(v4 [4]byte) { |
| w.a.wrote = true |
| err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4}) |
| if err != nil { |
| panic(err) |
| } |
| } |
| |
| type AAAAWriter struct{ ResponseWriter } |
| |
| func (w AAAAWriter) AddIP(v6 [16]byte) { |
| w.a.wrote = true |
| err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6}) |
| if err != nil { |
| panic(err) |
| } |
| } |
| |
| type SRVWriter struct{ ResponseWriter } |
| |
| // AddSRV adds a SRV record. The target name must end in a period and |
| // be 63 bytes or fewer. |
| func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error { |
| targetName, err := dnsmessage.NewName(target) |
| if err != nil { |
| return err |
| } |
| w.a.wrote = true |
| err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{ |
| Priority: priority, |
| Weight: weight, |
| Port: port, |
| Target: targetName, |
| }) |
| if err != nil { |
| panic(err) // internal fault, not user |
| } |
| return nil |
| } |
| |
| var ( |
| ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN |
| ErrRefused = errors.New("refused") // maps to RCode5, REFUSED |
| ) |
| |
| type resolverFuncConn struct { |
| h *resolverDialHandler |
| network string |
| address string |
| builder *dnsmessage.Builder |
| q dnsmessage.Question |
| ttl uint32 |
| wrote bool |
| |
| rbuf bytes.Buffer |
| } |
| |
| func (*resolverFuncConn) Close() error { return nil } |
| func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} } |
| func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} } |
| func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil } |
| func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil } |
| func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil } |
| |
| func (a *resolverFuncConn) Read(p []byte) (n int, err error) { |
| return a.rbuf.Read(p) |
| } |
| |
| func (a *resolverFuncConn) Write(packet []byte) (n int, err error) { |
| if len(packet) < 2 { |
| return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet)) |
| } |
| reqLen := int(packet[0])<<8 | int(packet[1]) |
| req := packet[2:] |
| if len(req) != reqLen { |
| return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req)) |
| } |
| |
| var parser dnsmessage.Parser |
| h, err := parser.Start(req) |
| if err != nil { |
| // TODO: hook |
| return 0, err |
| } |
| q, err := parser.Question() |
| hadQ := (err == nil) |
| if err == nil && a.h.Question != nil { |
| a.h.Question(h, q) |
| } |
| if err != nil && err != dnsmessage.ErrSectionDone { |
| return 0, err |
| } |
| |
| resh := h |
| resh.Response = true |
| resh.Authoritative = true |
| if hadQ { |
| resh.RCode = dnsmessage.RCodeSuccess |
| } else { |
| resh.RCode = dnsmessage.RCodeNotImplemented |
| } |
| a.rbuf.Grow(514) |
| a.rbuf.WriteByte('X') // reserved header for beu16 length |
| a.rbuf.WriteByte('Y') // reserved header for beu16 length |
| builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh) |
| a.builder = &builder |
| if hadQ { |
| a.q = q |
| a.builder.StartQuestions() |
| err := a.builder.Question(q) |
| if err != nil { |
| return 0, fmt.Errorf("Question: %w", err) |
| } |
| a.builder.StartAnswers() |
| switch q.Type { |
| case dnsmessage.TypeA: |
| if a.h.HandleA != nil { |
| resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String())) |
| } |
| case dnsmessage.TypeAAAA: |
| if a.h.HandleAAAA != nil { |
| resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String())) |
| } |
| case dnsmessage.TypeSRV: |
| if a.h.HandleSRV != nil { |
| resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String())) |
| } |
| } |
| } |
| tcpRes, err := builder.Finish() |
| if err != nil { |
| return 0, fmt.Errorf("Finish: %w", err) |
| } |
| |
| n = len(tcpRes) - 2 |
| tcpRes[0] = byte(n >> 8) |
| tcpRes[1] = byte(n) |
| a.rbuf.Write(tcpRes[2:]) |
| |
| return len(packet), nil |
| } |
| |
| type someaddr struct{} |
| |
| func (someaddr) Network() string { return "unused" } |
| func (someaddr) String() string { return "unused-someaddr" } |
| |
| func mapRCode(err error) dnsmessage.RCode { |
| switch err { |
| case nil: |
| return dnsmessage.RCodeSuccess |
| case ErrNotExist: |
| return dnsmessage.RCodeNameError |
| case ErrRefused: |
| return dnsmessage.RCodeRefused |
| default: |
| return dnsmessage.RCodeServerFailure |
| } |
| } |