diff --git a/proxy/vless/encryption/client.go b/proxy/vless/encryption/client.go index 04360e2b..52b4828b 100644 --- a/proxy/vless/encryption/client.go +++ b/proxy/vless/encryption/client.go @@ -85,19 +85,31 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) { nfsKey, encapsulatedNfsKey := i.nfsEKey.Encapsulate() paddingLen := crypto.RandBetween(100, 1000) - clientHello := make([]byte, 1+1184+1088+5+paddingLen) - clientHello[0] = ClientCipher - copy(clientHello[1:], pfsEKeyBytes) - copy(clientHello[1185:], encapsulatedNfsKey) - EncodeHeader(clientHello[2273:], int(paddingLen)) - rand.Read(clientHello[2278:]) + clientHello := make([]byte, 5+1+1184+1088+5+paddingLen) + EncodeHeader(clientHello, 1, 1+1184+1088) + clientHello[5] = ClientCipher + copy(clientHello[5+1:], pfsEKeyBytes) + copy(clientHello[5+1+1184:], encapsulatedNfsKey) + EncodeHeader(clientHello[5+1+1184+1088:], 23, int(paddingLen)) + rand.Read(clientHello[5+1+1184+1088+5:]) - if _, err := c.Conn.Write(clientHello); err != nil { + if n, err := c.Conn.Write(clientHello); n != len(clientHello) || err != nil { return nil, err } - // we can send more padding if needed + // client can send more padding / NFS AEAD messages if needed + + _, t, l, err := ReadAndDecodeHeader(c.Conn) + if err != nil { + return nil, err + } + if t != 1 { + return nil, errors.New("unexpected type ", t, ", expect server hello") + } peerServerHello := make([]byte, 1088+21) + if l != len(peerServerHello) { + return nil, errors.New("unexpected length ", l, " for server hello") + } if _, err := io.ReadFull(c.Conn, peerServerHello); err != nil { return nil, err } @@ -112,7 +124,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) { nonce := [12]byte{ClientCipher} VLESS, _ := NewAead(ClientCipher, c.baseKey, encapsulatedPfsKey, encapsulatedNfsKey).Open(nil, nonce[:], c.ticket, pfsEKeyBytes) - if !bytes.Equal(VLESS, []byte("VLESS")) { // TODO: more messages + if !bytes.Equal(VLESS, []byte("VLESS")) { return nil, errors.New("invalid server").AtError() } @@ -143,21 +155,22 @@ func (c *ClientConn) Write(b []byte) (int, error) { rand.Read(c.random) c.aead = NewAead(ClientCipher, c.baseKey, c.random, c.ticket) c.nonce = make([]byte, 12) - data = make([]byte, 21+32+5+len(b)+16) - copy(data, c.ticket) - copy(data[21:], c.random) - EncodeHeader(data[53:], len(b)+16) - c.aead.Seal(data[:58], c.nonce, b, data[53:58]) + data = make([]byte, 5+21+32+5+len(b)+16) + EncodeHeader(data, 0, 21+32) + copy(data[5:], c.ticket) + copy(data[5+21:], c.random) + EncodeHeader(data[5+21+32:], 23, len(b)+16) + c.aead.Seal(data[:5+21+32+5], c.nonce, b, data[5+21+32:5+21+32+5]) } else { data = make([]byte, 5+len(b)+16) - EncodeHeader(data, len(b)+16) + EncodeHeader(data, 23, len(b)+16) c.aead.Seal(data[:5], c.nonce, b, data[:5]) if bytes.Equal(c.nonce, MaxNonce) { c.aead = NewAead(ClientCipher, c.baseKey, data[5:], data[:5]) } } IncreaseNonce(c.nonce) - if _, err := c.Conn.Write(data); err != nil { + if n, err := c.Conn.Write(data); n != len(data) || err != nil { return 0, err } } @@ -168,29 +181,44 @@ func (c *ClientConn) Read(b []byte) (int, error) { if len(b) == 0 { return 0, nil } - peerHeader := make([]byte, 5) if c.peerAead == nil { - if c.instance == nil { + var t byte + var l int + var err error + if c.instance == nil { // 1-RTT for { - if _, err := io.ReadFull(c.Conn, peerHeader); err != nil { + if _, t, l, err = ReadAndDecodeHeader(c.Conn); err != nil { return 0, err } - peerPaddingLen, _ := DecodeHeader(peerHeader) - if peerPaddingLen == 0 { + if t != 23 { break } - if _, err := io.ReadFull(c.Conn, make([]byte, peerPaddingLen)); err != nil { + if _, err := io.ReadFull(c.Conn, make([]byte, l)); err != nil { return 0, err } } } else { - if _, err := io.ReadFull(c.Conn, peerHeader); err != nil { + h := make([]byte, 5) + if _, err := io.ReadFull(c.Conn, h); err != nil { return 0, err } + if t, l, err = DecodeHeader(h); err != nil { + c.instance.Lock() + if bytes.Equal(c.ticket, c.instance.ticket) { + c.instance.expire = time.Now() // expired + } + c.instance.Unlock() + return 0, errors.New("new handshake needed") + } + } + if t != 0 { + return 0, errors.New("unexpected type ", t, ", expect server random") } peerRandom := make([]byte, 32) - copy(peerRandom, peerHeader) - if _, err := io.ReadFull(c.Conn, peerRandom[5:]); err != nil { + if l != len(peerRandom) { + return 0, errors.New("unexpected length ", l, " for server random") + } + if _, err := io.ReadFull(c.Conn, peerRandom); err != nil { return 0, err } if c.random == nil { @@ -204,33 +232,26 @@ func (c *ClientConn) Read(b []byte) (int, error) { c.peerCache = c.peerCache[n:] return n, nil } - if _, err := io.ReadFull(c.Conn, peerHeader); err != nil { - return 0, err - } - peerLength, err := DecodeHeader(peerHeader) // 17~17000 + h, t, l, err := ReadAndDecodeHeader(c.Conn) // l: 17~17000 if err != nil { - if c.instance != nil { - c.instance.Lock() - if bytes.Equal(c.ticket, c.instance.ticket) { - c.instance.expire = time.Now() // expired - } - c.instance.Unlock() - } return 0, err } - peerData := make([]byte, peerLength) + if t != 23 { + return 0, errors.New("unexpected type ", t, ", expect encrypted data") + } + peerData := make([]byte, l) if _, err := io.ReadFull(c.Conn, peerData); err != nil { return 0, err } - dst := peerData[:peerLength-16] + dst := peerData[:l-16] if len(dst) <= len(b) { dst = b[:len(dst)] // avoids another copy() } var peerAead cipher.AEAD if bytes.Equal(c.peerNonce, MaxNonce) { - peerAead = NewAead(ClientCipher, c.baseKey, peerData, peerHeader) + peerAead = NewAead(ClientCipher, c.baseKey, peerData, h) } - _, err = c.peerAead.Open(dst[:0], c.peerNonce, peerData, peerHeader) + _, err = c.peerAead.Open(dst[:0], c.peerNonce, peerData, h) if peerAead != nil { c.peerAead = peerAead } diff --git a/proxy/vless/encryption/common.go b/proxy/vless/encryption/common.go index 0cd23e16..2141f2d9 100644 --- a/proxy/vless/encryption/common.go +++ b/proxy/vless/encryption/common.go @@ -5,7 +5,9 @@ import ( "crypto/aes" "crypto/cipher" "crypto/sha256" - "strconv" + "fmt" + "io" + "net" "github.com/xtls/xray-core/common/errors" "golang.org/x/crypto/chacha20poly1305" @@ -14,23 +16,49 @@ import ( var MaxNonce = bytes.Repeat([]byte{255}, 12) -func EncodeHeader(b []byte, l int) { - b[0] = 23 - b[1] = 3 - b[2] = 3 - b[3] = byte(l >> 8) - b[4] = byte(l) +func EncodeHeader(h []byte, t byte, l int) { + switch t { + case 1: + h[0] = 1 + h[1] = 1 + h[2] = 1 + case 0: + h[0] = 0 + h[1] = 0 + h[2] = 0 + case 23: + h[0] = 23 + h[1] = 3 + h[2] = 3 + } + h[3] = byte(l >> 8) + h[4] = byte(l) } -func DecodeHeader(b []byte) (int, error) { - if b[0] == 23 && b[1] == 3 && b[2] == 3 { - l := int(b[3])<<8 | int(b[4]) - if l < 17 || l > 17000 { // TODO: TLSv1.3 max length - return 0, errors.New("invalid length in record's header: " + strconv.Itoa(l)) - } - return l, nil +func DecodeHeader(h []byte) (t byte, l int, err error) { + l = int(h[3])<<8 | int(h[4]) + if h[0] == 23 && h[1] == 3 && h[2] == 3 { + t = 23 + } else if h[0] == 0 && h[1] == 0 && h[2] == 0 { + t = 0 + } else if h[0] == 1 && h[1] == 1 && h[2] == 1 { + t = 1 + } else { + h = nil } - return 0, errors.New("invalid record's header") + if h == nil || l < 17 || l > 17000 { // TODO: TLSv1.3 max length + err = errors.New("invalid header: ", fmt.Sprintf("%v", h[:5])) + } + return +} + +func ReadAndDecodeHeader(conn net.Conn) (h []byte, t byte, l int, err error) { + h = make([]byte, 5) + if _, err = io.ReadFull(conn, h); err != nil { + return + } + t, l, err = DecodeHeader(h) + return } func NewAead(c byte, secret, salt, info []byte) (aead cipher.AEAD) { diff --git a/proxy/vless/encryption/server.go b/proxy/vless/encryption/server.go index 1cf5b2f9..71aed4a2 100644 --- a/proxy/vless/encryption/server.go +++ b/proxy/vless/encryption/server.go @@ -56,12 +56,12 @@ func (i *ServerInstance) Init(nfsDKeySeed []byte, xor uint32, minutes time.Durat go func() { for { time.Sleep(time.Minute) - now := time.Now() i.Lock() if i.closed { i.Unlock() return } + now := time.Now() for ticket, session := range i.sessions { if now.After(session.expire) { delete(i.sessions, ticket) @@ -90,40 +90,49 @@ func (i *ServerInstance) Handshake(conn net.Conn) (net.Conn, error) { } c := &ServerConn{Conn: conn} - peerTicketHello := make([]byte, 21+32) - if _, err := io.ReadFull(c.Conn, peerTicketHello); err != nil { + _, t, l, err := ReadAndDecodeHeader(c.Conn) + if err != nil { return nil, err } - if i.minutes > 0 { + if t == 23 { + return nil, errors.New("unexpected data") + } + + if t == 0 { + if i.minutes == 0 { + return nil, errors.New("0-RTT is not allowed") + } + peerTicketHello := make([]byte, 21+32) + if l != len(peerTicketHello) { + return nil, errors.New("unexpected length ", l, " for ticket hello") + } + if _, err := io.ReadFull(c.Conn, peerTicketHello); err != nil { + return nil, err + } i.RLock() s := i.sessions[[21]byte(peerTicketHello)] i.RUnlock() - if s != nil { - if _, replay := s.randoms.LoadOrStore([32]byte(peerTicketHello[21:]), true); !replay { - c.cipher = s.cipher - c.baseKey = s.baseKey - c.ticket = peerTicketHello[:21] - c.peerRandom = peerTicketHello[21:] - return c, nil - } + if s == nil { + noise := make([]byte, crypto.RandBetween(100, 1000)) + rand.Read(noise) + c.Conn.Write(noise) // make client do new handshake + return nil, errors.New("expired ticket") } - } - - peerHeader := make([]byte, 5) - if _, err := io.ReadFull(c.Conn, peerHeader); err != nil { - return nil, err - } - if l, _ := DecodeHeader(peerHeader); l != 0 { - noise := make([]byte, crypto.RandBetween(100, 1000)) - rand.Read(noise) - c.Conn.Write(noise) // make client do new handshake - return nil, errors.New("invalid ticket") + if _, replay := s.randoms.LoadOrStore([32]byte(peerTicketHello[21:]), true); replay { + return nil, errors.New("replay detected") + } + c.cipher = s.cipher + c.baseKey = s.baseKey + c.ticket = peerTicketHello[:21] + c.peerRandom = peerTicketHello[21:] + return c, nil } peerClientHello := make([]byte, 1+1184+1088) - copy(peerClientHello, peerTicketHello) - copy(peerClientHello[53:], peerHeader) - if _, err := io.ReadFull(c.Conn, peerClientHello[58:]); err != nil { + if l != len(peerClientHello) { + return nil, errors.New("unexpected length ", l, " for client hello") + } + if _, err := io.ReadFull(c.Conn, peerClientHello); err != nil { return nil, err } c.cipher = peerClientHello[0] @@ -146,16 +155,17 @@ func (i *ServerInstance) Handshake(conn net.Conn) (net.Conn, error) { paddingLen := crypto.RandBetween(100, 1000) - serverHello := make([]byte, 1088+21+5+paddingLen) - copy(serverHello, encapsulatedPfsKey) - copy(serverHello[1088:], c.ticket) - EncodeHeader(serverHello[1109:], int(paddingLen)) - rand.Read(serverHello[1114:]) + serverHello := make([]byte, 5+1088+21+5+paddingLen) + EncodeHeader(serverHello, 1, 1088+21) + copy(serverHello[5:], encapsulatedPfsKey) + copy(serverHello[5+1088:], c.ticket) + EncodeHeader(serverHello[5+1088+21:], 23, int(paddingLen)) + rand.Read(serverHello[5+1088+21+5:]) - if _, err := c.Conn.Write(serverHello); err != nil { + if n, err := c.Conn.Write(serverHello); n != len(serverHello) || err != nil { return nil, err } - // we can send more padding if needed + // server can send more padding / PFS AEAD messages if needed if i.minutes > 0 { i.Lock() @@ -174,24 +184,30 @@ func (c *ServerConn) Read(b []byte) (int, error) { if len(b) == 0 { return 0, nil } - peerHeader := make([]byte, 5) if c.peerAead == nil { - if c.peerRandom == nil { + if c.peerRandom == nil { // 1-RTT + var t byte + var l int + var err error for { - if _, err := io.ReadFull(c.Conn, peerHeader); err != nil { + if _, t, l, err = ReadAndDecodeHeader(c.Conn); err != nil { return 0, err } - peerPaddingLen, _ := DecodeHeader(peerHeader) - if peerPaddingLen == 0 { + if t != 23 { break } - if _, err := io.ReadFull(c.Conn, make([]byte, peerPaddingLen)); err != nil { + if _, err := io.ReadFull(c.Conn, make([]byte, l)); err != nil { return 0, err } } + if t != 0 { + return 0, errors.New("unexpected type ", t, ", expect ticket hello") + } peerTicket := make([]byte, 21) - copy(peerTicket, peerHeader) - if _, err := io.ReadFull(c.Conn, peerTicket[5:]); err != nil { + if l != len(peerTicket) { + return 0, errors.New("unexpected length ", l, " for ticket hello") + } + if _, err := io.ReadFull(c.Conn, peerTicket); err != nil { return 0, err } if !bytes.Equal(peerTicket, c.ticket) { @@ -210,32 +226,32 @@ func (c *ServerConn) Read(b []byte) (int, error) { c.peerCache = c.peerCache[n:] return n, nil } - if _, err := io.ReadFull(c.Conn, peerHeader); err != nil { - return 0, err - } - peerLength, err := DecodeHeader(peerHeader) // 17~17000 + h, t, l, err := ReadAndDecodeHeader(c.Conn) // l: 17~17000 if err != nil { return 0, err } - peerData := make([]byte, peerLength) + if t != 23 { + return 0, errors.New("unexpected type ", t, ", expect encrypted data") + } + peerData := make([]byte, l) if _, err := io.ReadFull(c.Conn, peerData); err != nil { return 0, err } - dst := peerData[:peerLength-16] + dst := peerData[:l-16] if len(dst) <= len(b) { dst = b[:len(dst)] // avoids another copy() } var peerAead cipher.AEAD if bytes.Equal(c.peerNonce, MaxNonce) { - peerAead = NewAead(c.cipher, c.baseKey, peerData, peerHeader) + peerAead = NewAead(c.cipher, c.baseKey, peerData, h) } - _, err = c.peerAead.Open(dst[:0], c.peerNonce, peerData, peerHeader) + _, err = c.peerAead.Open(dst[:0], c.peerNonce, peerData, h) if peerAead != nil { c.peerAead = peerAead } IncreaseNonce(c.peerNonce) if err != nil { - return 0, errors.New("error") + return 0, err } if len(dst) > len(b) { c.peerCache = dst[copy(b, dst):] @@ -259,22 +275,23 @@ func (c *ServerConn) Write(b []byte) (int, error) { if c.peerRandom == nil { return 0, errors.New("empty c.peerRandom") } - data = make([]byte, 32+5+len(b)+16) - rand.Read(data[:32]) - c.aead = NewAead(c.cipher, c.baseKey, data[:32], c.peerRandom) + data = make([]byte, 5+32+5+len(b)+16) + EncodeHeader(data, 0, 32) + rand.Read(data[5 : 5+32]) + c.aead = NewAead(c.cipher, c.baseKey, data[5:5+32], c.peerRandom) c.nonce = make([]byte, 12) - EncodeHeader(data[32:], len(b)+16) - c.aead.Seal(data[:37], c.nonce, b, data[32:37]) + EncodeHeader(data[5+32:], 23, len(b)+16) + c.aead.Seal(data[:5+32+5], c.nonce, b, data[5+32:5+32+5]) } else { data = make([]byte, 5+len(b)+16) - EncodeHeader(data, len(b)+16) + EncodeHeader(data, 23, len(b)+16) c.aead.Seal(data[:5], c.nonce, b, data[:5]) if bytes.Equal(c.nonce, MaxNonce) { c.aead = NewAead(c.cipher, c.baseKey, data[5:], data[:5]) } } IncreaseNonce(c.nonce) - if _, err := c.Conn.Write(data); err != nil { + if n, err := c.Conn.Write(data); n != len(data) || err != nil { return 0, err } }