diff --git a/daemon/daemon.go b/daemon/daemon.go index 0fcbe5b46bb5db92ed76834abaafa343a06003e0..bd0b20c65ab92c3a9fff1eeb77c8272a31cd502a 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -132,7 +132,9 @@ func (d *Daemon) Run() error { // Cleanup. log.Println("daemon: shutting down") d.listener.Close() //nolint:errcheck,gosec - d.idleWatcher.StopAll() + if err := d.idleWatcher.StopAllAndWaitTimeout(5 * time.Second); err != nil { + log.Printf("daemon: %v", err) + } cancel() d.closeAllClients() d.closeProviders() diff --git a/fetcher/idle.go b/fetcher/idle.go index 7757b04e531d984ba1be08a3cfdf365253bf876a..6daf59e0d5d0517cda58809f0d72ead6f9e7c254 100644 --- a/fetcher/idle.go +++ b/fetcher/idle.go @@ -1,6 +1,7 @@ package fetcher import ( + "errors" "log" "strings" "sync" @@ -19,10 +20,14 @@ type IdleUpdate struct { // IdleWatcher manages IDLE connections for multiple accounts. type IdleWatcher struct { mu sync.Mutex + wg sync.WaitGroup watchers map[string]*accountIdle // key: account ID notify chan<- IdleUpdate } +// ErrStopTimeout is returned when IDLE watcher goroutines do not stop before the timeout. +var ErrStopTimeout = errors.New("idle watcher: stop timed out") + // accountIdle manages a single IDLE connection for one account. type accountIdle struct { account *config.Account @@ -60,7 +65,11 @@ func (w *IdleWatcher) Watch(account *config.Account, folder string) { done: make(chan struct{}), } w.watchers[account.ID] = a - go a.run() + w.wg.Add(1) + go func() { + defer w.wg.Done() + a.run() + }() } // Stop stops the IDLE watcher for a specific account. @@ -100,6 +109,30 @@ func (w *IdleWatcher) StopAllAndWait() { for _, done := range pending { <-done } + w.wg.Wait() +} + +// StopAllAndWaitTimeout stops all IDLE watchers and waits for them to finish up to d. +func (w *IdleWatcher) StopAllAndWaitTimeout(d time.Duration) error { + w.mu.Lock() + for id, a := range w.watchers { + close(a.stop) + delete(w.watchers, id) + } + w.mu.Unlock() + + done := make(chan struct{}) + go func() { + w.wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-time.After(d): + return ErrStopTimeout + } } func (a *accountIdle) run() { diff --git a/fetcher/idle_test.go b/fetcher/idle_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a7ac96a21cf66316af5b38bacdd6010fbe30b1f2 --- /dev/null +++ b/fetcher/idle_test.go @@ -0,0 +1,72 @@ +package fetcher + +import ( + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/floatpane/matcha/config" +) + +func TestIdleWatcher_StopAllAndWait_TracksReplacedGoroutines(t *testing.T) { + w := NewIdleWatcher(make(chan IdleUpdate)) + stopCh := make(chan struct{}) + doneCh := make(chan struct{}) + var exits atomic.Int64 + + w.wg.Add(1) + go func() { + defer w.wg.Done() + defer close(doneCh) + <-stopCh + exits.Add(1) + }() + + w.watchers["acct"] = &accountIdle{ + account: &config.Account{ID: "acct"}, + stop: stopCh, + done: doneCh, + } + + if err := w.StopAllAndWaitTimeout(time.Second); err != nil { + t.Fatalf("StopAllAndWaitTimeout returned error: %v", err) + } + if got := exits.Load(); got != 1 { + t.Fatalf("expected synthetic watcher to exit once, got %d", got) + } +} + +func TestIdleWatcher_StopAllAndWaitTimeout_ReturnsOnSlowExit(t *testing.T) { + w := NewIdleWatcher(make(chan IdleUpdate)) + stopCh := make(chan struct{}) + doneCh := make(chan struct{}) + release := make(chan struct{}) + exited := make(chan struct{}) + + w.wg.Add(1) + go func() { + defer w.wg.Done() + defer close(doneCh) + defer close(exited) + <-release + }() + + w.watchers["acct"] = &accountIdle{ + account: &config.Account{ID: "acct"}, + stop: stopCh, + done: doneCh, + } + + err := w.StopAllAndWaitTimeout(50 * time.Millisecond) + if !errors.Is(err, ErrStopTimeout) { + t.Fatalf("expected ErrStopTimeout, got %v", err) + } + + close(release) + select { + case <-exited: + case <-time.After(time.Second): + t.Fatal("synthetic watcher did not exit during cleanup") + } +}