1package csync
2
3import (
4 "encoding/json"
5 "maps"
6 "sync"
7 "testing"
8 "testing/synctest"
9 "time"
10
11 "github.com/stretchr/testify/require"
12)
13
14func TestNewMap(t *testing.T) {
15 t.Parallel()
16
17 m := NewMap[string, int]()
18 require.NotNil(t, m)
19 require.NotNil(t, m.inner)
20 require.Equal(t, 0, m.Len())
21}
22
23func TestNewMapFrom(t *testing.T) {
24 t.Parallel()
25
26 original := map[string]int{
27 "key1": 1,
28 "key2": 2,
29 }
30
31 m := NewMapFrom(original)
32 require.NotNil(t, m)
33 require.Equal(t, original, m.inner)
34 require.Equal(t, 2, m.Len())
35
36 value, ok := m.Get("key1")
37 require.True(t, ok)
38 require.Equal(t, 1, value)
39}
40
41func TestNewLazyMap(t *testing.T) {
42 t.Parallel()
43
44 synctest.Test(t, func(t *testing.T) {
45 t.Helper()
46
47 waiter := sync.Mutex{}
48 waiter.Lock()
49 loadCalled := false
50
51 loadFunc := func() map[string]int {
52 waiter.Lock()
53 defer waiter.Unlock()
54 loadCalled = true
55 return map[string]int{
56 "key1": 1,
57 "key2": 2,
58 }
59 }
60
61 m := NewLazyMap(loadFunc)
62 require.NotNil(t, m)
63
64 waiter.Unlock() // Allow the load function to proceed
65 time.Sleep(100 * time.Millisecond)
66 require.True(t, loadCalled)
67 require.Equal(t, 2, m.Len())
68
69 value, ok := m.Get("key1")
70 require.True(t, ok)
71 require.Equal(t, 1, value)
72 })
73}
74
75func TestMap_Reset(t *testing.T) {
76 t.Parallel()
77
78 m := NewMapFrom(map[string]int{
79 "a": 10,
80 })
81
82 m.Reset(map[string]int{
83 "b": 20,
84 })
85 value, ok := m.Get("b")
86 require.True(t, ok)
87 require.Equal(t, 20, value)
88 require.Equal(t, 1, m.Len())
89}
90
91func TestMap_Set(t *testing.T) {
92 t.Parallel()
93
94 m := NewMap[string, int]()
95
96 m.Set("key1", 42)
97 value, ok := m.Get("key1")
98 require.True(t, ok)
99 require.Equal(t, 42, value)
100 require.Equal(t, 1, m.Len())
101
102 m.Set("key1", 100)
103 value, ok = m.Get("key1")
104 require.True(t, ok)
105 require.Equal(t, 100, value)
106 require.Equal(t, 1, m.Len())
107}
108
109func TestMap_GetOrSet(t *testing.T) {
110 t.Parallel()
111
112 m := NewMap[string, int]()
113
114 require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
115 require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
116 require.Equal(t, 1, m.Len())
117}
118
119func TestMap_Get(t *testing.T) {
120 t.Parallel()
121
122 m := NewMap[string, int]()
123
124 value, ok := m.Get("nonexistent")
125 require.False(t, ok)
126 require.Equal(t, 0, value)
127
128 m.Set("key1", 42)
129 value, ok = m.Get("key1")
130 require.True(t, ok)
131 require.Equal(t, 42, value)
132}
133
134func TestMap_Del(t *testing.T) {
135 t.Parallel()
136
137 m := NewMap[string, int]()
138 m.Set("key1", 42)
139 m.Set("key2", 100)
140
141 require.Equal(t, 2, m.Len())
142
143 m.Del("key1")
144 _, ok := m.Get("key1")
145 require.False(t, ok)
146 require.Equal(t, 1, m.Len())
147
148 value, ok := m.Get("key2")
149 require.True(t, ok)
150 require.Equal(t, 100, value)
151
152 m.Del("nonexistent")
153 require.Equal(t, 1, m.Len())
154}
155
156func TestMap_Len(t *testing.T) {
157 t.Parallel()
158
159 m := NewMap[string, int]()
160 require.Equal(t, 0, m.Len())
161
162 m.Set("key1", 1)
163 require.Equal(t, 1, m.Len())
164
165 m.Set("key2", 2)
166 require.Equal(t, 2, m.Len())
167
168 m.Del("key1")
169 require.Equal(t, 1, m.Len())
170
171 m.Del("key2")
172 require.Equal(t, 0, m.Len())
173}
174
175func TestMap_Take(t *testing.T) {
176 t.Parallel()
177
178 m := NewMap[string, int]()
179 m.Set("key1", 42)
180 m.Set("key2", 100)
181
182 require.Equal(t, 2, m.Len())
183
184 value, ok := m.Take("key1")
185 require.True(t, ok)
186 require.Equal(t, 42, value)
187 require.Equal(t, 1, m.Len())
188
189 _, exists := m.Get("key1")
190 require.False(t, exists)
191
192 value, ok = m.Get("key2")
193 require.True(t, ok)
194 require.Equal(t, 100, value)
195}
196
197func TestMap_Take_NonexistentKey(t *testing.T) {
198 t.Parallel()
199
200 m := NewMap[string, int]()
201 m.Set("key1", 42)
202
203 value, ok := m.Take("nonexistent")
204 require.False(t, ok)
205 require.Equal(t, 0, value)
206 require.Equal(t, 1, m.Len())
207
208 value, ok = m.Get("key1")
209 require.True(t, ok)
210 require.Equal(t, 42, value)
211}
212
213func TestMap_Take_EmptyMap(t *testing.T) {
214 t.Parallel()
215
216 m := NewMap[string, int]()
217
218 value, ok := m.Take("key1")
219 require.False(t, ok)
220 require.Equal(t, 0, value)
221 require.Equal(t, 0, m.Len())
222}
223
224func TestMap_Take_SameKeyTwice(t *testing.T) {
225 t.Parallel()
226
227 m := NewMap[string, int]()
228 m.Set("key1", 42)
229
230 value, ok := m.Take("key1")
231 require.True(t, ok)
232 require.Equal(t, 42, value)
233 require.Equal(t, 0, m.Len())
234
235 value, ok = m.Take("key1")
236 require.False(t, ok)
237 require.Equal(t, 0, value)
238 require.Equal(t, 0, m.Len())
239}
240
241func TestMap_Seq2(t *testing.T) {
242 t.Parallel()
243
244 m := NewMap[string, int]()
245 m.Set("key1", 1)
246 m.Set("key2", 2)
247 m.Set("key3", 3)
248
249 collected := maps.Collect(m.Seq2())
250
251 require.Equal(t, 3, len(collected))
252 require.Equal(t, 1, collected["key1"])
253 require.Equal(t, 2, collected["key2"])
254 require.Equal(t, 3, collected["key3"])
255}
256
257func TestMap_Seq2_EarlyReturn(t *testing.T) {
258 t.Parallel()
259
260 m := NewMap[string, int]()
261 m.Set("key1", 1)
262 m.Set("key2", 2)
263 m.Set("key3", 3)
264
265 count := 0
266 for range m.Seq2() {
267 count++
268 if count == 2 {
269 break
270 }
271 }
272
273 require.Equal(t, 2, count)
274}
275
276func TestMap_Seq2_EmptyMap(t *testing.T) {
277 t.Parallel()
278
279 m := NewMap[string, int]()
280
281 count := 0
282 for range m.Seq2() {
283 count++
284 }
285
286 require.Equal(t, 0, count)
287}
288
289func TestMap_Seq(t *testing.T) {
290 t.Parallel()
291
292 m := NewMap[string, int]()
293 m.Set("key1", 1)
294 m.Set("key2", 2)
295 m.Set("key3", 3)
296
297 collected := make([]int, 0)
298 for v := range m.Seq() {
299 collected = append(collected, v)
300 }
301
302 require.Equal(t, 3, len(collected))
303 require.Contains(t, collected, 1)
304 require.Contains(t, collected, 2)
305 require.Contains(t, collected, 3)
306}
307
308func TestMap_Seq_EarlyReturn(t *testing.T) {
309 t.Parallel()
310
311 m := NewMap[string, int]()
312 m.Set("key1", 1)
313 m.Set("key2", 2)
314 m.Set("key3", 3)
315
316 count := 0
317 for range m.Seq() {
318 count++
319 if count == 2 {
320 break
321 }
322 }
323
324 require.Equal(t, 2, count)
325}
326
327func TestMap_Seq_EmptyMap(t *testing.T) {
328 t.Parallel()
329
330 m := NewMap[string, int]()
331
332 count := 0
333 for range m.Seq() {
334 count++
335 }
336
337 require.Equal(t, 0, count)
338}
339
340func TestMap_MarshalJSON(t *testing.T) {
341 t.Parallel()
342
343 m := NewMap[string, int]()
344 m.Set("key1", 1)
345 m.Set("key2", 2)
346
347 data, err := json.Marshal(m)
348 require.NoError(t, err)
349
350 result := &Map[string, int]{}
351 err = json.Unmarshal(data, result)
352 require.NoError(t, err)
353 require.Equal(t, 2, result.Len())
354 v1, _ := result.Get("key1")
355 v2, _ := result.Get("key2")
356 require.Equal(t, 1, v1)
357 require.Equal(t, 2, v2)
358}
359
360func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
361 t.Parallel()
362
363 m := NewMap[string, int]()
364
365 data, err := json.Marshal(m)
366 require.NoError(t, err)
367 require.Equal(t, "{}", string(data))
368}
369
370func TestMap_UnmarshalJSON(t *testing.T) {
371 t.Parallel()
372
373 jsonData := `{"key1": 1, "key2": 2}`
374
375 m := NewMap[string, int]()
376 err := json.Unmarshal([]byte(jsonData), m)
377 require.NoError(t, err)
378
379 require.Equal(t, 2, m.Len())
380 value, ok := m.Get("key1")
381 require.True(t, ok)
382 require.Equal(t, 1, value)
383
384 value, ok = m.Get("key2")
385 require.True(t, ok)
386 require.Equal(t, 2, value)
387}
388
389func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
390 t.Parallel()
391
392 jsonData := `{}`
393
394 m := NewMap[string, int]()
395 err := json.Unmarshal([]byte(jsonData), m)
396 require.NoError(t, err)
397 require.Equal(t, 0, m.Len())
398}
399
400func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
401 t.Parallel()
402
403 jsonData := `{"key1": 1, "key2":}`
404
405 m := NewMap[string, int]()
406 err := json.Unmarshal([]byte(jsonData), m)
407 require.Error(t, err)
408}
409
410func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
411 t.Parallel()
412
413 m := NewMap[string, int]()
414 m.Set("existing", 999)
415
416 jsonData := `{"key1": 1, "key2": 2}`
417 err := json.Unmarshal([]byte(jsonData), m)
418 require.NoError(t, err)
419
420 require.Equal(t, 2, m.Len())
421 _, ok := m.Get("existing")
422 require.False(t, ok)
423
424 value, ok := m.Get("key1")
425 require.True(t, ok)
426 require.Equal(t, 1, value)
427}
428
429func TestMap_JSONRoundTrip(t *testing.T) {
430 t.Parallel()
431
432 original := NewMap[string, int]()
433 original.Set("key1", 1)
434 original.Set("key2", 2)
435 original.Set("key3", 3)
436
437 data, err := json.Marshal(original)
438 require.NoError(t, err)
439
440 restored := NewMap[string, int]()
441 err = json.Unmarshal(data, restored)
442 require.NoError(t, err)
443
444 require.Equal(t, original.Len(), restored.Len())
445
446 for k, v := range original.Seq2() {
447 restoredValue, ok := restored.Get(k)
448 require.True(t, ok)
449 require.Equal(t, v, restoredValue)
450 }
451}
452
453func TestMap_ConcurrentAccess(t *testing.T) {
454 t.Parallel()
455
456 m := NewMap[int, int]()
457 const numGoroutines = 100
458 const numOperations = 100
459
460 var wg sync.WaitGroup
461 wg.Add(numGoroutines)
462
463 for i := range numGoroutines {
464 go func(id int) {
465 defer wg.Done()
466 for j := range numOperations {
467 key := id*numOperations + j
468 m.Set(key, key*2)
469 value, ok := m.Get(key)
470 require.True(t, ok)
471 require.Equal(t, key*2, value)
472 }
473 }(i)
474 }
475
476 wg.Wait()
477
478 require.Equal(t, numGoroutines*numOperations, m.Len())
479}
480
481func TestMap_ConcurrentReadWrite(t *testing.T) {
482 t.Parallel()
483
484 m := NewMap[int, int]()
485 const numReaders = 50
486 const numWriters = 50
487 const numOperations = 100
488
489 for i := range 1000 {
490 m.Set(i, i)
491 }
492
493 var wg sync.WaitGroup
494 wg.Add(numReaders + numWriters)
495
496 for range numReaders {
497 go func() {
498 defer wg.Done()
499 for j := range numOperations {
500 key := j % 1000
501 value, ok := m.Get(key)
502 if ok {
503 require.Equal(t, key, value)
504 }
505 _ = m.Len()
506 }
507 }()
508 }
509
510 for i := range numWriters {
511 go func(id int) {
512 defer wg.Done()
513 for j := range numOperations {
514 key := 1000 + id*numOperations + j
515 m.Set(key, key)
516 if j%10 == 0 {
517 m.Del(key)
518 }
519 }
520 }(i)
521 }
522
523 wg.Wait()
524}
525
526func TestMap_ConcurrentSeq2(t *testing.T) {
527 t.Parallel()
528
529 m := NewMap[int, int]()
530 for i := range 100 {
531 m.Set(i, i*2)
532 }
533
534 var wg sync.WaitGroup
535 const numIterators = 10
536
537 wg.Add(numIterators)
538 for range numIterators {
539 go func() {
540 defer wg.Done()
541 count := 0
542 for k, v := range m.Seq2() {
543 require.Equal(t, k*2, v)
544 count++
545 }
546 require.Equal(t, 100, count)
547 }()
548 }
549
550 wg.Wait()
551}
552
553func TestMap_ConcurrentSeq(t *testing.T) {
554 t.Parallel()
555
556 m := NewMap[int, int]()
557 for i := range 100 {
558 m.Set(i, i*2)
559 }
560
561 var wg sync.WaitGroup
562 const numIterators = 10
563
564 wg.Add(numIterators)
565 for range numIterators {
566 go func() {
567 defer wg.Done()
568 count := 0
569 values := make(map[int]bool)
570 for v := range m.Seq() {
571 values[v] = true
572 count++
573 }
574 require.Equal(t, 100, count)
575 for i := range 100 {
576 require.True(t, values[i*2])
577 }
578 }()
579 }
580
581 wg.Wait()
582}
583
584func TestMap_ConcurrentTake(t *testing.T) {
585 t.Parallel()
586
587 m := NewMap[int, int]()
588 const numItems = 1000
589
590 for i := range numItems {
591 m.Set(i, i*2)
592 }
593
594 var wg sync.WaitGroup
595 const numWorkers = 10
596 taken := make([][]int, numWorkers)
597
598 wg.Add(numWorkers)
599 for i := range numWorkers {
600 go func(workerID int) {
601 defer wg.Done()
602 taken[workerID] = make([]int, 0)
603 for j := workerID; j < numItems; j += numWorkers {
604 if value, ok := m.Take(j); ok {
605 taken[workerID] = append(taken[workerID], value)
606 }
607 }
608 }(i)
609 }
610
611 wg.Wait()
612
613 require.Equal(t, 0, m.Len())
614
615 allTaken := make(map[int]bool)
616 for _, workerTaken := range taken {
617 for _, value := range workerTaken {
618 require.False(t, allTaken[value], "Value %d was taken multiple times", value)
619 allTaken[value] = true
620 }
621 }
622
623 require.Equal(t, numItems, len(allTaken))
624 for i := range numItems {
625 require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
626 }
627}
628
629func TestMap_TypeSafety(t *testing.T) {
630 t.Parallel()
631
632 stringIntMap := NewMap[string, int]()
633 stringIntMap.Set("key", 42)
634 value, ok := stringIntMap.Get("key")
635 require.True(t, ok)
636 require.Equal(t, 42, value)
637
638 intStringMap := NewMap[int, string]()
639 intStringMap.Set(42, "value")
640 strValue, ok := intStringMap.Get(42)
641 require.True(t, ok)
642 require.Equal(t, "value", strValue)
643
644 structMap := NewMap[string, struct{ Name string }]()
645 structMap.Set("key", struct{ Name string }{Name: "test"})
646 structValue, ok := structMap.Get("key")
647 require.True(t, ok)
648 require.Equal(t, "test", structValue.Name)
649}
650
651func TestMap_InterfaceCompliance(t *testing.T) {
652 t.Parallel()
653
654 var _ json.Marshaler = &Map[string, any]{}
655 var _ json.Unmarshaler = &Map[string, any]{}
656}
657
658func BenchmarkMap_Set(b *testing.B) {
659 m := NewMap[int, int]()
660
661 for i := 0; b.Loop(); i++ {
662 m.Set(i, i*2)
663 }
664}
665
666func BenchmarkMap_Get(b *testing.B) {
667 m := NewMap[int, int]()
668 for i := range 1000 {
669 m.Set(i, i*2)
670 }
671
672 for i := 0; b.Loop(); i++ {
673 m.Get(i % 1000)
674 }
675}
676
677func BenchmarkMap_Seq2(b *testing.B) {
678 m := NewMap[int, int]()
679 for i := range 1000 {
680 m.Set(i, i*2)
681 }
682
683 for b.Loop() {
684 for range m.Seq2() {
685 }
686 }
687}
688
689func BenchmarkMap_Seq(b *testing.B) {
690 m := NewMap[int, int]()
691 for i := range 1000 {
692 m.Set(i, i*2)
693 }
694
695 for b.Loop() {
696 for range m.Seq() {
697 }
698 }
699}
700
701func BenchmarkMap_Take(b *testing.B) {
702 m := NewMap[int, int]()
703 for i := range 1000 {
704 m.Set(i, i*2)
705 }
706
707 b.ResetTimer()
708 for i := 0; b.Loop(); i++ {
709 key := i % 1000
710 m.Take(key)
711 if i%1000 == 999 {
712 b.StopTimer()
713 for j := range 1000 {
714 m.Set(j, j*2)
715 }
716 b.StartTimer()
717 }
718 }
719}
720
721func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
722 m := NewMap[int, int]()
723 for i := range 1000 {
724 m.Set(i, i*2)
725 }
726
727 b.ResetTimer()
728 b.RunParallel(func(pb *testing.PB) {
729 i := 0
730 for pb.Next() {
731 if i%2 == 0 {
732 m.Get(i % 1000)
733 } else {
734 m.Set(i+1000, i*2)
735 }
736 i++
737 }
738 })
739}