diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 8a159fad..5294d28a 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -75,6 +75,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, } + if isTransportConn(conn) { + inbound.CanSpliceCopy = 3 + } switch network { case net.Network_TCP: @@ -199,7 +202,9 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ } responseDone := func() error { - inbound.CanSpliceCopy = 1 + if inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 + } defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) v2writer := buf.NewWriter(writer) @@ -259,7 +264,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis if inbound != nil && inbound.Source.IsValid() { errors.LogInfo(ctx, "client UDP connection from ", inbound.Source) } - inbound.CanSpliceCopy = 1 + if inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 + } var dest *net.Destination @@ -308,6 +315,20 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis } } +// isTransportConn return false if the conn is a raw tcp conn without transport or tls, can process splice copy +func isTransportConn(conn stat.Connection) bool { + if conn != nil { + statConn, ok := conn.(*stat.CounterConnection) + if ok { + conn = statConn.Connection + } + if _, ok := conn.(*net.TCPConn); ok { + return false + } + } + return true +} + func init() { common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { return NewServer(ctx, config.(*ServerConfig))