1// Package sync provides synchronization utilities.
2package sync
3
4import (
5 "context"
6 "sync"
7
8 "golang.org/x/sync/semaphore"
9)
10
11// WorkPool is a pool of work to be done.
12type WorkPool struct {
13 workers int
14 work sync.Map
15 sem *semaphore.Weighted
16 ctx context.Context
17 logger func(string, ...interface{})
18}
19
20// WorkPoolOption is a function that configures a WorkPool.
21type WorkPoolOption func(*WorkPool)
22
23// WithWorkPoolLogger sets the logger to use.
24func WithWorkPoolLogger(logger func(string, ...interface{})) WorkPoolOption {
25 return func(wq *WorkPool) {
26 wq.logger = logger
27 }
28}
29
30// NewWorkPool creates a new work pool. The workers argument specifies the
31// number of concurrent workers to run the work.
32// The queue will chunk the work into batches of workers size.
33func NewWorkPool(ctx context.Context, workers int, opts ...WorkPoolOption) *WorkPool {
34 wq := &WorkPool{
35 workers: workers,
36 ctx: ctx,
37 }
38
39 for _, opt := range opts {
40 opt(wq)
41 }
42
43 if wq.workers <= 0 {
44 wq.workers = 1
45 }
46
47 wq.sem = semaphore.NewWeighted(int64(wq.workers))
48
49 return wq
50}
51
52// Run starts the workers and waits for them to finish.
53func (wq *WorkPool) Run() {
54 wq.work.Range(func(key, value any) bool {
55 id := key.(string)
56 fn := value.(func())
57 if err := wq.sem.Acquire(wq.ctx, 1); err != nil {
58 wq.logf("workpool: %v", err)
59 return false
60 }
61
62 go func(id string, fn func()) {
63 defer wq.sem.Release(1)
64 fn()
65 wq.work.Delete(id)
66 }(id, fn)
67
68 return true
69 })
70
71 if err := wq.sem.Acquire(wq.ctx, int64(wq.workers)); err != nil {
72 wq.logf("workpool: %v", err)
73 }
74}
75
76// Add adds a new job to the pool.
77// If the job already exists, it is a no-op.
78func (wq *WorkPool) Add(id string, fn func()) {
79 if _, ok := wq.work.Load(id); ok {
80 return
81 }
82 wq.work.Store(id, fn)
83}
84
85// Status checks if a job is in the queue.
86func (wq *WorkPool) Status(id string) bool {
87 _, ok := wq.work.Load(id)
88 return ok
89}
90
91func (wq *WorkPool) logf(format string, args ...interface{}) {
92 if wq.logger != nil {
93 wq.logger(format, args...)
94 }
95}