add useSystemHosts

This commit is contained in:
patterniha 2025-05-07 06:10:58 +03:30
parent 67d0a2df71
commit c564abafec

View File

@ -1,7 +1,11 @@
package conf package conf
import ( import (
"bufio"
"encoding/json" "encoding/json"
"os"
"path/filepath"
"runtime"
"sort" "sort"
"strings" "strings"
@ -192,6 +196,7 @@ type DNSConfig struct {
DisableCache bool `json:"disableCache"` DisableCache bool `json:"disableCache"`
DisableFallback bool `json:"disableFallback"` DisableFallback bool `json:"disableFallback"`
DisableFallbackIfMatch bool `json:"disableFallbackIfMatch"` DisableFallbackIfMatch bool `json:"disableFallbackIfMatch"`
UseSystemHosts bool `json:"useSystemHosts"`
} }
type HostAddress struct { type HostAddress struct {
@ -413,6 +418,15 @@ func (c *DNSConfig) Build() (*dns.Config, error) {
} }
config.StaticHosts = append(config.StaticHosts, staticHosts...) config.StaticHosts = append(config.StaticHosts, staticHosts...)
} }
if c.UseSystemHosts {
systemHosts, err := readSystemHosts()
if err != nil {
return nil, errors.New("failed to read system hosts").Base(err)
}
for domain, ips := range systemHosts {
config.StaticHosts = append(config.StaticHosts, &dns.Config_HostMapping{Ip: ips, Domain: domain, Type: dns.DomainMatchingType_Full})
}
}
return config, nil return config, nil
} }
@ -431,3 +445,85 @@ func resolveQueryStrategy(queryStrategy string) dns.QueryStrategy {
return dns.QueryStrategy_USE_IP return dns.QueryStrategy_USE_IP
} }
} }
func readSystemHosts() (map[string][][]byte, error) {
var hostsPath string
switch runtime.GOOS {
case "windows":
hostsPath = filepath.Join(os.Getenv("SystemRoot"), "System32", "drivers", "etc", "hosts")
default:
hostsPath = "/etc/hosts"
}
file, err := os.Open(hostsPath)
if err != nil {
return nil, err
}
defer file.Close()
hostsMap := make(map[string][][]byte)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if i := strings.IndexByte(line, '#'); i >= 0 {
// Discard comments.
line = line[0:i]
}
f := getFields(line)
if len(f) < 2 {
continue
}
addr := net.ParseAddress(f[0])
if addr.Family().IsDomain() {
continue
}
ip := addr.IP()
for i := 1; i < len(f); i++ {
domain := strings.TrimSuffix(f[i], ".")
domain = strings.ToLower(domain)
if v, ok := hostsMap[domain]; ok {
hostsMap[domain] = append(v, ip)
} else {
hostsMap[domain] = [][]byte{ip}
}
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
return hostsMap, nil
}
func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") }
// Count occurrences in s of any bytes in t.
func countAnyByte(s string, t string) int {
n := 0
for i := 0; i < len(s); i++ {
if strings.IndexByte(t, s[i]) >= 0 {
n++
}
}
return n
}
// Split s at any bytes in t.
func splitAtBytes(s string, t string) []string {
a := make([]string, 1+countAnyByte(s, t))
n := 0
last := 0
for i := 0; i < len(s); i++ {
if strings.IndexByte(t, s[i]) >= 0 {
if last < i {
a[n] = s[last:i]
n++
}
last = i + 1
}
}
if last < len(s) {
a[n] = s[last:]
n++
}
return a[0:n]
}