diff --git a/app/dns/dns.go b/app/dns/dns.go index 3b9cfbcb..cd610907 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -187,7 +187,9 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er } // Static host lookup - switch addrs := s.hosts.Lookup(domain, option); { + switch addrs, err := s.hosts.Lookup(domain, option); { + case err != nil: + return nil, 0, err case addrs == nil: // Domain not recorded in static host break case len(addrs) == 0: // Domain recorded, but no valid IP returned (e.g. IPv4 address with only IPv6 enabled) @@ -250,13 +252,12 @@ func (s *DNS) LookupHosts(domain string) *net.Address { return nil } // Normalize the FQDN form query - addrs := s.hosts.Lookup(domain, *s.ipOption) - if len(addrs) > 0 { - errors.LogInfo(s.ctx, "domain replaced: ", domain, " -> ", addrs[0].String()) - return &addrs[0] + addrs, err := s.hosts.Lookup(domain, *s.ipOption) + if err != nil || len(addrs) == 0 { + return nil } - - return nil + errors.LogInfo(s.ctx, "domain replaced: ", domain, " -> ", addrs[0].String()) + return &addrs[0] } func (s *DNS) sortClients(domain string) []*Client { diff --git a/app/dns/hosts.go b/app/dns/hosts.go index d0485be2..3b38268f 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -3,6 +3,7 @@ package dns import ( "context" "sort" + "strconv" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" @@ -33,7 +34,15 @@ func NewStaticHosts(hosts []*Config_HostMapping) (*StaticHosts, error) { ips := make([]net.Address, 0, len(mapping.Ip)+1) switch { case len(mapping.ProxiedDomain) > 0: - ips = append(ips, net.DomainAddress(mapping.ProxiedDomain)) + if mapping.ProxiedDomain[0] == '#' { + rcode, err := strconv.Atoi(mapping.ProxiedDomain[1:]) + if err != nil { + return nil, err + } + ips = append(ips, dns.RCodeError(rcode)) + } else { + ips = append(ips, net.DomainAddress(mapping.ProxiedDomain)) + } case len(mapping.Ip) > 0: for _, ip := range mapping.Ip { addr := net.IPAddress(ip) @@ -71,25 +80,34 @@ func (h *StaticHosts) lookupInternal(domain string) []net.Address { return h.ips[MatchSlice[0]] } -func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address { +func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ([]net.Address, error) { switch addrs := h.lookupInternal(domain); { case len(addrs) == 0: // Not recorded in static hosts, return nil - return addrs - case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Try to unwrap domain - errors.LogDebug(context.Background(), "found replaced domain: ", domain, " -> ", addrs[0].Domain(), ". Try to unwrap it") - if maxDepth > 0 { - unwrapped := h.lookup(addrs[0].Domain(), option, maxDepth-1) - if unwrapped != nil { - return unwrapped - } + return addrs, nil + case len(addrs) == 1: + if err, ok := addrs[0].(dns.RCodeError); ok { + return nil, err } - return addrs + if addrs[0].Family().IsDomain() { // Try to unwrap domain + errors.LogDebug(context.Background(), "found replaced domain: ", domain, " -> ", addrs[0].Domain(), ". Try to unwrap it") + if maxDepth > 0 { + unwrapped, err := h.lookup(addrs[0].Domain(), option, maxDepth-1) + if err != nil { + return nil, err + } + if unwrapped != nil { + return unwrapped, nil + } + } + return addrs, nil + } + fallthrough default: // IP record found, return a non-nil IP array - return filterIP(addrs, option) + return filterIP(addrs, option), nil } } // Lookup returns IP addresses or proxied domain for the given domain, if exists in this StaticHosts. -func (h *StaticHosts) Lookup(domain string, option dns.IPOption) []net.Address { +func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) { return h.lookup(domain, option, 5) } diff --git a/app/dns/hosts_test.go b/app/dns/hosts_test.go index 380c7cb2..2c7f8b69 100644 --- a/app/dns/hosts_test.go +++ b/app/dns/hosts_test.go @@ -12,6 +12,11 @@ import ( func TestStaticHosts(t *testing.T) { pb := []*Config_HostMapping{ + { + Type: DomainMatchingType_Subdomain, + Domain: "lan", + ProxiedDomain: "#3", + }, { Type: DomainMatchingType_Full, Domain: "example.com", @@ -54,7 +59,14 @@ func TestStaticHosts(t *testing.T) { common.Must(err) { - ips := hosts.Lookup("example.com", dns.IPOption{ + _, err := hosts.Lookup("example.com.lan", dns.IPOption{}) + if dns.RCodeFromError(err) != 3 { + t.Error(err) + } + } + + { + ips, _ := hosts.Lookup("example.com", dns.IPOption{ IPv4Enable: true, IPv6Enable: true, }) @@ -67,7 +79,7 @@ func TestStaticHosts(t *testing.T) { } { - domain := hosts.Lookup("proxy.xray.com", dns.IPOption{ + domain, _ := hosts.Lookup("proxy.xray.com", dns.IPOption{ IPv4Enable: true, IPv6Enable: false, }) @@ -80,7 +92,7 @@ func TestStaticHosts(t *testing.T) { } { - domain := hosts.Lookup("proxy2.xray.com", dns.IPOption{ + domain, _ := hosts.Lookup("proxy2.xray.com", dns.IPOption{ IPv4Enable: true, IPv6Enable: false, }) @@ -93,7 +105,7 @@ func TestStaticHosts(t *testing.T) { } { - ips := hosts.Lookup("www.example.cn", dns.IPOption{ + ips, _ := hosts.Lookup("www.example.cn", dns.IPOption{ IPv4Enable: true, IPv6Enable: true, }) @@ -106,7 +118,7 @@ func TestStaticHosts(t *testing.T) { } { - ips := hosts.Lookup("baidu.com", dns.IPOption{ + ips, _ := hosts.Lookup("baidu.com", dns.IPOption{ IPv4Enable: false, IPv6Enable: true, }) diff --git a/features/dns/client.go b/features/dns/client.go index 7dc576fb..db202e3d 100644 --- a/features/dns/client.go +++ b/features/dns/client.go @@ -46,6 +46,24 @@ func (e RCodeError) Error() string { return serial.Concat("rcode: ", uint16(e)) } +func (RCodeError) IP() net.IP { + panic("Calling IP() on a RCodeError.") +} + +func (RCodeError) Domain() string { + panic("Calling Domain() on a RCodeError.") +} + +func (RCodeError) Family() net.AddressFamily { + panic("Calling Family() on a RCodeError.") +} + +func (e RCodeError) String() string { + return e.Error() +} + +var _ net.Address = (*RCodeError)(nil) + func RCodeFromError(err error) uint16 { if err == nil { return 0