1package csync
2
3import (
4 "encoding/json"
5 "maps"
6 "sync"
7 "testing"
8 "testing/synctest"
9 "time"
10
11 "github.com/stretchr/testify/require"
12)
13
14func TestNewMap(t *testing.T) {
15 t.Parallel()
16
17 m := NewMap[string, int]()
18 require.NotNil(t, m)
19 require.NotNil(t, m.inner)
20 require.Equal(t, 0, m.Len())
21}
22
23func TestNewMapFrom(t *testing.T) {
24 t.Parallel()
25
26 original := map[string]int{
27 "key1": 1,
28 "key2": 2,
29 }
30
31 m := NewMapFrom(original)
32 require.NotNil(t, m)
33 require.Equal(t, original, m.inner)
34 require.Equal(t, 2, m.Len())
35
36 value, ok := m.Get("key1")
37 require.True(t, ok)
38 require.Equal(t, 1, value)
39}
40
41func TestNewLazyMap(t *testing.T) {
42 t.Parallel()
43
44 synctest.Test(t, func(t *testing.T) {
45 t.Helper()
46
47 waiter := sync.Mutex{}
48 waiter.Lock()
49 loadCalled := false
50
51 loadFunc := func() map[string]int {
52 waiter.Lock()
53 defer waiter.Unlock()
54 loadCalled = true
55 return map[string]int{
56 "key1": 1,
57 "key2": 2,
58 }
59 }
60
61 m := NewLazyMap(loadFunc)
62 require.NotNil(t, m)
63
64 waiter.Unlock() // Allow the load function to proceed
65 time.Sleep(100 * time.Millisecond)
66 require.True(t, loadCalled)
67 require.Equal(t, 2, m.Len())
68
69 value, ok := m.Get("key1")
70 require.True(t, ok)
71 require.Equal(t, 1, value)
72 })
73}
74
75func TestMap_Set(t *testing.T) {
76 t.Parallel()
77
78 m := NewMap[string, int]()
79
80 m.Set("key1", 42)
81 value, ok := m.Get("key1")
82 require.True(t, ok)
83 require.Equal(t, 42, value)
84 require.Equal(t, 1, m.Len())
85
86 m.Set("key1", 100)
87 value, ok = m.Get("key1")
88 require.True(t, ok)
89 require.Equal(t, 100, value)
90 require.Equal(t, 1, m.Len())
91}
92
93func TestMap_GetOrSet(t *testing.T) {
94 t.Parallel()
95
96 m := NewMap[string, int]()
97
98 require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
99 require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
100 require.Equal(t, 1, m.Len())
101}
102
103func TestMap_Get(t *testing.T) {
104 t.Parallel()
105
106 m := NewMap[string, int]()
107
108 value, ok := m.Get("nonexistent")
109 require.False(t, ok)
110 require.Equal(t, 0, value)
111
112 m.Set("key1", 42)
113 value, ok = m.Get("key1")
114 require.True(t, ok)
115 require.Equal(t, 42, value)
116}
117
118func TestMap_Del(t *testing.T) {
119 t.Parallel()
120
121 m := NewMap[string, int]()
122 m.Set("key1", 42)
123 m.Set("key2", 100)
124
125 require.Equal(t, 2, m.Len())
126
127 m.Del("key1")
128 _, ok := m.Get("key1")
129 require.False(t, ok)
130 require.Equal(t, 1, m.Len())
131
132 value, ok := m.Get("key2")
133 require.True(t, ok)
134 require.Equal(t, 100, value)
135
136 m.Del("nonexistent")
137 require.Equal(t, 1, m.Len())
138}
139
140func TestMap_Len(t *testing.T) {
141 t.Parallel()
142
143 m := NewMap[string, int]()
144 require.Equal(t, 0, m.Len())
145
146 m.Set("key1", 1)
147 require.Equal(t, 1, m.Len())
148
149 m.Set("key2", 2)
150 require.Equal(t, 2, m.Len())
151
152 m.Del("key1")
153 require.Equal(t, 1, m.Len())
154
155 m.Del("key2")
156 require.Equal(t, 0, m.Len())
157}
158
159func TestMap_Take(t *testing.T) {
160 t.Parallel()
161
162 m := NewMap[string, int]()
163 m.Set("key1", 42)
164 m.Set("key2", 100)
165
166 require.Equal(t, 2, m.Len())
167
168 value, ok := m.Take("key1")
169 require.True(t, ok)
170 require.Equal(t, 42, value)
171 require.Equal(t, 1, m.Len())
172
173 _, exists := m.Get("key1")
174 require.False(t, exists)
175
176 value, ok = m.Get("key2")
177 require.True(t, ok)
178 require.Equal(t, 100, value)
179}
180
181func TestMap_Take_NonexistentKey(t *testing.T) {
182 t.Parallel()
183
184 m := NewMap[string, int]()
185 m.Set("key1", 42)
186
187 value, ok := m.Take("nonexistent")
188 require.False(t, ok)
189 require.Equal(t, 0, value)
190 require.Equal(t, 1, m.Len())
191
192 value, ok = m.Get("key1")
193 require.True(t, ok)
194 require.Equal(t, 42, value)
195}
196
197func TestMap_Take_EmptyMap(t *testing.T) {
198 t.Parallel()
199
200 m := NewMap[string, int]()
201
202 value, ok := m.Take("key1")
203 require.False(t, ok)
204 require.Equal(t, 0, value)
205 require.Equal(t, 0, m.Len())
206}
207
208func TestMap_Take_SameKeyTwice(t *testing.T) {
209 t.Parallel()
210
211 m := NewMap[string, int]()
212 m.Set("key1", 42)
213
214 value, ok := m.Take("key1")
215 require.True(t, ok)
216 require.Equal(t, 42, value)
217 require.Equal(t, 0, m.Len())
218
219 value, ok = m.Take("key1")
220 require.False(t, ok)
221 require.Equal(t, 0, value)
222 require.Equal(t, 0, m.Len())
223}
224
225func TestMap_Seq2(t *testing.T) {
226 t.Parallel()
227
228 m := NewMap[string, int]()
229 m.Set("key1", 1)
230 m.Set("key2", 2)
231 m.Set("key3", 3)
232
233 collected := maps.Collect(m.Seq2())
234
235 require.Equal(t, 3, len(collected))
236 require.Equal(t, 1, collected["key1"])
237 require.Equal(t, 2, collected["key2"])
238 require.Equal(t, 3, collected["key3"])
239}
240
241func TestMap_Seq2_EarlyReturn(t *testing.T) {
242 t.Parallel()
243
244 m := NewMap[string, int]()
245 m.Set("key1", 1)
246 m.Set("key2", 2)
247 m.Set("key3", 3)
248
249 count := 0
250 for range m.Seq2() {
251 count++
252 if count == 2 {
253 break
254 }
255 }
256
257 require.Equal(t, 2, count)
258}
259
260func TestMap_Seq2_EmptyMap(t *testing.T) {
261 t.Parallel()
262
263 m := NewMap[string, int]()
264
265 count := 0
266 for range m.Seq2() {
267 count++
268 }
269
270 require.Equal(t, 0, count)
271}
272
273func TestMap_Seq(t *testing.T) {
274 t.Parallel()
275
276 m := NewMap[string, int]()
277 m.Set("key1", 1)
278 m.Set("key2", 2)
279 m.Set("key3", 3)
280
281 collected := make([]int, 0)
282 for v := range m.Seq() {
283 collected = append(collected, v)
284 }
285
286 require.Equal(t, 3, len(collected))
287 require.Contains(t, collected, 1)
288 require.Contains(t, collected, 2)
289 require.Contains(t, collected, 3)
290}
291
292func TestMap_Seq_EarlyReturn(t *testing.T) {
293 t.Parallel()
294
295 m := NewMap[string, int]()
296 m.Set("key1", 1)
297 m.Set("key2", 2)
298 m.Set("key3", 3)
299
300 count := 0
301 for range m.Seq() {
302 count++
303 if count == 2 {
304 break
305 }
306 }
307
308 require.Equal(t, 2, count)
309}
310
311func TestMap_Seq_EmptyMap(t *testing.T) {
312 t.Parallel()
313
314 m := NewMap[string, int]()
315
316 count := 0
317 for range m.Seq() {
318 count++
319 }
320
321 require.Equal(t, 0, count)
322}
323
324func TestMap_MarshalJSON(t *testing.T) {
325 t.Parallel()
326
327 m := NewMap[string, int]()
328 m.Set("key1", 1)
329 m.Set("key2", 2)
330
331 data, err := json.Marshal(m)
332 require.NoError(t, err)
333
334 result := &Map[string, int]{}
335 err = json.Unmarshal(data, result)
336 require.NoError(t, err)
337 require.Equal(t, 2, result.Len())
338 v1, _ := result.Get("key1")
339 v2, _ := result.Get("key2")
340 require.Equal(t, 1, v1)
341 require.Equal(t, 2, v2)
342}
343
344func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
345 t.Parallel()
346
347 m := NewMap[string, int]()
348
349 data, err := json.Marshal(m)
350 require.NoError(t, err)
351 require.Equal(t, "{}", string(data))
352}
353
354func TestMap_UnmarshalJSON(t *testing.T) {
355 t.Parallel()
356
357 jsonData := `{"key1": 1, "key2": 2}`
358
359 m := NewMap[string, int]()
360 err := json.Unmarshal([]byte(jsonData), m)
361 require.NoError(t, err)
362
363 require.Equal(t, 2, m.Len())
364 value, ok := m.Get("key1")
365 require.True(t, ok)
366 require.Equal(t, 1, value)
367
368 value, ok = m.Get("key2")
369 require.True(t, ok)
370 require.Equal(t, 2, value)
371}
372
373func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
374 t.Parallel()
375
376 jsonData := `{}`
377
378 m := NewMap[string, int]()
379 err := json.Unmarshal([]byte(jsonData), m)
380 require.NoError(t, err)
381 require.Equal(t, 0, m.Len())
382}
383
384func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
385 t.Parallel()
386
387 jsonData := `{"key1": 1, "key2":}`
388
389 m := NewMap[string, int]()
390 err := json.Unmarshal([]byte(jsonData), m)
391 require.Error(t, err)
392}
393
394func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
395 t.Parallel()
396
397 m := NewMap[string, int]()
398 m.Set("existing", 999)
399
400 jsonData := `{"key1": 1, "key2": 2}`
401 err := json.Unmarshal([]byte(jsonData), m)
402 require.NoError(t, err)
403
404 require.Equal(t, 2, m.Len())
405 _, ok := m.Get("existing")
406 require.False(t, ok)
407
408 value, ok := m.Get("key1")
409 require.True(t, ok)
410 require.Equal(t, 1, value)
411}
412
413func TestMap_JSONRoundTrip(t *testing.T) {
414 t.Parallel()
415
416 original := NewMap[string, int]()
417 original.Set("key1", 1)
418 original.Set("key2", 2)
419 original.Set("key3", 3)
420
421 data, err := json.Marshal(original)
422 require.NoError(t, err)
423
424 restored := NewMap[string, int]()
425 err = json.Unmarshal(data, restored)
426 require.NoError(t, err)
427
428 require.Equal(t, original.Len(), restored.Len())
429
430 for k, v := range original.Seq2() {
431 restoredValue, ok := restored.Get(k)
432 require.True(t, ok)
433 require.Equal(t, v, restoredValue)
434 }
435}
436
437func TestMap_ConcurrentAccess(t *testing.T) {
438 t.Parallel()
439
440 m := NewMap[int, int]()
441 const numGoroutines = 100
442 const numOperations = 100
443
444 var wg sync.WaitGroup
445 wg.Add(numGoroutines)
446
447 for i := range numGoroutines {
448 go func(id int) {
449 defer wg.Done()
450 for j := range numOperations {
451 key := id*numOperations + j
452 m.Set(key, key*2)
453 value, ok := m.Get(key)
454 require.True(t, ok)
455 require.Equal(t, key*2, value)
456 }
457 }(i)
458 }
459
460 wg.Wait()
461
462 require.Equal(t, numGoroutines*numOperations, m.Len())
463}
464
465func TestMap_ConcurrentReadWrite(t *testing.T) {
466 t.Parallel()
467
468 m := NewMap[int, int]()
469 const numReaders = 50
470 const numWriters = 50
471 const numOperations = 100
472
473 for i := range 1000 {
474 m.Set(i, i)
475 }
476
477 var wg sync.WaitGroup
478 wg.Add(numReaders + numWriters)
479
480 for range numReaders {
481 go func() {
482 defer wg.Done()
483 for j := range numOperations {
484 key := j % 1000
485 value, ok := m.Get(key)
486 if ok {
487 require.Equal(t, key, value)
488 }
489 _ = m.Len()
490 }
491 }()
492 }
493
494 for i := range numWriters {
495 go func(id int) {
496 defer wg.Done()
497 for j := range numOperations {
498 key := 1000 + id*numOperations + j
499 m.Set(key, key)
500 if j%10 == 0 {
501 m.Del(key)
502 }
503 }
504 }(i)
505 }
506
507 wg.Wait()
508}
509
510func TestMap_ConcurrentSeq2(t *testing.T) {
511 t.Parallel()
512
513 m := NewMap[int, int]()
514 for i := range 100 {
515 m.Set(i, i*2)
516 }
517
518 var wg sync.WaitGroup
519 const numIterators = 10
520
521 wg.Add(numIterators)
522 for range numIterators {
523 go func() {
524 defer wg.Done()
525 count := 0
526 for k, v := range m.Seq2() {
527 require.Equal(t, k*2, v)
528 count++
529 }
530 require.Equal(t, 100, count)
531 }()
532 }
533
534 wg.Wait()
535}
536
537func TestMap_ConcurrentSeq(t *testing.T) {
538 t.Parallel()
539
540 m := NewMap[int, int]()
541 for i := range 100 {
542 m.Set(i, i*2)
543 }
544
545 var wg sync.WaitGroup
546 const numIterators = 10
547
548 wg.Add(numIterators)
549 for range numIterators {
550 go func() {
551 defer wg.Done()
552 count := 0
553 values := make(map[int]bool)
554 for v := range m.Seq() {
555 values[v] = true
556 count++
557 }
558 require.Equal(t, 100, count)
559 for i := range 100 {
560 require.True(t, values[i*2])
561 }
562 }()
563 }
564
565 wg.Wait()
566}
567
568func TestMap_ConcurrentTake(t *testing.T) {
569 t.Parallel()
570
571 m := NewMap[int, int]()
572 const numItems = 1000
573
574 for i := range numItems {
575 m.Set(i, i*2)
576 }
577
578 var wg sync.WaitGroup
579 const numWorkers = 10
580 taken := make([][]int, numWorkers)
581
582 wg.Add(numWorkers)
583 for i := range numWorkers {
584 go func(workerID int) {
585 defer wg.Done()
586 taken[workerID] = make([]int, 0)
587 for j := workerID; j < numItems; j += numWorkers {
588 if value, ok := m.Take(j); ok {
589 taken[workerID] = append(taken[workerID], value)
590 }
591 }
592 }(i)
593 }
594
595 wg.Wait()
596
597 require.Equal(t, 0, m.Len())
598
599 allTaken := make(map[int]bool)
600 for _, workerTaken := range taken {
601 for _, value := range workerTaken {
602 require.False(t, allTaken[value], "Value %d was taken multiple times", value)
603 allTaken[value] = true
604 }
605 }
606
607 require.Equal(t, numItems, len(allTaken))
608 for i := range numItems {
609 require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
610 }
611}
612
613func TestMap_TypeSafety(t *testing.T) {
614 t.Parallel()
615
616 stringIntMap := NewMap[string, int]()
617 stringIntMap.Set("key", 42)
618 value, ok := stringIntMap.Get("key")
619 require.True(t, ok)
620 require.Equal(t, 42, value)
621
622 intStringMap := NewMap[int, string]()
623 intStringMap.Set(42, "value")
624 strValue, ok := intStringMap.Get(42)
625 require.True(t, ok)
626 require.Equal(t, "value", strValue)
627
628 structMap := NewMap[string, struct{ Name string }]()
629 structMap.Set("key", struct{ Name string }{Name: "test"})
630 structValue, ok := structMap.Get("key")
631 require.True(t, ok)
632 require.Equal(t, "test", structValue.Name)
633}
634
635func TestMap_InterfaceCompliance(t *testing.T) {
636 t.Parallel()
637
638 var _ json.Marshaler = &Map[string, any]{}
639 var _ json.Unmarshaler = &Map[string, any]{}
640}
641
642func BenchmarkMap_Set(b *testing.B) {
643 m := NewMap[int, int]()
644
645 for i := 0; b.Loop(); i++ {
646 m.Set(i, i*2)
647 }
648}
649
650func BenchmarkMap_Get(b *testing.B) {
651 m := NewMap[int, int]()
652 for i := range 1000 {
653 m.Set(i, i*2)
654 }
655
656 for i := 0; b.Loop(); i++ {
657 m.Get(i % 1000)
658 }
659}
660
661func BenchmarkMap_Seq2(b *testing.B) {
662 m := NewMap[int, int]()
663 for i := range 1000 {
664 m.Set(i, i*2)
665 }
666
667 for b.Loop() {
668 for range m.Seq2() {
669 }
670 }
671}
672
673func BenchmarkMap_Seq(b *testing.B) {
674 m := NewMap[int, int]()
675 for i := range 1000 {
676 m.Set(i, i*2)
677 }
678
679 for b.Loop() {
680 for range m.Seq() {
681 }
682 }
683}
684
685func BenchmarkMap_Take(b *testing.B) {
686 m := NewMap[int, int]()
687 for i := range 1000 {
688 m.Set(i, i*2)
689 }
690
691 b.ResetTimer()
692 for i := 0; b.Loop(); i++ {
693 key := i % 1000
694 m.Take(key)
695 if i%1000 == 999 {
696 b.StopTimer()
697 for j := range 1000 {
698 m.Set(j, j*2)
699 }
700 b.StartTimer()
701 }
702 }
703}
704
705func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
706 m := NewMap[int, int]()
707 for i := range 1000 {
708 m.Set(i, i*2)
709 }
710
711 b.ResetTimer()
712 b.RunParallel(func(pb *testing.PB) {
713 i := 0
714 for pb.Next() {
715 if i%2 == 0 {
716 m.Get(i % 1000)
717 } else {
718 m.Set(i+1000, i*2)
719 }
720 i++
721 }
722 })
723}