diff --git a/common/signal/timer.go b/common/signal/timer.go index da108f2c..e2c052de 100644 --- a/common/signal/timer.go +++ b/common/signal/timer.go @@ -15,9 +15,10 @@ type ActivityUpdater interface { type ActivityTimer struct { sync.RWMutex - updated chan struct{} - checkTask *task.Periodic - onTimeout func() + updated chan struct{} + checkTask *task.Periodic + onTimeout func() + overridden bool } func (t *ActivityTimer) Update() { @@ -31,14 +32,16 @@ func (t *ActivityTimer) check() error { select { case <-t.updated: default: - t.finish() + t.finish(false) } return nil } -func (t *ActivityTimer) finish() { - t.Lock() - defer t.Unlock() +func (t *ActivityTimer) finish(locked bool) { + if !locked { + t.Lock() + defer t.Unlock() + } if t.onTimeout != nil { t.onTimeout() @@ -50,17 +53,15 @@ func (t *ActivityTimer) finish() { } } -func (t *ActivityTimer) SetTimeout(timeout time.Duration) { - if timeout == 0 { - t.finish() - return - } - - t.Lock() - defer t.Unlock() +func (t *ActivityTimer) setTimeout(timeout time.Duration) { if t.onTimeout == nil { return } + if timeout == 0 { + t.finish(true) + return + } + checkTask := &task.Periodic{ Interval: timeout, Execute: t.check, @@ -68,12 +69,27 @@ func (t *ActivityTimer) SetTimeout(timeout time.Duration) { if t.checkTask != nil { t.checkTask.Close() + t.overridden = true } t.checkTask = checkTask t.Update() common.Must(checkTask.Start()) } +func (t *ActivityTimer) SetTimeout(timeout time.Duration) { + t.Lock() + t.setTimeout(timeout) + t.Unlock() +} + +func (t *ActivityTimer) SetTimeoutIfNotOverridden(timeout time.Duration) { + t.Lock() + if !t.overridden { + t.setTimeout(timeout) + } + t.Unlock() +} + func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer { timer := &ActivityTimer{ updated: make(chan struct{}, 1), diff --git a/proxy/proxy.go b/proxy/proxy.go index 3fec31af..1ae9fea4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -596,9 +596,9 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net statWriter, _ := writer.(*dispatcher.SizeStatWriter) //runtime.Gosched() // necessary time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice - timer.SetTimeout(8 * time.Hour) // prevent leak, just in case + timer.SetTimeoutIfNotOverridden(8 * time.Hour) // prevent leak, just in case if inTimer != nil { - inTimer.SetTimeout(8 * time.Hour) + inTimer.SetTimeoutIfNotOverridden(8 * time.Hour) } w, err := tc.ReadFrom(readerConn) if readCounter != nil {