1package sync
2
3import (
4 "context"
5 "sync"
6
7 "golang.org/x/sync/semaphore"
8)
9
10// WorkPool is a pool of work to be done.
11type WorkPool struct {
12 workers int
13 work map[string]func()
14 mu sync.RWMutex
15 sem *semaphore.Weighted
16 ctx context.Context
17 logger func(string, ...interface{})
18}
19
20// WorkPoolOption is a function that configures a WorkPool.
21type WorkPoolOption func(*WorkPool)
22
23// WithWorkPoolLogger sets the logger to use.
24func WithWorkPoolLogger(logger func(string, ...interface{})) WorkPoolOption {
25 return func(wq *WorkPool) {
26 wq.logger = logger
27 }
28}
29
30// NewWorkPool creates a new work pool. The workers argument specifies the
31// number of concurrent workers to run the work.
32// The queue will chunk the work into batches of workers size.
33func NewWorkPool(ctx context.Context, workers int, opts ...WorkPoolOption) *WorkPool {
34 wq := &WorkPool{
35 workers: workers,
36 work: make(map[string]func()),
37 ctx: ctx,
38 }
39
40 for _, opt := range opts {
41 opt(wq)
42 }
43
44 if wq.workers <= 0 {
45 wq.workers = 1
46 }
47
48 wq.sem = semaphore.NewWeighted(int64(wq.workers))
49
50 return wq
51}
52
53// Run starts the workers and waits for them to finish.
54func (wq *WorkPool) Run() {
55 for id, fn := range wq.work {
56 if err := wq.sem.Acquire(wq.ctx, 1); err != nil {
57 wq.logf("workpool: %v", err)
58 return
59 }
60
61 go func(id string, fn func()) {
62 defer wq.sem.Release(1)
63 fn()
64 wq.mu.Lock()
65 delete(wq.work, id)
66 wq.mu.Unlock()
67 }(id, fn)
68 }
69
70 if err := wq.sem.Acquire(wq.ctx, int64(wq.workers)); err != nil {
71 wq.logf("workpool: %v", err)
72 }
73}
74
75// Add adds a new job to the pool.
76// If the job already exists, it is a no-op.
77func (wq *WorkPool) Add(id string, fn func()) {
78 wq.mu.Lock()
79 defer wq.mu.Unlock()
80 if _, ok := wq.work[id]; ok {
81 return
82 }
83 wq.work[id] = fn
84}
85
86// Status checks if a job is in the queue.
87func (wq *WorkPool) Status(id string) bool {
88 wq.mu.RLock()
89 defer wq.mu.RUnlock()
90 _, ok := wq.work[id]
91 return ok
92}
93
94func (wq *WorkPool) logf(format string, args ...interface{}) {
95 if wq.logger != nil {
96 wq.logger(format, args...)
97 }
98}