diff --git a/app/dns/dns.go b/app/dns/dns.go index 3b9cfbcb..6d810a70 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -187,7 +187,12 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er } // Static host lookup - switch addrs := s.hosts.Lookup(domain, option); { + switch addrs, err := s.hosts.Lookup(domain, option); { + case err != nil: + if go_errors.Is(err, dns.ErrEmptyResponse) { + return nil, 0, dns.ErrEmptyResponse + } + return nil, 0, errors.New("returning nil for domain ", domain).Base(err) case addrs == nil: // Domain not recorded in static host break case len(addrs) == 0: // Domain recorded, but no valid IP returned (e.g. IPv4 address with only IPv6 enabled) @@ -250,13 +255,12 @@ func (s *DNS) LookupHosts(domain string) *net.Address { return nil } // Normalize the FQDN form query - addrs := s.hosts.Lookup(domain, *s.ipOption) - if len(addrs) > 0 { - errors.LogInfo(s.ctx, "domain replaced: ", domain, " -> ", addrs[0].String()) - return &addrs[0] + addrs, err := s.hosts.Lookup(domain, *s.ipOption) + if err != nil || len(addrs) == 0 { + return nil } - - return nil + errors.LogInfo(s.ctx, "domain replaced: ", domain, " -> ", addrs[0].String()) + return &addrs[0] } func (s *DNS) sortClients(domain string) []*Client { diff --git a/app/dns/hosts.go b/app/dns/hosts.go index f4e06dbb..c2f7649d 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -2,6 +2,8 @@ package dns import ( "context" + "strconv" + "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/strmatcher" @@ -31,7 +33,15 @@ func NewStaticHosts(hosts []*Config_HostMapping) (*StaticHosts, error) { ips := make([]net.Address, 0, len(mapping.Ip)+1) switch { case len(mapping.ProxiedDomain) > 0: - ips = append(ips, net.DomainAddress(mapping.ProxiedDomain)) + if mapping.ProxiedDomain[0] == '#' { + rcode, err := strconv.Atoi(mapping.ProxiedDomain[1:]) + if err != nil { + return nil, err + } + ips = append(ips, dns.RCodeError(rcode)) + } else { + ips = append(ips, net.DomainAddress(mapping.ProxiedDomain)) + } case len(mapping.Ip) > 0: for _, ip := range mapping.Ip { addr := net.IPAddress(ip) @@ -58,38 +68,51 @@ func filterIP(ips []net.Address, option dns.IPOption) []net.Address { return filtered } -func (h *StaticHosts) lookupInternal(domain string) []net.Address { +func (h *StaticHosts) lookupInternal(domain string) ([]net.Address, error) { ips := make([]net.Address, 0) found := false for _, id := range h.matchers.Match(domain) { + for _, v := range h.ips[id] { + if err, ok := v.(dns.RCodeError); ok { + if uint16(err) == 0 { + return nil, dns.ErrEmptyResponse + } + return nil, err + } + } ips = append(ips, h.ips[id]...) found = true } if !found { - return nil + return nil, nil } - return ips + return ips, nil } -func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address { - switch addrs := h.lookupInternal(domain); { +func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ([]net.Address, error) { + switch addrs, err := h.lookupInternal(domain); { + case err != nil: + return nil, err case len(addrs) == 0: // Not recorded in static hosts, return nil - return addrs + return addrs, nil case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Try to unwrap domain errors.LogDebug(context.Background(), "found replaced domain: ", domain, " -> ", addrs[0].Domain(), ". Try to unwrap it") if maxDepth > 0 { - unwrapped := h.lookup(addrs[0].Domain(), option, maxDepth-1) + unwrapped, err := h.lookup(addrs[0].Domain(), option, maxDepth-1) + if err != nil { + return nil, err + } if unwrapped != nil { - return unwrapped + return unwrapped, nil } } - return addrs + return addrs, nil default: // IP record found, return a non-nil IP array - return filterIP(addrs, option) + return filterIP(addrs, option), nil } } // Lookup returns IP addresses or proxied domain for the given domain, if exists in this StaticHosts. -func (h *StaticHosts) Lookup(domain string, option dns.IPOption) []net.Address { +func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) { return h.lookup(domain, option, 5) } diff --git a/app/dns/hosts_test.go b/app/dns/hosts_test.go index 380c7cb2..2c7f8b69 100644 --- a/app/dns/hosts_test.go +++ b/app/dns/hosts_test.go @@ -12,6 +12,11 @@ import ( func TestStaticHosts(t *testing.T) { pb := []*Config_HostMapping{ + { + Type: DomainMatchingType_Subdomain, + Domain: "lan", + ProxiedDomain: "#3", + }, { Type: DomainMatchingType_Full, Domain: "example.com", @@ -54,7 +59,14 @@ func TestStaticHosts(t *testing.T) { common.Must(err) { - ips := hosts.Lookup("example.com", dns.IPOption{ + _, err := hosts.Lookup("example.com.lan", dns.IPOption{}) + if dns.RCodeFromError(err) != 3 { + t.Error(err) + } + } + + { + ips, _ := hosts.Lookup("example.com", dns.IPOption{ IPv4Enable: true, IPv6Enable: true, }) @@ -67,7 +79,7 @@ func TestStaticHosts(t *testing.T) { } { - domain := hosts.Lookup("proxy.xray.com", dns.IPOption{ + domain, _ := hosts.Lookup("proxy.xray.com", dns.IPOption{ IPv4Enable: true, IPv6Enable: false, }) @@ -80,7 +92,7 @@ func TestStaticHosts(t *testing.T) { } { - domain := hosts.Lookup("proxy2.xray.com", dns.IPOption{ + domain, _ := hosts.Lookup("proxy2.xray.com", dns.IPOption{ IPv4Enable: true, IPv6Enable: false, }) @@ -93,7 +105,7 @@ func TestStaticHosts(t *testing.T) { } { - ips := hosts.Lookup("www.example.cn", dns.IPOption{ + ips, _ := hosts.Lookup("www.example.cn", dns.IPOption{ IPv4Enable: true, IPv6Enable: true, }) @@ -106,7 +118,7 @@ func TestStaticHosts(t *testing.T) { } { - ips := hosts.Lookup("baidu.com", dns.IPOption{ + ips, _ := hosts.Lookup("baidu.com", dns.IPOption{ IPv4Enable: false, IPv6Enable: true, }) diff --git a/features/dns/client.go b/features/dns/client.go index 7dc576fb..db202e3d 100644 --- a/features/dns/client.go +++ b/features/dns/client.go @@ -46,6 +46,24 @@ func (e RCodeError) Error() string { return serial.Concat("rcode: ", uint16(e)) } +func (RCodeError) IP() net.IP { + panic("Calling IP() on a RCodeError.") +} + +func (RCodeError) Domain() string { + panic("Calling Domain() on a RCodeError.") +} + +func (RCodeError) Family() net.AddressFamily { + panic("Calling Family() on a RCodeError.") +} + +func (e RCodeError) String() string { + return e.Error() +} + +var _ net.Address = (*RCodeError)(nil) + func RCodeFromError(err error) uint16 { if err == nil { return 0