diff --git a/proxy/vless/encryption/client.go b/proxy/vless/encryption/client.go index 2faf745a..1ea2ecc5 100644 --- a/proxy/vless/encryption/client.go +++ b/proxy/vless/encryption/client.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "io" "net" + "strings" "sync" "time" @@ -26,13 +27,12 @@ func init() { type ClientInstance struct { sync.RWMutex - nfsEKey *mlkem.EncapsulationKey768 - nfsEKeySha256 [32]byte - xor uint32 - minutes time.Duration - expire time.Time - baseKey []byte - ticket []byte + nfsEKey *mlkem.EncapsulationKey768 + xorKey []byte + minutes time.Duration + expire time.Time + baseKey []byte + ticket []byte } type ClientConn struct { @@ -49,10 +49,17 @@ type ClientConn struct { } func (i *ClientInstance) Init(nfsEKeyBytes []byte, xor uint32, minutes time.Duration) (err error) { + if i.nfsEKey != nil { + err = errors.New("already initialized") + return + } i.nfsEKey, err = mlkem.NewEncapsulationKey768(nfsEKeyBytes) + if err != nil { + return + } if xor > 0 { - i.nfsEKeySha256 = sha256.Sum256(nfsEKeyBytes) - i.xor = xor + xorKey := sha256.Sum256(nfsEKeyBytes) + i.xorKey = xorKey[:] } i.minutes = minutes return @@ -62,8 +69,8 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) { if i.nfsEKey == nil { return nil, errors.New("uninitialized") } - if i.xor > 0 { - conn = NewXorConn(conn, i.nfsEKeySha256[:]) + if i.xorKey != nil { + conn = NewXorConn(conn, i.xorKey) } c := &ClientConn{Conn: conn} @@ -99,14 +106,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) { } // client can send more padding / NFS AEAD messages if needed - _, t, l, err := ReadAndDecodeHeader(c.Conn) + _, t, l, err := ReadAndDiscardPaddings(c.Conn) if err != nil { return nil, err } + if t != 1 { return nil, errors.New("unexpected type ", t, ", expect server hello") } - peerServerHello := make([]byte, 1088+21) if l != len(peerServerHello) { return nil, errors.New("unexpected length ", l, " for server hello") @@ -183,27 +190,9 @@ func (c *ClientConn) Read(b []byte) (int, error) { return 0, nil } if c.peerAead == nil { - var t byte - var l int - var err error - if c.instance == nil { // from 1-RTT - for { - if _, t, l, err = ReadAndDecodeHeader(c.Conn); err != nil { - return 0, err - } - if t != 23 { - break - } - if _, err := io.ReadFull(c.Conn, make([]byte, l)); err != nil { - return 0, err - } - } - } else { - h := make([]byte, 5) - if _, err := io.ReadFull(c.Conn, h); err != nil { - return 0, err - } - if t, l, err = DecodeHeader(h); err != nil { + _, t, l, err := ReadAndDiscardPaddings(c.Conn) + if err != nil { + if c.instance != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // from 0-RTT c.instance.Lock() if bytes.Equal(c.ticket, c.instance.ticket) { c.instance.expire = time.Now() // expired @@ -211,6 +200,7 @@ func (c *ClientConn) Read(b []byte) (int, error) { c.instance.Unlock() return 0, errors.New("new handshake needed") } + return 0, err } if t != 0 { return 0, errors.New("unexpected type ", t, ", expect random hello") diff --git a/proxy/vless/encryption/common.go b/proxy/vless/encryption/common.go index 2141f2d9..7f879dc1 100644 --- a/proxy/vless/encryption/common.go +++ b/proxy/vless/encryption/common.go @@ -44,10 +44,10 @@ func DecodeHeader(h []byte) (t byte, l int, err error) { } else if h[0] == 1 && h[1] == 1 && h[2] == 1 { t = 1 } else { - h = nil + l = 0 } - if h == nil || l < 17 || l > 17000 { // TODO: TLSv1.3 max length - err = errors.New("invalid header: ", fmt.Sprintf("%v", h[:5])) + if l < 17 || l > 17000 { // TODO: TLSv1.3 max length + err = errors.New("invalid header: ", fmt.Sprintf("%v", h[:5])) // relied by client's Read() } return } @@ -61,6 +61,17 @@ func ReadAndDecodeHeader(conn net.Conn) (h []byte, t byte, l int, err error) { return } +func ReadAndDiscardPaddings(conn net.Conn) (h []byte, t byte, l int, err error) { + for { + if h, t, l, err = ReadAndDecodeHeader(conn); err != nil || t != 23 { + return + } + if _, err = io.ReadFull(conn, make([]byte, l)); err != nil { + return + } + } +} + func NewAead(c byte, secret, salt, info []byte) (aead cipher.AEAD) { key := make([]byte, 32) hkdf.New(sha256.New, secret, salt, info).Read(key) diff --git a/proxy/vless/encryption/server.go b/proxy/vless/encryption/server.go index 49e0b9df..72346575 100644 --- a/proxy/vless/encryption/server.go +++ b/proxy/vless/encryption/server.go @@ -24,12 +24,11 @@ type ServerSession struct { type ServerInstance struct { sync.RWMutex - nfsDKey *mlkem.DecapsulationKey768 - nfsEKeySha256 [32]byte - xor uint32 - minutes time.Duration - sessions map[[21]byte]*ServerSession - closed bool + nfsDKey *mlkem.DecapsulationKey768 + xorKey []byte + minutes time.Duration + sessions map[[21]byte]*ServerSession + closed bool } type ServerConn struct { @@ -46,10 +45,17 @@ type ServerConn struct { } func (i *ServerInstance) Init(nfsDKeySeed []byte, xor uint32, minutes time.Duration) (err error) { + if i.nfsDKey != nil { + err = errors.New("already initialized") + return + } i.nfsDKey, err = mlkem.NewDecapsulationKey768(nfsDKeySeed) + if err != nil { + return + } if xor > 0 { - i.nfsEKeySha256 = sha256.Sum256(i.nfsDKey.EncapsulationKey().Bytes()) - i.xor = xor + xorKey := sha256.Sum256(i.nfsDKey.EncapsulationKey().Bytes()) + i.xorKey = xorKey[:] } if minutes > 0 { i.minutes = minutes @@ -86,18 +92,15 @@ func (i *ServerInstance) Handshake(conn net.Conn) (net.Conn, error) { if i.nfsDKey == nil { return nil, errors.New("uninitialized") } - if i.xor > 0 { - conn = NewXorConn(conn, i.nfsEKeySha256[:]) + if i.xorKey != nil { + conn = NewXorConn(conn, i.xorKey) } c := &ServerConn{Conn: conn} - _, t, l, err := ReadAndDecodeHeader(c.Conn) + _, t, l, err := ReadAndDiscardPaddings(c.Conn) if err != nil { return nil, err } - if t == 23 { - return nil, errors.New("unexpected data") - } if t == 0 { if i.minutes == 0 { @@ -187,19 +190,9 @@ func (c *ServerConn) Read(b []byte) (int, error) { } if c.peerAead == nil { if c.peerRandom == nil { // from 1-RTT - var t byte - var l int - var err error - for { - if _, t, l, err = ReadAndDecodeHeader(c.Conn); err != nil { - return 0, err - } - if t != 23 { - break - } - if _, err := io.ReadFull(c.Conn, make([]byte, l)); err != nil { - return 0, err - } + _, t, l, err := ReadAndDiscardPaddings(c.Conn) + if err != nil { + return 0, err } if t != 0 { return 0, errors.New("unexpected type ", t, ", expect ticket hello") diff --git a/proxy/vless/encryption/xor.go b/proxy/vless/encryption/xor.go index 696702bc..bbe489ef 100644 --- a/proxy/vless/encryption/xor.go +++ b/proxy/vless/encryption/xor.go @@ -18,7 +18,7 @@ type XorConn struct { } func NewXorConn(conn net.Conn, key []byte) *XorConn { - return &XorConn{Conn: conn, key: key[:16]} + return &XorConn{Conn: conn, key: key} //chacha20.NewUnauthenticatedCipher() }