diff --git a/app/dns/hosts.go b/app/dns/hosts.go index f4e06dbb..01a052cc 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -6,6 +6,7 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/features/dns" + "sort" ) // StaticHosts represents static domain-ip mapping in DNS server. @@ -59,16 +60,14 @@ func filterIP(ips []net.Address, option dns.IPOption) []net.Address { } func (h *StaticHosts) lookupInternal(domain string) []net.Address { - ips := make([]net.Address, 0) - found := false - for _, id := range h.matchers.Match(domain) { - ips = append(ips, h.ips[id]...) - found = true - } - if !found { + MatchSlice := h.matchers.Match(domain) + sort.Slice(MatchSlice, func(i, j int) bool { + return MatchSlice[i] < MatchSlice[j] + }) + if len(MatchSlice) == 0 { return nil } - return ips + return h.ips[MatchSlice[0]] } func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address { diff --git a/infra/conf/dns.go b/infra/conf/dns.go index 7baeda87..5e870d1d 100644 --- a/infra/conf/dns.go +++ b/infra/conf/dns.go @@ -1,8 +1,8 @@ package conf import ( + "bytes" "encoding/json" - "sort" "strings" "github.com/xtls/xray-core/app/dns" @@ -192,7 +192,8 @@ func (h *HostAddress) UnmarshalJSON(data []byte) error { } type HostsWrapper struct { - Hosts map[string]*HostAddress + Domains []string + Hosts map[string]*HostAddress } func getHostMapping(ha *HostAddress) *dns.Config_HostMapping { @@ -223,31 +224,70 @@ func getHostMapping(ha *HostAddress) *dns.Config_HostMapping { // MarshalJSON implements encoding/json.Marshaler.MarshalJSON func (m *HostsWrapper) MarshalJSON() ([]byte, error) { - return json.Marshal(m.Hosts) + var buf bytes.Buffer + buf.WriteString("{") + for i, domain := range m.Domains { + if i > 0 { + buf.WriteString(",") + } + keyBytes, err := json.Marshal(domain) + if err != nil { + return nil, err + } + buf.Write(keyBytes) + buf.WriteString(":") + valueBytes, err := json.Marshal(m.Hosts[domain]) + if err != nil { + return nil, err + } + buf.Write(valueBytes) + } + buf.WriteString("}") + return buf.Bytes(), nil } // UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (m *HostsWrapper) UnmarshalJSON(data []byte) error { - hosts := make(map[string]*HostAddress) - err := json.Unmarshal(data, &hosts) - if err == nil { - m.Hosts = hosts - return nil + m.Hosts = make(map[string]*HostAddress) + m.Domains = []string{} + + var tempMap map[string]*HostAddress + if err := json.Unmarshal(data, &tempMap); err != nil { + return err } - return errors.New("invalid DNS hosts").Base(err) + + dec := json.NewDecoder(bytes.NewReader(data)) + t, err := dec.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("unexpected token") + } + for dec.More() { + key, err := dec.Token() + if err != nil { + return err + } + domain, ok := key.(string) + if !ok { + return errors.New("invalid key") + } + m.Domains = append(m.Domains, domain) + var ha *HostAddress + if err := dec.Decode(&ha); err != nil { + return err + } + m.Hosts[domain] = ha + } + return nil } // Build implements Buildable func (m *HostsWrapper) Build() ([]*dns.Config_HostMapping, error) { mappings := make([]*dns.Config_HostMapping, 0, 20) - domains := make([]string, 0, len(m.Hosts)) - for domain := range m.Hosts { - domains = append(domains, domain) - } - sort.Strings(domains) - - for _, domain := range domains { + for _, domain := range m.Domains { switch { case strings.HasPrefix(domain, "domain:"): domainName := domain[7:]