diff --git a/common/common.go b/common/common.go index a09f6fbe..c3bfa944 100644 --- a/common/common.go +++ b/common/common.go @@ -23,7 +23,9 @@ func Must(err error) { } // Must2 panics if the second parameter is not nil, otherwise returns the first parameter. -func Must2(v interface{}, err error) interface{} { +// This is useful when function returned "sth, err" and avoid many "if err != nil" +// Internal usage only, if user input can cause err, it must be handled +func Must2[T any](v T, err error) T { Must(err) return v } diff --git a/common/crypto/aes.go b/common/crypto/aes.go index 3205a207..bbc974d9 100644 --- a/common/crypto/aes.go +++ b/common/crypto/aes.go @@ -32,9 +32,7 @@ func NewAesCTRStream(key []byte, iv []byte) cipher.Stream { // NewAesGcm creates a AEAD cipher based on AES-GCM. func NewAesGcm(key []byte) cipher.AEAD { - block, err := aes.NewCipher(key) - common.Must(err) - aead, err := cipher.NewGCM(block) - common.Must(err) + block := common.Must2(aes.NewCipher(key)) + aead := common.Must2(cipher.NewGCM(block)) return aead } diff --git a/common/crypto/auth_test.go b/common/crypto/auth_test.go index 6af8e0ad..82063852 100644 --- a/common/crypto/auth_test.go +++ b/common/crypto/auth_test.go @@ -18,11 +18,8 @@ import ( func TestAuthenticationReaderWriter(t *testing.T) { key := make([]byte, 16) rand.Read(key) - block, err := aes.NewCipher(key) - common.Must(err) - aead, err := cipher.NewGCM(block) - common.Must(err) + aead := NewAesGcm(key) const payloadSize = 1024 * 80 rawPayload := make([]byte, payloadSize) @@ -71,7 +68,7 @@ func TestAuthenticationReaderWriter(t *testing.T) { t.Error(r) } - _, err = reader.ReadMultiBuffer() + _, err := reader.ReadMultiBuffer() if err != io.EOF { t.Error("error: ", err) } @@ -80,11 +77,8 @@ func TestAuthenticationReaderWriter(t *testing.T) { func TestAuthenticationReaderWriterPacket(t *testing.T) { key := make([]byte, 16) common.Must2(rand.Read(key)) - block, err := aes.NewCipher(key) - common.Must(err) - aead, err := cipher.NewGCM(block) - common.Must(err) + aead := NewAesGcm(key) cache := buf.New() iv := make([]byte, 12) diff --git a/proxy/shadowsocks/config.go b/proxy/shadowsocks/config.go index a6d2ef87..391e050c 100644 --- a/proxy/shadowsocks/config.go +++ b/proxy/shadowsocks/config.go @@ -58,11 +58,7 @@ func (a *MemoryAccount) CheckIV(iv []byte) error { } func createAesGcm(key []byte) cipher.AEAD { - block, err := aes.NewCipher(key) - common.Must(err) - gcm, err := cipher.NewGCM(block) - common.Must(err) - return gcm + return crypto.NewAesGcm(key) } func createChaCha20Poly1305(key []byte) cipher.AEAD { diff --git a/proxy/vmess/aead/encrypt.go b/proxy/vmess/aead/encrypt.go index 8995f2ea..9b94c691 100644 --- a/proxy/vmess/aead/encrypt.go +++ b/proxy/vmess/aead/encrypt.go @@ -10,6 +10,7 @@ import ( "time" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/crypto" ) func SealVMessAEADHeader(key [16]byte, data []byte) []byte { @@ -34,15 +35,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte { payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12] - payloadHeaderLengthAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey) - if err != nil { - panic(err.Error()) - } - - payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderLengthAEADAESBlock) - if err != nil { - panic(err.Error()) - } + payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey) payloadHeaderLengthAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderLengthAEADNonce, aeadPayloadLengthSerializedByte, generatedAuthID[:]) } @@ -54,15 +47,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte { payloadHeaderAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12] - payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey) - if err != nil { - panic(err.Error()) - } - - payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock) - if err != nil { - panic(err.Error()) - } + payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey) payloadHeaderAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderAEADNonce, data, generatedAuthID[:]) } @@ -104,15 +89,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte, payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(authid[:]), string(nonce[:]))[:12] - payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey) - if err != nil { - panic(err.Error()) - } - - payloadHeaderLengthAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock) - if err != nil { - panic(err.Error()) - } + payloadHeaderLengthAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey) decryptedAEADHeaderLengthPayload, erropenAEAD := payloadHeaderLengthAEAD.Open(nil, payloadHeaderLengthAEADNonce, payloadHeaderLengthAEADEncrypted[:], authid[:]) @@ -145,15 +122,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte, return nil, false, bytesRead, err } - payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey) - if err != nil { - panic(err.Error()) - } - - payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock) - if err != nil { - panic(err.Error()) - } + payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey) decryptedAEADHeaderPayload, erropenAEAD := payloadHeaderAEAD.Open(nil, payloadHeaderAEADNonce, payloadHeaderAEADEncrypted, authid[:]) diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index d678646b..472d933c 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -182,8 +182,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey) aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12] - aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block) - aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD) + aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey) var aeadEncryptedResponseHeaderLength [18]byte var decryptedResponseHeaderLength int @@ -205,8 +204,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey) aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12] - aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block) - aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD) + aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey) encryptedResponseHeaderBuffer := make([]byte, decryptedResponseHeaderLength+16) diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 99e7abc9..328b09fe 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -350,8 +350,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey) aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12] - aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block) - aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD) + aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey) aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil) @@ -365,8 +364,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey) aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12] - aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block) - aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD) + aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey) aeadEncryptedHeaderPayload := aeadResponseHeaderPayloadEncryptionAEAD.Seal(nil, aeadResponseHeaderPayloadEncryptionIV, aeadEncryptedHeaderBuffer.Bytes(), nil) common.Must2(io.Copy(writer, bytes.NewReader(aeadEncryptedHeaderPayload))) diff --git a/transport/internet/kcp/cryptreal.go b/transport/internet/kcp/cryptreal.go index e86bba98..6131a2d9 100644 --- a/transport/internet/kcp/cryptreal.go +++ b/transport/internet/kcp/cryptreal.go @@ -6,10 +6,10 @@ import ( "crypto/sha256" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/crypto" ) func NewAEADAESGCMBasedOnSeed(seed string) cipher.AEAD { hashedSeed := sha256.Sum256([]byte(seed)) - aesBlock := common.Must2(aes.NewCipher(hashedSeed[:16])).(cipher.Block) - return common.Must2(cipher.NewGCM(aesBlock)).(cipher.AEAD) + return crypto.NewAesGcm(hashedSeed[:]) } diff --git a/transport/internet/reality/reality.go b/transport/internet/reality/reality.go index dca4e951..20f13ba5 100644 --- a/transport/internet/reality/reality.go +++ b/transport/internet/reality/reality.go @@ -3,8 +3,6 @@ package reality import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" "crypto/ecdh" "crypto/ed25519" "crypto/hmac" @@ -169,8 +167,7 @@ func UClient(c net.Conn, config *Config, ctx context.Context, dest net.Destinati if _, err := hkdf.New(sha256.New, uConn.AuthKey, hello.Random[:20], []byte("REALITY")).Read(uConn.AuthKey); err != nil { return nil, err } - block, _ := aes.NewCipher(uConn.AuthKey) - aead, _ := cipher.NewGCM(block) + aead := crypto.NewAesGcm(uConn.AuthKey) if config.Show { fmt.Printf("REALITY localAddr: %v\tuConn.AuthKey[:16]: %v\tAEAD: %T\n", localAddr, uConn.AuthKey[:16], aead) } diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index dee5f486..e409e61c 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -297,7 +297,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if transportConfiguration.DownloadSettings != nil { globalDialerAccess.Lock() if streamSettings.DownloadSettings == nil { - streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)).(*internet.MemoryStreamConfig) + streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)) if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.Penetrate { streamSettings.DownloadSettings.SocketSettings = streamSettings.SocketSettings }