DNS-Hosts: return only first matched-result

This commit is contained in:
patterniha 2025-05-07 23:04:17 +03:30
parent 59aa5e1b88
commit 253cd99122
2 changed files with 63 additions and 24 deletions

View File

@ -6,6 +6,7 @@ import (
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/common/strmatcher"
"github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/dns"
"sort"
) )
// StaticHosts represents static domain-ip mapping in DNS server. // StaticHosts represents static domain-ip mapping in DNS server.
@ -59,16 +60,14 @@ func filterIP(ips []net.Address, option dns.IPOption) []net.Address {
} }
func (h *StaticHosts) lookupInternal(domain string) []net.Address { func (h *StaticHosts) lookupInternal(domain string) []net.Address {
ips := make([]net.Address, 0) MatchSlice := h.matchers.Match(domain)
found := false sort.Slice(MatchSlice, func(i, j int) bool {
for _, id := range h.matchers.Match(domain) { return MatchSlice[i] < MatchSlice[j]
ips = append(ips, h.ips[id]...) })
found = true if len(MatchSlice) == 0 {
}
if !found {
return nil return nil
} }
return ips return h.ips[MatchSlice[0]]
} }
func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address { func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address {

View File

@ -1,8 +1,8 @@
package conf package conf
import ( import (
"bytes"
"encoding/json" "encoding/json"
"sort"
"strings" "strings"
"github.com/xtls/xray-core/app/dns" "github.com/xtls/xray-core/app/dns"
@ -192,6 +192,7 @@ func (h *HostAddress) UnmarshalJSON(data []byte) error {
} }
type HostsWrapper struct { type HostsWrapper struct {
Domains []string
Hosts map[string]*HostAddress Hosts map[string]*HostAddress
} }
@ -223,31 +224,70 @@ func getHostMapping(ha *HostAddress) *dns.Config_HostMapping {
// MarshalJSON implements encoding/json.Marshaler.MarshalJSON // MarshalJSON implements encoding/json.Marshaler.MarshalJSON
func (m *HostsWrapper) MarshalJSON() ([]byte, error) { func (m *HostsWrapper) MarshalJSON() ([]byte, error) {
return json.Marshal(m.Hosts) var buf bytes.Buffer
buf.WriteString("{")
for i, domain := range m.Domains {
if i > 0 {
buf.WriteString(",")
}
keyBytes, err := json.Marshal(domain)
if err != nil {
return nil, err
}
buf.Write(keyBytes)
buf.WriteString(":")
valueBytes, err := json.Marshal(m.Hosts[domain])
if err != nil {
return nil, err
}
buf.Write(valueBytes)
}
buf.WriteString("}")
return buf.Bytes(), nil
} }
// UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON // UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON
func (m *HostsWrapper) UnmarshalJSON(data []byte) error { func (m *HostsWrapper) UnmarshalJSON(data []byte) error {
hosts := make(map[string]*HostAddress) m.Hosts = make(map[string]*HostAddress)
err := json.Unmarshal(data, &hosts) m.Domains = []string{}
if err == nil {
m.Hosts = hosts var tempMap map[string]*HostAddress
return nil if err := json.Unmarshal(data, &tempMap); err != nil {
return err
} }
return errors.New("invalid DNS hosts").Base(err)
dec := json.NewDecoder(bytes.NewReader(data))
t, err := dec.Token()
if err != nil {
return err
}
if t != json.Delim('{') {
return errors.New("unexpected token")
}
for dec.More() {
key, err := dec.Token()
if err != nil {
return err
}
domain, ok := key.(string)
if !ok {
return errors.New("invalid key")
}
m.Domains = append(m.Domains, domain)
var ha *HostAddress
if err := dec.Decode(&ha); err != nil {
return err
}
m.Hosts[domain] = ha
}
return nil
} }
// Build implements Buildable // Build implements Buildable
func (m *HostsWrapper) Build() ([]*dns.Config_HostMapping, error) { func (m *HostsWrapper) Build() ([]*dns.Config_HostMapping, error) {
mappings := make([]*dns.Config_HostMapping, 0, 20) mappings := make([]*dns.Config_HostMapping, 0, 20)
domains := make([]string, 0, len(m.Hosts)) for _, domain := range m.Domains {
for domain := range m.Hosts {
domains = append(domains, domain)
}
sort.Strings(domains)
for _, domain := range domains {
switch { switch {
case strings.HasPrefix(domain, "domain:"): case strings.HasPrefix(domain, "domain:"):
domainName := domain[7:] domainName := domain[7:]