1package sync
 2
 3import (
 4	"context"
 5	"sync"
 6
 7	"golang.org/x/sync/semaphore"
 8)
 9
10// WorkPool is a pool of work to be done.
11type WorkPool struct {
12	workers int
13	work    sync.Map
14	sem     *semaphore.Weighted
15	ctx     context.Context
16	logger  func(string, ...interface{})
17}
18
19// WorkPoolOption is a function that configures a WorkPool.
20type WorkPoolOption func(*WorkPool)
21
22// WithWorkPoolLogger sets the logger to use.
23func WithWorkPoolLogger(logger func(string, ...interface{})) WorkPoolOption {
24	return func(wq *WorkPool) {
25		wq.logger = logger
26	}
27}
28
29// NewWorkPool creates a new work pool. The workers argument specifies the
30// number of concurrent workers to run the work.
31// The queue will chunk the work into batches of workers size.
32func NewWorkPool(ctx context.Context, workers int, opts ...WorkPoolOption) *WorkPool {
33	wq := &WorkPool{
34		workers: workers,
35		ctx:     ctx,
36	}
37
38	for _, opt := range opts {
39		opt(wq)
40	}
41
42	if wq.workers <= 0 {
43		wq.workers = 1
44	}
45
46	wq.sem = semaphore.NewWeighted(int64(wq.workers))
47
48	return wq
49}
50
51// Run starts the workers and waits for them to finish.
52func (wq *WorkPool) Run() {
53	wq.work.Range(func(key, value any) bool {
54		id := key.(string)
55		fn := value.(func())
56		if err := wq.sem.Acquire(wq.ctx, 1); err != nil {
57			wq.logf("workpool: %v", err)
58			return false
59		}
60
61		go func(id string, fn func()) {
62			defer wq.sem.Release(1)
63			fn()
64			wq.work.Delete(id)
65		}(id, fn)
66
67		return true
68	})
69
70	if err := wq.sem.Acquire(wq.ctx, int64(wq.workers)); err != nil {
71		wq.logf("workpool: %v", err)
72	}
73}
74
75// Add adds a new job to the pool.
76// If the job already exists, it is a no-op.
77func (wq *WorkPool) Add(id string, fn func()) {
78	if _, ok := wq.work.Load(id); ok {
79		return
80	}
81	wq.work.Store(id, fn)
82}
83
84// Status checks if a job is in the queue.
85func (wq *WorkPool) Status(id string) bool {
86	_, ok := wq.work.Load(id)
87	return ok
88}
89
90func (wq *WorkPool) logf(format string, args ...interface{}) {
91	if wq.logger != nil {
92		wq.logger(format, args...)
93	}
94}