1package sync
 2
 3import (
 4	"context"
 5	"sync"
 6
 7	"golang.org/x/sync/semaphore"
 8)
 9
10// WorkPool is a pool of work to be done.
11type WorkPool struct {
12	workers int
13	work    map[string]func()
14	mu      sync.RWMutex
15	sem     *semaphore.Weighted
16	ctx     context.Context
17	logger  func(string, ...interface{})
18}
19
20// WorkPoolOption is a function that configures a WorkPool.
21type WorkPoolOption func(*WorkPool)
22
23// WithWorkPoolLogger sets the logger to use.
24func WithWorkPoolLogger(logger func(string, ...interface{})) WorkPoolOption {
25	return func(wq *WorkPool) {
26		wq.logger = logger
27	}
28}
29
30// NewWorkPool creates a new work pool. The workers argument specifies the
31// number of concurrent workers to run the work.
32// The queue will chunk the work into batches of workers size.
33func NewWorkPool(ctx context.Context, workers int, opts ...WorkPoolOption) *WorkPool {
34	wq := &WorkPool{
35		workers: workers,
36		work:    make(map[string]func()),
37		ctx:     ctx,
38	}
39
40	for _, opt := range opts {
41		opt(wq)
42	}
43
44	if wq.workers <= 0 {
45		wq.workers = 1
46	}
47
48	wq.sem = semaphore.NewWeighted(int64(wq.workers))
49
50	return wq
51}
52
53// Run starts the workers and waits for them to finish.
54func (wq *WorkPool) Run() {
55	for id, fn := range wq.work {
56		if err := wq.sem.Acquire(wq.ctx, 1); err != nil {
57			wq.logf("workpool: %v", err)
58			return
59		}
60
61		go func(id string, fn func()) {
62			defer wq.sem.Release(1)
63			fn()
64			wq.mu.Lock()
65			delete(wq.work, id)
66			wq.mu.Unlock()
67		}(id, fn)
68	}
69
70	if err := wq.sem.Acquire(wq.ctx, int64(wq.workers)); err != nil {
71		wq.logf("workpool: %v", err)
72	}
73}
74
75// Add adds a new job to the pool.
76// If the job already exists, it is a no-op.
77func (wq *WorkPool) Add(id string, fn func()) {
78	wq.mu.Lock()
79	defer wq.mu.Unlock()
80	if _, ok := wq.work[id]; ok {
81		return
82	}
83	wq.work[id] = fn
84}
85
86// Status checks if a job is in the queue.
87func (wq *WorkPool) Status(id string) bool {
88	wq.mu.RLock()
89	defer wq.mu.RUnlock()
90	_, ok := wq.work[id]
91	return ok
92}
93
94func (wq *WorkPool) logf(format string, args ...interface{}) {
95	if wq.logger != nil {
96		wq.logger(format, args...)
97	}
98}