1package sync
2
3import (
4 "context"
5 "sync"
6
7 "golang.org/x/sync/semaphore"
8)
9
10// WorkPool is a pool of work to be done.
11type WorkPool struct {
12 workers int
13 work sync.Map
14 sem *semaphore.Weighted
15 ctx context.Context
16 logger func(string, ...interface{})
17}
18
19// WorkPoolOption is a function that configures a WorkPool.
20type WorkPoolOption func(*WorkPool)
21
22// WithWorkPoolLogger sets the logger to use.
23func WithWorkPoolLogger(logger func(string, ...interface{})) WorkPoolOption {
24 return func(wq *WorkPool) {
25 wq.logger = logger
26 }
27}
28
29// NewWorkPool creates a new work pool. The workers argument specifies the
30// number of concurrent workers to run the work.
31// The queue will chunk the work into batches of workers size.
32func NewWorkPool(ctx context.Context, workers int, opts ...WorkPoolOption) *WorkPool {
33 wq := &WorkPool{
34 workers: workers,
35 ctx: ctx,
36 }
37
38 for _, opt := range opts {
39 opt(wq)
40 }
41
42 if wq.workers <= 0 {
43 wq.workers = 1
44 }
45
46 wq.sem = semaphore.NewWeighted(int64(wq.workers))
47
48 return wq
49}
50
51// Run starts the workers and waits for them to finish.
52func (wq *WorkPool) Run() {
53 wq.work.Range(func(key, value any) bool {
54 id := key.(string)
55 fn := value.(func())
56 if err := wq.sem.Acquire(wq.ctx, 1); err != nil {
57 wq.logf("workpool: %v", err)
58 return false
59 }
60
61 go func(id string, fn func()) {
62 defer wq.sem.Release(1)
63 fn()
64 wq.work.Delete(id)
65 }(id, fn)
66
67 return true
68 })
69
70 if err := wq.sem.Acquire(wq.ctx, int64(wq.workers)); err != nil {
71 wq.logf("workpool: %v", err)
72 }
73}
74
75// Add adds a new job to the pool.
76// If the job already exists, it is a no-op.
77func (wq *WorkPool) Add(id string, fn func()) {
78 if _, ok := wq.work.Load(id); ok {
79 return
80 }
81 wq.work.Store(id, fn)
82}
83
84// Status checks if a job is in the queue.
85func (wq *WorkPool) Status(id string) bool {
86 _, ok := wq.work.Load(id)
87 return ok
88}
89
90func (wq *WorkPool) logf(format string, args ...interface{}) {
91 if wq.logger != nil {
92 wq.logger(format, args...)
93 }
94}