diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 94a9225b..e624fc9b 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -34,7 +34,7 @@ func ApplyECH(c *Config, config *tls.Config) error { if len(c.EchConfigList) != 0 { // direct base64 config if strings.Contains(c.EchConfigList, "://") { - // query config from dns + // query config from dns parts := strings.Split(c.EchConfigList, "+") if len(parts) == 2 { // parse ECH DNS server in format of "example.com+https://1.1.1.1/dns-query" @@ -71,62 +71,62 @@ func ApplyECH(c *Config, config *tls.Config) error { } config.EncryptedClientHelloKeys = KeySets } - + return nil } -type record struct { - echConfig []byte - expire time.Time +type ECHConfigCache struct { + echConfig []byte + expire time.Time + updateLock sync.Mutex } var ( - dnsCache sync.Map - // global Lock? I'm not sure if this needs finer grained locks. - // If we do this, we will need to nest another layer of struct - updating sync.Mutex + GlobalECHConfigCache map[string]*ECHConfigCache + GlobalECHConfigCacheAccess sync.Mutex ) // QueryRecord returns the ECH config for given domain. // If the record is not in cache or expired, it will query the DNS server and update the cache. func QueryRecord(domain string, server string) ([]byte, error) { - val, found := dnsCache.Load(domain) - rec, _ := val.(record) - if found && rec.expire.After(time.Now()) { + // Global cache init + GlobalECHConfigCacheAccess.Lock() + if GlobalECHConfigCache == nil { + GlobalECHConfigCache = make(map[string]*ECHConfigCache) + } + + echConfigCache := GlobalECHConfigCache[domain] + if echConfigCache != nil && echConfigCache.expire.After(time.Now()) { errors.LogDebug(context.Background(), "Cache hit for domain: ", domain) - return rec.echConfig, nil + GlobalECHConfigCacheAccess.Unlock() + return echConfigCache.echConfig, nil } - - updating.Lock() - defer updating.Unlock() - // Try to get cache again after lock, in case another goroutine has updated it - // This might happen when the core tring is just stared and multiple goroutines are trying to query the same domain - val, found = dnsCache.Load(domain) - rec, _ = val.(record) - if found && rec.expire.After(time.Now()) { - errors.LogDebug(context.Background(), "ECH Config cache hit for domain: ", domain, " after trying to get update lock") - return rec.echConfig, nil + if echConfigCache == nil { + echConfigCache = &ECHConfigCache{} + GlobalECHConfigCache[domain] = echConfigCache } + GlobalECHConfigCacheAccess.Unlock() + echConfigCache.updateLock.Lock() + defer echConfigCache.updateLock.Unlock() + // Double check cache after acquiring lock + if echConfigCache.expire.After(time.Now()) { + errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain) + return echConfigCache.echConfig, nil + } // Query ECH config from DNS server errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server) echConfig, ttl, err := dnsQuery(server, domain) if err != nil { - return []byte{}, err + return nil, err } - // Set minimum TTL to 600 seconds if ttl < 600 { ttl = 600 } - - // Update cache - newRecored := record{ - echConfig: echConfig, - expire: time.Now().Add(time.Second * time.Duration(ttl)), - } - dnsCache.Store(domain, newRecored) - return echConfig, nil + echConfigCache.echConfig = echConfig + echConfigCache.expire = time.Now().Add(time.Second * time.Duration(ttl)) + return echConfigCache.echConfig, nil } // dnsQuery is the real func for sending type65 query for given domain to given DNS server. @@ -301,10 +301,10 @@ const KDF_HKDF_SHA512 = 0x0003 func GenerateECHKeySet(configID uint8, domain string, kem uint16) (reality.EchConfig, []byte, error) { config := reality.EchConfig{ - Version: ExtensionEncryptedClientHello, - ConfigID: configID, - PublicName: []byte(domain), - KemID: kem, + Version: ExtensionEncryptedClientHello, + ConfigID: configID, + PublicName: []byte(domain), + KemID: kem, SymmetricCipherSuite: []reality.EchCipher{ {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_128_GCM}, {KDFID: hpke.KDF_HKDF_SHA256, AEADID: hpke.AEAD_AES_256_GCM}, @@ -316,19 +316,19 @@ func GenerateECHKeySet(configID uint8, domain string, kem uint16) (reality.EchCo {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_AES_256_GCM}, {KDFID: KDF_HKDF_SHA512, AEADID: hpke.AEAD_ChaCha20Poly1305}, }, - MaxNameLength: 0, - Extensions: nil, + MaxNameLength: 0, + Extensions: nil, } // if kem == hpke.DHKEM_X25519_HKDF_SHA256 { - curve := ecdh.X25519() - priv := make([]byte, 32) //x25519 - _, err := io.ReadFull(rand.Reader, priv) - if err != nil { - return config, nil, err - } - privKey, _ := curve.NewPrivateKey(priv) - config.PublicKey = privKey.PublicKey().Bytes(); - return config, priv, nil + curve := ecdh.X25519() + priv := make([]byte, 32) //x25519 + _, err := io.ReadFull(rand.Reader, priv) + if err != nil { + return config, nil, err + } + privKey, _ := curve.NewPrivateKey(priv) + config.PublicKey = privKey.PublicKey().Bytes() + return config, priv, nil // } // TODO: add mlkem768 (former kyber768 draft00). The golang mlkem private key is 64 bytes seed? }