Allow concurrent DNS queries

This commit is contained in:
风扇滑翔翼 2025-06-28 15:39:42 +00:00 committed by GitHub
parent 27af360726
commit 54774ceca6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,7 +34,7 @@ func ApplyECH(c *Config, config *tls.Config) error {
if len(c.EchConfigList) != 0 { if len(c.EchConfigList) != 0 {
// direct base64 config // direct base64 config
if strings.Contains(c.EchConfigList, "://") { if strings.Contains(c.EchConfigList, "://") {
// query config from dns // query config from dns
parts := strings.Split(c.EchConfigList, "+") parts := strings.Split(c.EchConfigList, "+")
if len(parts) == 2 { if len(parts) == 2 {
// parse ECH DNS server in format of "example.com+https://1.1.1.1/dns-query" // parse ECH DNS server in format of "example.com+https://1.1.1.1/dns-query"
@ -71,62 +71,62 @@ func ApplyECH(c *Config, config *tls.Config) error {
} }
config.EncryptedClientHelloKeys = KeySets config.EncryptedClientHelloKeys = KeySets
} }
return nil return nil
} }
type record struct { type ECHConfigCache struct {
echConfig []byte echConfig []byte
expire time.Time expire time.Time
updateLock sync.Mutex
} }
var ( var (
dnsCache sync.Map GlobalECHConfigCache map[string]*ECHConfigCache
// global Lock? I'm not sure if this needs finer grained locks. GlobalECHConfigCacheAccess sync.Mutex
// If we do this, we will need to nest another layer of struct
updating sync.Mutex
) )
// QueryRecord returns the ECH config for given domain. // QueryRecord returns the ECH config for given domain.
// If the record is not in cache or expired, it will query the DNS server and update the cache. // If the record is not in cache or expired, it will query the DNS server and update the cache.
func QueryRecord(domain string, server string) ([]byte, error) { func QueryRecord(domain string, server string) ([]byte, error) {
val, found := dnsCache.Load(domain) // Global cache init
rec, _ := val.(record) GlobalECHConfigCacheAccess.Lock()
if found && rec.expire.After(time.Now()) { if GlobalECHConfigCache == nil {
GlobalECHConfigCache = make(map[string]*ECHConfigCache)
}
echConfigCache := GlobalECHConfigCache[domain]
if echConfigCache != nil && echConfigCache.expire.After(time.Now()) {
errors.LogDebug(context.Background(), "Cache hit for domain: ", domain) errors.LogDebug(context.Background(), "Cache hit for domain: ", domain)
return rec.echConfig, nil GlobalECHConfigCacheAccess.Unlock()
return echConfigCache.echConfig, nil
} }
if echConfigCache == nil {
updating.Lock() echConfigCache = &ECHConfigCache{}
defer updating.Unlock() GlobalECHConfigCache[domain] = echConfigCache
// Try to get cache again after lock, in case another goroutine has updated it
// This might happen when the core tring is just stared and multiple goroutines are trying to query the same domain
val, found = dnsCache.Load(domain)
rec, _ = val.(record)
if found && rec.expire.After(time.Now()) {
errors.LogDebug(context.Background(), "ECH Config cache hit for domain: ", domain, " after trying to get update lock")
return rec.echConfig, nil
} }
GlobalECHConfigCacheAccess.Unlock()
echConfigCache.updateLock.Lock()
defer echConfigCache.updateLock.Unlock()
// Double check cache after acquiring lock
if echConfigCache.expire.After(time.Now()) {
errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain)
return echConfigCache.echConfig, nil
}
// Query ECH config from DNS server // Query ECH config from DNS server
errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server) errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server)
echConfig, ttl, err := dnsQuery(server, domain) echConfig, ttl, err := dnsQuery(server, domain)
if err != nil { if err != nil {
return []byte{}, err return nil, err
} }
// Set minimum TTL to 600 seconds // Set minimum TTL to 600 seconds
if ttl < 600 { if ttl < 600 {
ttl = 600 ttl = 600
} }
echConfigCache.echConfig = echConfig
// Update cache echConfigCache.expire = time.Now().Add(time.Second * time.Duration(ttl))
newRecored := record{ return echConfigCache.echConfig, nil
echConfig: echConfig,
expire: time.Now().Add(time.Second * time.Duration(ttl)),
}
dnsCache.Store(domain, newRecored)
return echConfig, nil
} }
// dnsQuery is the real func for sending type65 query for given domain to given DNS server. // dnsQuery is the real func for sending type65 query for given domain to given DNS server.
@ -301,10 +301,10 @@ const KDF_HKDF_SHA512 = 0x0003
func GenerateECHKeySet(configID uint8, domain string, kem uint16) (reality.EchConfig, []byte, error) { func GenerateECHKeySet(configID uint8, domain string, kem uint16) (reality.EchConfig, []byte, error) {
config := reality.EchConfig{ config := reality.EchConfig{
Version: ExtensionEncryptedClientHello, Version: ExtensionEncryptedClientHello,
ConfigID: configID, ConfigID: configID,
PublicName: []byte(domain), PublicName: []byte(domain),
KemID: kem, KemID: kem,
SymmetricCipherSuite: []reality.EchCipher{ SymmetricCipherSuite: []reality.EchCipher{
{KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_128_GCM}, {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_128_GCM},
{KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_256_GCM}, {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_256_GCM},
@ -316,19 +316,19 @@ func GenerateECHKeySet(configID uint8, domain string, kem uint16) (reality.EchCo
{KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_AES_256_GCM}, {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_AES_256_GCM},
{KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_ChaCha20Poly1305}, {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_ChaCha20Poly1305},
}, },
MaxNameLength: 0, MaxNameLength: 0,
Extensions: nil, Extensions: nil,
} }
// if kem == hpke.DHKEM_X25519_HKDF_SHA256 { // if kem == hpke.DHKEM_X25519_HKDF_SHA256 {
curve := ecdh.X25519() curve := ecdh.X25519()
priv := make([]byte, 32) //x25519 priv := make([]byte, 32) //x25519
_, err := io.ReadFull(rand.Reader, priv) _, err := io.ReadFull(rand.Reader, priv)
if err != nil { if err != nil {
return config, nil, err return config, nil, err
} }
privKey, _ := curve.NewPrivateKey(priv) privKey, _ := curve.NewPrivateKey(priv)
config.PublicKey = privKey.PublicKey().Bytes(); config.PublicKey = privKey.PublicKey().Bytes()
return config, priv, nil return config, priv, nil
// } // }
// TODO: add mlkem768 (former kyber768 draft00). The golang mlkem private key is 64 bytes seed? // TODO: add mlkem768 (former kyber768 draft00). The golang mlkem private key is 64 bytes seed?
} }