diff --git a/proxy/vless/encryption/server.go b/proxy/vless/encryption/server.go index 63c0e8a7..1cf5b2f9 100644 --- a/proxy/vless/encryption/server.go +++ b/proxy/vless/encryption/server.go @@ -28,6 +28,7 @@ type ServerInstance struct { xor uint32 minutes time.Duration sessions map[[21]byte]*ServerSession + closed bool } type ServerConn struct { @@ -57,6 +58,10 @@ func (i *ServerInstance) Init(nfsDKeySeed []byte, xor uint32, minutes time.Durat time.Sleep(time.Minute) now := time.Now() i.Lock() + if i.closed { + i.Unlock() + return + } for ticket, session := range i.sessions { if now.After(session.expire) { delete(i.sessions, ticket) @@ -69,6 +74,13 @@ func (i *ServerInstance) Init(nfsDKeySeed []byte, xor uint32, minutes time.Durat return } +func (i *ServerInstance) Close() (err error) { + i.Lock() + i.closed = true + i.Unlock() + return +} + func (i *ServerInstance) Handshake(conn net.Conn) (net.Conn, error) { if i.nfsDKey == nil { return nil, errors.New("uninitialized") @@ -215,7 +227,7 @@ func (c *ServerConn) Read(b []byte) (int, error) { } var peerAead cipher.AEAD if bytes.Equal(c.peerNonce, MaxNonce) { - peerAead = NewAead(ClientCipher, c.baseKey, peerData, peerHeader) + peerAead = NewAead(c.cipher, c.baseKey, peerData, peerHeader) } _, err = c.peerAead.Open(dst[:0], c.peerNonce, peerData, peerHeader) if peerAead != nil { @@ -258,7 +270,7 @@ func (c *ServerConn) Write(b []byte) (int, error) { EncodeHeader(data, len(b)+16) c.aead.Seal(data[:5], c.nonce, b, data[:5]) if bytes.Equal(c.nonce, MaxNonce) { - c.aead = NewAead(ClientCipher, c.baseKey, data[5:], data[:5]) + c.aead = NewAead(c.cipher, c.baseKey, data[5:], data[:5]) } } IncreaseNonce(c.nonce) diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 7c5e9005..06c9e518 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -170,6 +170,9 @@ func isMuxAndNotXUDP(request *protocol.RequestHeader, first *buf.Buffer) bool { // Close implements common.Closable.Close(). func (h *Handler) Close() error { + if h.decryption != nil { + h.decryption.Close() + } return errors.Combine(common.Close(h.validator)) }