1package csync
2
3import (
4 "encoding/json"
5 "iter"
6 "maps"
7 "sync"
8)
9
10// Map is a concurrent map implementation that provides thread-safe access.
11type Map[K comparable, V any] struct {
12 inner map[K]V
13 mu sync.RWMutex
14}
15
16// NewMap creates a new thread-safe map with the specified key and value types.
17func NewMap[K comparable, V any]() *Map[K, V] {
18 return &Map[K, V]{
19 inner: make(map[K]V),
20 }
21}
22
23// NewMapFrom creates a new thread-safe map from an existing map.
24func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] {
25 return &Map[K, V]{
26 inner: m,
27 }
28}
29
30// Set sets the value for the specified key in the map.
31func (m *Map[K, V]) Set(key K, value V) {
32 m.mu.Lock()
33 defer m.mu.Unlock()
34 m.inner[key] = value
35}
36
37// Del deletes the specified key from the map.
38func (m *Map[K, V]) Del(key K) {
39 m.mu.Lock()
40 defer m.mu.Unlock()
41 delete(m.inner, key)
42}
43
44// Get gets the value for the specified key from the map.
45func (m *Map[K, V]) Get(key K) (V, bool) {
46 m.mu.RLock()
47 defer m.mu.RUnlock()
48 v, ok := m.inner[key]
49 return v, ok
50}
51
52// Len returns the number of items in the map.
53func (m *Map[K, V]) Len() int {
54 m.mu.RLock()
55 defer m.mu.RUnlock()
56 return len(m.inner)
57}
58
59// GetOrSet gets and returns the key if it exists, otherwise, it executes the
60// given function, set its return value for the given key, and returns it.
61func (m *Map[K, V]) GetOrSet(key K, fn func() V) V {
62 got, ok := m.Get(key)
63 if ok {
64 return got
65 }
66 value := fn()
67 m.Set(key, value)
68 return value
69}
70
71// Take gets an item and then deletes it.
72func (m *Map[K, V]) Take(key K) (V, bool) {
73 v, ok := m.Get(key)
74 if ok {
75 m.Del(key)
76 }
77 return v, ok
78}
79
80// Seq2 returns an iter.Seq2 that yields key-value pairs from the map.
81func (m *Map[K, V]) Seq2() iter.Seq2[K, V] {
82 dst := make(map[K]V)
83 m.mu.RLock()
84 maps.Copy(dst, m.inner)
85 m.mu.RUnlock()
86 return func(yield func(K, V) bool) {
87 for k, v := range dst {
88 if !yield(k, v) {
89 return
90 }
91 }
92 }
93}
94
95// Seq returns an iter.Seq that yields values from the map.
96func (m *Map[K, V]) Seq() iter.Seq[V] {
97 return func(yield func(V) bool) {
98 for _, v := range m.Seq2() {
99 if !yield(v) {
100 return
101 }
102 }
103 }
104}
105
106var (
107 _ json.Unmarshaler = &Map[string, any]{}
108 _ json.Marshaler = &Map[string, any]{}
109)
110
111func (Map[K, V]) JSONSchemaAlias() any { //nolint
112 m := map[K]V{}
113 return m
114}
115
116// UnmarshalJSON implements json.Unmarshaler.
117func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
118 m.mu.Lock()
119 defer m.mu.Unlock()
120 m.inner = make(map[K]V)
121 return json.Unmarshal(data, &m.inner)
122}
123
124// MarshalJSON implements json.Marshaler.
125func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
126 m.mu.RLock()
127 defer m.mu.RUnlock()
128 return json.Marshal(m.inner)
129}