1package csync
2
3import (
4 "encoding/json"
5 "maps"
6 "sync"
7 "sync/atomic"
8 "testing"
9 "testing/synctest"
10 "time"
11
12 "github.com/stretchr/testify/require"
13)
14
15func TestNewMap(t *testing.T) {
16 t.Parallel()
17
18 m := NewMap[string, int]()
19 require.NotNil(t, m)
20 require.NotNil(t, m.inner)
21 require.Equal(t, 0, m.Len())
22}
23
24func TestNewMapFrom(t *testing.T) {
25 t.Parallel()
26
27 original := map[string]int{
28 "key1": 1,
29 "key2": 2,
30 }
31
32 m := NewMapFrom(original)
33 require.NotNil(t, m)
34 require.Equal(t, original, m.inner)
35 require.Equal(t, 2, m.Len())
36
37 value, ok := m.Get("key1")
38 require.True(t, ok)
39 require.Equal(t, 1, value)
40}
41
42func TestNewLazyMap(t *testing.T) {
43 t.Parallel()
44
45 synctest.Test(t, func(t *testing.T) {
46 t.Helper()
47
48 waiter := sync.Mutex{}
49 waiter.Lock()
50 var loadCalled atomic.Bool
51
52 loadFunc := func() map[string]int {
53 waiter.Lock()
54 defer waiter.Unlock()
55 loadCalled.Store(true)
56 return map[string]int{
57 "key1": 1,
58 "key2": 2,
59 }
60 }
61
62 m := NewLazyMap(loadFunc)
63 require.NotNil(t, m)
64
65 waiter.Unlock() // Allow the load function to proceed
66 time.Sleep(100 * time.Millisecond)
67 require.True(t, loadCalled.Load())
68 require.Equal(t, 2, m.Len())
69
70 value, ok := m.Get("key1")
71 require.True(t, ok)
72 require.Equal(t, 1, value)
73 })
74}
75
76func TestMap_Reset(t *testing.T) {
77 t.Parallel()
78
79 m := NewMapFrom(map[string]int{
80 "a": 10,
81 })
82
83 m.Reset(map[string]int{
84 "b": 20,
85 })
86 value, ok := m.Get("b")
87 require.True(t, ok)
88 require.Equal(t, 20, value)
89 require.Equal(t, 1, m.Len())
90}
91
92func TestMap_Set(t *testing.T) {
93 t.Parallel()
94
95 m := NewMap[string, int]()
96
97 m.Set("key1", 42)
98 value, ok := m.Get("key1")
99 require.True(t, ok)
100 require.Equal(t, 42, value)
101 require.Equal(t, 1, m.Len())
102
103 m.Set("key1", 100)
104 value, ok = m.Get("key1")
105 require.True(t, ok)
106 require.Equal(t, 100, value)
107 require.Equal(t, 1, m.Len())
108}
109
110func TestMap_GetOrSet(t *testing.T) {
111 t.Parallel()
112
113 m := NewMap[string, int]()
114
115 require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
116 require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
117 require.Equal(t, 1, m.Len())
118}
119
120func TestMap_Get(t *testing.T) {
121 t.Parallel()
122
123 m := NewMap[string, int]()
124
125 value, ok := m.Get("nonexistent")
126 require.False(t, ok)
127 require.Equal(t, 0, value)
128
129 m.Set("key1", 42)
130 value, ok = m.Get("key1")
131 require.True(t, ok)
132 require.Equal(t, 42, value)
133}
134
135func TestMap_Del(t *testing.T) {
136 t.Parallel()
137
138 m := NewMap[string, int]()
139 m.Set("key1", 42)
140 m.Set("key2", 100)
141
142 require.Equal(t, 2, m.Len())
143
144 m.Del("key1")
145 _, ok := m.Get("key1")
146 require.False(t, ok)
147 require.Equal(t, 1, m.Len())
148
149 value, ok := m.Get("key2")
150 require.True(t, ok)
151 require.Equal(t, 100, value)
152
153 m.Del("nonexistent")
154 require.Equal(t, 1, m.Len())
155}
156
157func TestMap_Len(t *testing.T) {
158 t.Parallel()
159
160 m := NewMap[string, int]()
161 require.Equal(t, 0, m.Len())
162
163 m.Set("key1", 1)
164 require.Equal(t, 1, m.Len())
165
166 m.Set("key2", 2)
167 require.Equal(t, 2, m.Len())
168
169 m.Del("key1")
170 require.Equal(t, 1, m.Len())
171
172 m.Del("key2")
173 require.Equal(t, 0, m.Len())
174}
175
176func TestMap_Take(t *testing.T) {
177 t.Parallel()
178
179 m := NewMap[string, int]()
180 m.Set("key1", 42)
181 m.Set("key2", 100)
182
183 require.Equal(t, 2, m.Len())
184
185 value, ok := m.Take("key1")
186 require.True(t, ok)
187 require.Equal(t, 42, value)
188 require.Equal(t, 1, m.Len())
189
190 _, exists := m.Get("key1")
191 require.False(t, exists)
192
193 value, ok = m.Get("key2")
194 require.True(t, ok)
195 require.Equal(t, 100, value)
196}
197
198func TestMap_Take_NonexistentKey(t *testing.T) {
199 t.Parallel()
200
201 m := NewMap[string, int]()
202 m.Set("key1", 42)
203
204 value, ok := m.Take("nonexistent")
205 require.False(t, ok)
206 require.Equal(t, 0, value)
207 require.Equal(t, 1, m.Len())
208
209 value, ok = m.Get("key1")
210 require.True(t, ok)
211 require.Equal(t, 42, value)
212}
213
214func TestMap_Take_EmptyMap(t *testing.T) {
215 t.Parallel()
216
217 m := NewMap[string, int]()
218
219 value, ok := m.Take("key1")
220 require.False(t, ok)
221 require.Equal(t, 0, value)
222 require.Equal(t, 0, m.Len())
223}
224
225func TestMap_Take_SameKeyTwice(t *testing.T) {
226 t.Parallel()
227
228 m := NewMap[string, int]()
229 m.Set("key1", 42)
230
231 value, ok := m.Take("key1")
232 require.True(t, ok)
233 require.Equal(t, 42, value)
234 require.Equal(t, 0, m.Len())
235
236 value, ok = m.Take("key1")
237 require.False(t, ok)
238 require.Equal(t, 0, value)
239 require.Equal(t, 0, m.Len())
240}
241
242func TestMap_Seq2(t *testing.T) {
243 t.Parallel()
244
245 m := NewMap[string, int]()
246 m.Set("key1", 1)
247 m.Set("key2", 2)
248 m.Set("key3", 3)
249
250 collected := maps.Collect(m.Seq2())
251
252 require.Equal(t, 3, len(collected))
253 require.Equal(t, 1, collected["key1"])
254 require.Equal(t, 2, collected["key2"])
255 require.Equal(t, 3, collected["key3"])
256}
257
258func TestMap_Seq2_EarlyReturn(t *testing.T) {
259 t.Parallel()
260
261 m := NewMap[string, int]()
262 m.Set("key1", 1)
263 m.Set("key2", 2)
264 m.Set("key3", 3)
265
266 count := 0
267 for range m.Seq2() {
268 count++
269 if count == 2 {
270 break
271 }
272 }
273
274 require.Equal(t, 2, count)
275}
276
277func TestMap_Seq2_EmptyMap(t *testing.T) {
278 t.Parallel()
279
280 m := NewMap[string, int]()
281
282 count := 0
283 for range m.Seq2() {
284 count++
285 }
286
287 require.Equal(t, 0, count)
288}
289
290func TestMap_Seq(t *testing.T) {
291 t.Parallel()
292
293 m := NewMap[string, int]()
294 m.Set("key1", 1)
295 m.Set("key2", 2)
296 m.Set("key3", 3)
297
298 collected := make([]int, 0)
299 for v := range m.Seq() {
300 collected = append(collected, v)
301 }
302
303 require.Equal(t, 3, len(collected))
304 require.Contains(t, collected, 1)
305 require.Contains(t, collected, 2)
306 require.Contains(t, collected, 3)
307}
308
309func TestMap_Seq_EarlyReturn(t *testing.T) {
310 t.Parallel()
311
312 m := NewMap[string, int]()
313 m.Set("key1", 1)
314 m.Set("key2", 2)
315 m.Set("key3", 3)
316
317 count := 0
318 for range m.Seq() {
319 count++
320 if count == 2 {
321 break
322 }
323 }
324
325 require.Equal(t, 2, count)
326}
327
328func TestMap_Seq_EmptyMap(t *testing.T) {
329 t.Parallel()
330
331 m := NewMap[string, int]()
332
333 count := 0
334 for range m.Seq() {
335 count++
336 }
337
338 require.Equal(t, 0, count)
339}
340
341func TestMap_MarshalJSON(t *testing.T) {
342 t.Parallel()
343
344 m := NewMap[string, int]()
345 m.Set("key1", 1)
346 m.Set("key2", 2)
347
348 data, err := json.Marshal(m)
349 require.NoError(t, err)
350
351 result := &Map[string, int]{}
352 err = json.Unmarshal(data, result)
353 require.NoError(t, err)
354 require.Equal(t, 2, result.Len())
355 v1, _ := result.Get("key1")
356 v2, _ := result.Get("key2")
357 require.Equal(t, 1, v1)
358 require.Equal(t, 2, v2)
359}
360
361func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
362 t.Parallel()
363
364 m := NewMap[string, int]()
365
366 data, err := json.Marshal(m)
367 require.NoError(t, err)
368 require.Equal(t, "{}", string(data))
369}
370
371func TestMap_UnmarshalJSON(t *testing.T) {
372 t.Parallel()
373
374 jsonData := `{"key1": 1, "key2": 2}`
375
376 m := NewMap[string, int]()
377 err := json.Unmarshal([]byte(jsonData), m)
378 require.NoError(t, err)
379
380 require.Equal(t, 2, m.Len())
381 value, ok := m.Get("key1")
382 require.True(t, ok)
383 require.Equal(t, 1, value)
384
385 value, ok = m.Get("key2")
386 require.True(t, ok)
387 require.Equal(t, 2, value)
388}
389
390func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
391 t.Parallel()
392
393 jsonData := `{}`
394
395 m := NewMap[string, int]()
396 err := json.Unmarshal([]byte(jsonData), m)
397 require.NoError(t, err)
398 require.Equal(t, 0, m.Len())
399}
400
401func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
402 t.Parallel()
403
404 jsonData := `{"key1": 1, "key2":}`
405
406 m := NewMap[string, int]()
407 err := json.Unmarshal([]byte(jsonData), m)
408 require.Error(t, err)
409}
410
411func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
412 t.Parallel()
413
414 m := NewMap[string, int]()
415 m.Set("existing", 999)
416
417 jsonData := `{"key1": 1, "key2": 2}`
418 err := json.Unmarshal([]byte(jsonData), m)
419 require.NoError(t, err)
420
421 require.Equal(t, 2, m.Len())
422 _, ok := m.Get("existing")
423 require.False(t, ok)
424
425 value, ok := m.Get("key1")
426 require.True(t, ok)
427 require.Equal(t, 1, value)
428}
429
430func TestMap_JSONRoundTrip(t *testing.T) {
431 t.Parallel()
432
433 original := NewMap[string, int]()
434 original.Set("key1", 1)
435 original.Set("key2", 2)
436 original.Set("key3", 3)
437
438 data, err := json.Marshal(original)
439 require.NoError(t, err)
440
441 restored := NewMap[string, int]()
442 err = json.Unmarshal(data, restored)
443 require.NoError(t, err)
444
445 require.Equal(t, original.Len(), restored.Len())
446
447 for k, v := range original.Seq2() {
448 restoredValue, ok := restored.Get(k)
449 require.True(t, ok)
450 require.Equal(t, v, restoredValue)
451 }
452}
453
454func TestMap_ConcurrentAccess(t *testing.T) {
455 t.Parallel()
456
457 m := NewMap[int, int]()
458 const numGoroutines = 100
459 const numOperations = 100
460
461 var wg sync.WaitGroup
462 wg.Add(numGoroutines)
463
464 for i := range numGoroutines {
465 go func(id int) {
466 defer wg.Done()
467 for j := range numOperations {
468 key := id*numOperations + j
469 m.Set(key, key*2)
470 value, ok := m.Get(key)
471 require.True(t, ok)
472 require.Equal(t, key*2, value)
473 }
474 }(i)
475 }
476
477 wg.Wait()
478
479 require.Equal(t, numGoroutines*numOperations, m.Len())
480}
481
482func TestMap_ConcurrentReadWrite(t *testing.T) {
483 t.Parallel()
484
485 m := NewMap[int, int]()
486 const numReaders = 50
487 const numWriters = 50
488 const numOperations = 100
489
490 for i := range 1000 {
491 m.Set(i, i)
492 }
493
494 var wg sync.WaitGroup
495 wg.Add(numReaders + numWriters)
496
497 for range numReaders {
498 go func() {
499 defer wg.Done()
500 for j := range numOperations {
501 key := j % 1000
502 value, ok := m.Get(key)
503 if ok {
504 require.Equal(t, key, value)
505 }
506 _ = m.Len()
507 }
508 }()
509 }
510
511 for i := range numWriters {
512 go func(id int) {
513 defer wg.Done()
514 for j := range numOperations {
515 key := 1000 + id*numOperations + j
516 m.Set(key, key)
517 if j%10 == 0 {
518 m.Del(key)
519 }
520 }
521 }(i)
522 }
523
524 wg.Wait()
525}
526
527func TestMap_ConcurrentSeq2(t *testing.T) {
528 t.Parallel()
529
530 m := NewMap[int, int]()
531 for i := range 100 {
532 m.Set(i, i*2)
533 }
534
535 var wg sync.WaitGroup
536 const numIterators = 10
537
538 wg.Add(numIterators)
539 for range numIterators {
540 go func() {
541 defer wg.Done()
542 count := 0
543 for k, v := range m.Seq2() {
544 require.Equal(t, k*2, v)
545 count++
546 }
547 require.Equal(t, 100, count)
548 }()
549 }
550
551 wg.Wait()
552}
553
554func TestMap_ConcurrentSeq(t *testing.T) {
555 t.Parallel()
556
557 m := NewMap[int, int]()
558 for i := range 100 {
559 m.Set(i, i*2)
560 }
561
562 var wg sync.WaitGroup
563 const numIterators = 10
564
565 wg.Add(numIterators)
566 for range numIterators {
567 go func() {
568 defer wg.Done()
569 count := 0
570 values := make(map[int]bool)
571 for v := range m.Seq() {
572 values[v] = true
573 count++
574 }
575 require.Equal(t, 100, count)
576 for i := range 100 {
577 require.True(t, values[i*2])
578 }
579 }()
580 }
581
582 wg.Wait()
583}
584
585func TestMap_ConcurrentTake(t *testing.T) {
586 t.Parallel()
587
588 m := NewMap[int, int]()
589 const numItems = 1000
590
591 for i := range numItems {
592 m.Set(i, i*2)
593 }
594
595 var wg sync.WaitGroup
596 const numWorkers = 10
597 taken := make([][]int, numWorkers)
598
599 wg.Add(numWorkers)
600 for i := range numWorkers {
601 go func(workerID int) {
602 defer wg.Done()
603 taken[workerID] = make([]int, 0)
604 for j := workerID; j < numItems; j += numWorkers {
605 if value, ok := m.Take(j); ok {
606 taken[workerID] = append(taken[workerID], value)
607 }
608 }
609 }(i)
610 }
611
612 wg.Wait()
613
614 require.Equal(t, 0, m.Len())
615
616 allTaken := make(map[int]bool)
617 for _, workerTaken := range taken {
618 for _, value := range workerTaken {
619 require.False(t, allTaken[value], "Value %d was taken multiple times", value)
620 allTaken[value] = true
621 }
622 }
623
624 require.Equal(t, numItems, len(allTaken))
625 for i := range numItems {
626 require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
627 }
628}
629
630func TestMap_TypeSafety(t *testing.T) {
631 t.Parallel()
632
633 stringIntMap := NewMap[string, int]()
634 stringIntMap.Set("key", 42)
635 value, ok := stringIntMap.Get("key")
636 require.True(t, ok)
637 require.Equal(t, 42, value)
638
639 intStringMap := NewMap[int, string]()
640 intStringMap.Set(42, "value")
641 strValue, ok := intStringMap.Get(42)
642 require.True(t, ok)
643 require.Equal(t, "value", strValue)
644
645 structMap := NewMap[string, struct{ Name string }]()
646 structMap.Set("key", struct{ Name string }{Name: "test"})
647 structValue, ok := structMap.Get("key")
648 require.True(t, ok)
649 require.Equal(t, "test", structValue.Name)
650}
651
652func TestMap_InterfaceCompliance(t *testing.T) {
653 t.Parallel()
654
655 var _ json.Marshaler = &Map[string, any]{}
656 var _ json.Unmarshaler = &Map[string, any]{}
657}
658
659func BenchmarkMap_Set(b *testing.B) {
660 m := NewMap[int, int]()
661
662 for i := 0; b.Loop(); i++ {
663 m.Set(i, i*2)
664 }
665}
666
667func BenchmarkMap_Get(b *testing.B) {
668 m := NewMap[int, int]()
669 for i := range 1000 {
670 m.Set(i, i*2)
671 }
672
673 for i := 0; b.Loop(); i++ {
674 m.Get(i % 1000)
675 }
676}
677
678func BenchmarkMap_Seq2(b *testing.B) {
679 m := NewMap[int, int]()
680 for i := range 1000 {
681 m.Set(i, i*2)
682 }
683
684 for b.Loop() {
685 for range m.Seq2() {
686 }
687 }
688}
689
690func BenchmarkMap_Seq(b *testing.B) {
691 m := NewMap[int, int]()
692 for i := range 1000 {
693 m.Set(i, i*2)
694 }
695
696 for b.Loop() {
697 for range m.Seq() {
698 }
699 }
700}
701
702func BenchmarkMap_Take(b *testing.B) {
703 m := NewMap[int, int]()
704 for i := range 1000 {
705 m.Set(i, i*2)
706 }
707
708 b.ResetTimer()
709 for i := 0; b.Loop(); i++ {
710 key := i % 1000
711 m.Take(key)
712 if i%1000 == 999 {
713 b.StopTimer()
714 for j := range 1000 {
715 m.Set(j, j*2)
716 }
717 b.StartTimer()
718 }
719 }
720}
721
722func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
723 m := NewMap[int, int]()
724 for i := range 1000 {
725 m.Set(i, i*2)
726 }
727
728 b.ResetTimer()
729 b.RunParallel(func(pb *testing.PB) {
730 i := 0
731 for pb.Next() {
732 if i%2 == 0 {
733 m.Get(i % 1000)
734 } else {
735 m.Set(i+1000, i*2)
736 }
737 i++
738 }
739 })
740}