diff --git a/infra/conf/dns.go b/infra/conf/dns.go index e1071781..d532a781 100644 --- a/infra/conf/dns.go +++ b/infra/conf/dns.go @@ -1,7 +1,11 @@ package conf import ( + "bufio" "encoding/json" + "os" + "path/filepath" + "runtime" "sort" "strings" @@ -192,6 +196,7 @@ type DNSConfig struct { DisableCache bool `json:"disableCache"` DisableFallback bool `json:"disableFallback"` DisableFallbackIfMatch bool `json:"disableFallbackIfMatch"` + UseSystemHosts bool `json:"useSystemHosts"` } type HostAddress struct { @@ -413,6 +418,15 @@ func (c *DNSConfig) Build() (*dns.Config, error) { } config.StaticHosts = append(config.StaticHosts, staticHosts...) } + if c.UseSystemHosts { + systemHosts, err := readSystemHosts() + if err != nil { + return nil, errors.New("failed to read system hosts").Base(err) + } + for domain, ips := range systemHosts { + config.StaticHosts = append(config.StaticHosts, &dns.Config_HostMapping{Ip: ips, Domain: domain, Type: dns.DomainMatchingType_Full}) + } + } return config, nil } @@ -431,3 +445,85 @@ func resolveQueryStrategy(queryStrategy string) dns.QueryStrategy { return dns.QueryStrategy_USE_IP } } + +func readSystemHosts() (map[string][][]byte, error) { + var hostsPath string + switch runtime.GOOS { + case "windows": + hostsPath = filepath.Join(os.Getenv("SystemRoot"), "System32", "drivers", "etc", "hosts") + default: + hostsPath = "/etc/hosts" + } + + file, err := os.Open(hostsPath) + if err != nil { + return nil, err + } + defer file.Close() + + hostsMap := make(map[string][][]byte) + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if i := strings.IndexByte(line, '#'); i >= 0 { + // Discard comments. + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 { + continue + } + addr := net.ParseAddress(f[0]) + if addr.Family().IsDomain() { + continue + } + ip := addr.IP() + for i := 1; i < len(f); i++ { + domain := strings.TrimSuffix(f[i], ".") + domain = strings.ToLower(domain) + if v, ok := hostsMap[domain]; ok { + hostsMap[domain] = append(v, ip) + } else { + hostsMap[domain] = [][]byte{ip} + } + } + } + if err := scanner.Err(); err != nil { + return nil, err + } + return hostsMap, nil +} + +func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") } + +// Count occurrences in s of any bytes in t. +func countAnyByte(s string, t string) int { + n := 0 + for i := 0; i < len(s); i++ { + if strings.IndexByte(t, s[i]) >= 0 { + n++ + } + } + return n +} + +// Split s at any bytes in t. +func splitAtBytes(s string, t string) []string { + a := make([]string, 1+countAnyByte(s, t)) + n := 0 + last := 0 + for i := 0; i < len(s); i++ { + if strings.IndexByte(t, s[i]) >= 0 { + if last < i { + a[n] = s[last:i] + n++ + } + last = i + 1 + } + } + if last < len(s) { + a[n] = s[last:] + n++ + } + return a[0:n] +}