1package csync
2
3import (
4 "encoding/json"
5 "maps"
6 "sync"
7 "testing"
8 "time"
9
10 "github.com/stretchr/testify/require"
11)
12
13func TestNewMap(t *testing.T) {
14 t.Parallel()
15
16 m := NewMap[string, int]()
17 require.NotNil(t, m)
18 require.NotNil(t, m.inner)
19 require.Equal(t, 0, m.Len())
20}
21
22func TestNewMapFrom(t *testing.T) {
23 t.Parallel()
24
25 original := map[string]int{
26 "key1": 1,
27 "key2": 2,
28 }
29
30 m := NewMapFrom(original)
31 require.NotNil(t, m)
32 require.Equal(t, original, m.inner)
33 require.Equal(t, 2, m.Len())
34
35 value, ok := m.Get("key1")
36 require.True(t, ok)
37 require.Equal(t, 1, value)
38}
39
40func TestNewLazyMap(t *testing.T) {
41 t.Parallel()
42
43 waiter := sync.Mutex{}
44 waiter.Lock()
45 loadCalled := false
46
47 loadFunc := func() map[string]int {
48 waiter.Lock()
49 defer waiter.Unlock()
50 loadCalled = true
51 return map[string]int{
52 "key1": 1,
53 "key2": 2,
54 }
55 }
56
57 m := NewLazyMap(loadFunc)
58 require.NotNil(t, m)
59
60 waiter.Unlock() // Allow the load function to proceed
61 time.Sleep(100 * time.Millisecond)
62 require.True(t, loadCalled)
63 require.Equal(t, 2, m.Len())
64
65 value, ok := m.Get("key1")
66 require.True(t, ok)
67 require.Equal(t, 1, value)
68}
69
70func TestMap_Set(t *testing.T) {
71 t.Parallel()
72
73 m := NewMap[string, int]()
74
75 m.Set("key1", 42)
76 value, ok := m.Get("key1")
77 require.True(t, ok)
78 require.Equal(t, 42, value)
79 require.Equal(t, 1, m.Len())
80
81 m.Set("key1", 100)
82 value, ok = m.Get("key1")
83 require.True(t, ok)
84 require.Equal(t, 100, value)
85 require.Equal(t, 1, m.Len())
86}
87
88func TestMap_GetOrSet(t *testing.T) {
89 t.Parallel()
90
91 m := NewMap[string, int]()
92
93 require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
94 require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
95 require.Equal(t, 1, m.Len())
96}
97
98func TestMap_Get(t *testing.T) {
99 t.Parallel()
100
101 m := NewMap[string, int]()
102
103 value, ok := m.Get("nonexistent")
104 require.False(t, ok)
105 require.Equal(t, 0, value)
106
107 m.Set("key1", 42)
108 value, ok = m.Get("key1")
109 require.True(t, ok)
110 require.Equal(t, 42, value)
111}
112
113func TestMap_Del(t *testing.T) {
114 t.Parallel()
115
116 m := NewMap[string, int]()
117 m.Set("key1", 42)
118 m.Set("key2", 100)
119
120 require.Equal(t, 2, m.Len())
121
122 m.Del("key1")
123 _, ok := m.Get("key1")
124 require.False(t, ok)
125 require.Equal(t, 1, m.Len())
126
127 value, ok := m.Get("key2")
128 require.True(t, ok)
129 require.Equal(t, 100, value)
130
131 m.Del("nonexistent")
132 require.Equal(t, 1, m.Len())
133}
134
135func TestMap_Len(t *testing.T) {
136 t.Parallel()
137
138 m := NewMap[string, int]()
139 require.Equal(t, 0, m.Len())
140
141 m.Set("key1", 1)
142 require.Equal(t, 1, m.Len())
143
144 m.Set("key2", 2)
145 require.Equal(t, 2, m.Len())
146
147 m.Del("key1")
148 require.Equal(t, 1, m.Len())
149
150 m.Del("key2")
151 require.Equal(t, 0, m.Len())
152}
153
154func TestMap_Take(t *testing.T) {
155 t.Parallel()
156
157 m := NewMap[string, int]()
158 m.Set("key1", 42)
159 m.Set("key2", 100)
160
161 require.Equal(t, 2, m.Len())
162
163 value, ok := m.Take("key1")
164 require.True(t, ok)
165 require.Equal(t, 42, value)
166 require.Equal(t, 1, m.Len())
167
168 _, exists := m.Get("key1")
169 require.False(t, exists)
170
171 value, ok = m.Get("key2")
172 require.True(t, ok)
173 require.Equal(t, 100, value)
174}
175
176func TestMap_Take_NonexistentKey(t *testing.T) {
177 t.Parallel()
178
179 m := NewMap[string, int]()
180 m.Set("key1", 42)
181
182 value, ok := m.Take("nonexistent")
183 require.False(t, ok)
184 require.Equal(t, 0, value)
185 require.Equal(t, 1, m.Len())
186
187 value, ok = m.Get("key1")
188 require.True(t, ok)
189 require.Equal(t, 42, value)
190}
191
192func TestMap_Take_EmptyMap(t *testing.T) {
193 t.Parallel()
194
195 m := NewMap[string, int]()
196
197 value, ok := m.Take("key1")
198 require.False(t, ok)
199 require.Equal(t, 0, value)
200 require.Equal(t, 0, m.Len())
201}
202
203func TestMap_Take_SameKeyTwice(t *testing.T) {
204 t.Parallel()
205
206 m := NewMap[string, int]()
207 m.Set("key1", 42)
208
209 value, ok := m.Take("key1")
210 require.True(t, ok)
211 require.Equal(t, 42, value)
212 require.Equal(t, 0, m.Len())
213
214 value, ok = m.Take("key1")
215 require.False(t, ok)
216 require.Equal(t, 0, value)
217 require.Equal(t, 0, m.Len())
218}
219
220func TestMap_Seq2(t *testing.T) {
221 t.Parallel()
222
223 m := NewMap[string, int]()
224 m.Set("key1", 1)
225 m.Set("key2", 2)
226 m.Set("key3", 3)
227
228 collected := maps.Collect(m.Seq2())
229
230 require.Equal(t, 3, len(collected))
231 require.Equal(t, 1, collected["key1"])
232 require.Equal(t, 2, collected["key2"])
233 require.Equal(t, 3, collected["key3"])
234}
235
236func TestMap_Seq2_EarlyReturn(t *testing.T) {
237 t.Parallel()
238
239 m := NewMap[string, int]()
240 m.Set("key1", 1)
241 m.Set("key2", 2)
242 m.Set("key3", 3)
243
244 count := 0
245 for range m.Seq2() {
246 count++
247 if count == 2 {
248 break
249 }
250 }
251
252 require.Equal(t, 2, count)
253}
254
255func TestMap_Seq2_EmptyMap(t *testing.T) {
256 t.Parallel()
257
258 m := NewMap[string, int]()
259
260 count := 0
261 for range m.Seq2() {
262 count++
263 }
264
265 require.Equal(t, 0, count)
266}
267
268func TestMap_Seq(t *testing.T) {
269 t.Parallel()
270
271 m := NewMap[string, int]()
272 m.Set("key1", 1)
273 m.Set("key2", 2)
274 m.Set("key3", 3)
275
276 collected := make([]int, 0)
277 for v := range m.Seq() {
278 collected = append(collected, v)
279 }
280
281 require.Equal(t, 3, len(collected))
282 require.Contains(t, collected, 1)
283 require.Contains(t, collected, 2)
284 require.Contains(t, collected, 3)
285}
286
287func TestMap_Seq_EarlyReturn(t *testing.T) {
288 t.Parallel()
289
290 m := NewMap[string, int]()
291 m.Set("key1", 1)
292 m.Set("key2", 2)
293 m.Set("key3", 3)
294
295 count := 0
296 for range m.Seq() {
297 count++
298 if count == 2 {
299 break
300 }
301 }
302
303 require.Equal(t, 2, count)
304}
305
306func TestMap_Seq_EmptyMap(t *testing.T) {
307 t.Parallel()
308
309 m := NewMap[string, int]()
310
311 count := 0
312 for range m.Seq() {
313 count++
314 }
315
316 require.Equal(t, 0, count)
317}
318
319func TestMap_MarshalJSON(t *testing.T) {
320 t.Parallel()
321
322 m := NewMap[string, int]()
323 m.Set("key1", 1)
324 m.Set("key2", 2)
325
326 data, err := json.Marshal(m)
327 require.NoError(t, err)
328
329 result := &Map[string, int]{}
330 err = json.Unmarshal(data, result)
331 require.NoError(t, err)
332 require.Equal(t, 2, result.Len())
333 v1, _ := result.Get("key1")
334 v2, _ := result.Get("key2")
335 require.Equal(t, 1, v1)
336 require.Equal(t, 2, v2)
337}
338
339func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
340 t.Parallel()
341
342 m := NewMap[string, int]()
343
344 data, err := json.Marshal(m)
345 require.NoError(t, err)
346 require.Equal(t, "{}", string(data))
347}
348
349func TestMap_UnmarshalJSON(t *testing.T) {
350 t.Parallel()
351
352 jsonData := `{"key1": 1, "key2": 2}`
353
354 m := NewMap[string, int]()
355 err := json.Unmarshal([]byte(jsonData), m)
356 require.NoError(t, err)
357
358 require.Equal(t, 2, m.Len())
359 value, ok := m.Get("key1")
360 require.True(t, ok)
361 require.Equal(t, 1, value)
362
363 value, ok = m.Get("key2")
364 require.True(t, ok)
365 require.Equal(t, 2, value)
366}
367
368func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
369 t.Parallel()
370
371 jsonData := `{}`
372
373 m := NewMap[string, int]()
374 err := json.Unmarshal([]byte(jsonData), m)
375 require.NoError(t, err)
376 require.Equal(t, 0, m.Len())
377}
378
379func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
380 t.Parallel()
381
382 jsonData := `{"key1": 1, "key2":}`
383
384 m := NewMap[string, int]()
385 err := json.Unmarshal([]byte(jsonData), m)
386 require.Error(t, err)
387}
388
389func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
390 t.Parallel()
391
392 m := NewMap[string, int]()
393 m.Set("existing", 999)
394
395 jsonData := `{"key1": 1, "key2": 2}`
396 err := json.Unmarshal([]byte(jsonData), m)
397 require.NoError(t, err)
398
399 require.Equal(t, 2, m.Len())
400 _, ok := m.Get("existing")
401 require.False(t, ok)
402
403 value, ok := m.Get("key1")
404 require.True(t, ok)
405 require.Equal(t, 1, value)
406}
407
408func TestMap_JSONRoundTrip(t *testing.T) {
409 t.Parallel()
410
411 original := NewMap[string, int]()
412 original.Set("key1", 1)
413 original.Set("key2", 2)
414 original.Set("key3", 3)
415
416 data, err := json.Marshal(original)
417 require.NoError(t, err)
418
419 restored := NewMap[string, int]()
420 err = json.Unmarshal(data, restored)
421 require.NoError(t, err)
422
423 require.Equal(t, original.Len(), restored.Len())
424
425 for k, v := range original.Seq2() {
426 restoredValue, ok := restored.Get(k)
427 require.True(t, ok)
428 require.Equal(t, v, restoredValue)
429 }
430}
431
432func TestMap_ConcurrentAccess(t *testing.T) {
433 t.Parallel()
434
435 m := NewMap[int, int]()
436 const numGoroutines = 100
437 const numOperations = 100
438
439 var wg sync.WaitGroup
440 wg.Add(numGoroutines)
441
442 for i := range numGoroutines {
443 go func(id int) {
444 defer wg.Done()
445 for j := range numOperations {
446 key := id*numOperations + j
447 m.Set(key, key*2)
448 value, ok := m.Get(key)
449 require.True(t, ok)
450 require.Equal(t, key*2, value)
451 }
452 }(i)
453 }
454
455 wg.Wait()
456
457 require.Equal(t, numGoroutines*numOperations, m.Len())
458}
459
460func TestMap_ConcurrentReadWrite(t *testing.T) {
461 t.Parallel()
462
463 m := NewMap[int, int]()
464 const numReaders = 50
465 const numWriters = 50
466 const numOperations = 100
467
468 for i := range 1000 {
469 m.Set(i, i)
470 }
471
472 var wg sync.WaitGroup
473 wg.Add(numReaders + numWriters)
474
475 for range numReaders {
476 go func() {
477 defer wg.Done()
478 for j := range numOperations {
479 key := j % 1000
480 value, ok := m.Get(key)
481 if ok {
482 require.Equal(t, key, value)
483 }
484 _ = m.Len()
485 }
486 }()
487 }
488
489 for i := range numWriters {
490 go func(id int) {
491 defer wg.Done()
492 for j := range numOperations {
493 key := 1000 + id*numOperations + j
494 m.Set(key, key)
495 if j%10 == 0 {
496 m.Del(key)
497 }
498 }
499 }(i)
500 }
501
502 wg.Wait()
503}
504
505func TestMap_ConcurrentSeq2(t *testing.T) {
506 t.Parallel()
507
508 m := NewMap[int, int]()
509 for i := range 100 {
510 m.Set(i, i*2)
511 }
512
513 var wg sync.WaitGroup
514 const numIterators = 10
515
516 wg.Add(numIterators)
517 for range numIterators {
518 go func() {
519 defer wg.Done()
520 count := 0
521 for k, v := range m.Seq2() {
522 require.Equal(t, k*2, v)
523 count++
524 }
525 require.Equal(t, 100, count)
526 }()
527 }
528
529 wg.Wait()
530}
531
532func TestMap_ConcurrentSeq(t *testing.T) {
533 t.Parallel()
534
535 m := NewMap[int, int]()
536 for i := range 100 {
537 m.Set(i, i*2)
538 }
539
540 var wg sync.WaitGroup
541 const numIterators = 10
542
543 wg.Add(numIterators)
544 for range numIterators {
545 go func() {
546 defer wg.Done()
547 count := 0
548 values := make(map[int]bool)
549 for v := range m.Seq() {
550 values[v] = true
551 count++
552 }
553 require.Equal(t, 100, count)
554 for i := range 100 {
555 require.True(t, values[i*2])
556 }
557 }()
558 }
559
560 wg.Wait()
561}
562
563func TestMap_ConcurrentTake(t *testing.T) {
564 t.Parallel()
565
566 m := NewMap[int, int]()
567 const numItems = 1000
568
569 for i := range numItems {
570 m.Set(i, i*2)
571 }
572
573 var wg sync.WaitGroup
574 const numWorkers = 10
575 taken := make([][]int, numWorkers)
576
577 wg.Add(numWorkers)
578 for i := range numWorkers {
579 go func(workerID int) {
580 defer wg.Done()
581 taken[workerID] = make([]int, 0)
582 for j := workerID; j < numItems; j += numWorkers {
583 if value, ok := m.Take(j); ok {
584 taken[workerID] = append(taken[workerID], value)
585 }
586 }
587 }(i)
588 }
589
590 wg.Wait()
591
592 require.Equal(t, 0, m.Len())
593
594 allTaken := make(map[int]bool)
595 for _, workerTaken := range taken {
596 for _, value := range workerTaken {
597 require.False(t, allTaken[value], "Value %d was taken multiple times", value)
598 allTaken[value] = true
599 }
600 }
601
602 require.Equal(t, numItems, len(allTaken))
603 for i := range numItems {
604 require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
605 }
606}
607
608func TestMap_TypeSafety(t *testing.T) {
609 t.Parallel()
610
611 stringIntMap := NewMap[string, int]()
612 stringIntMap.Set("key", 42)
613 value, ok := stringIntMap.Get("key")
614 require.True(t, ok)
615 require.Equal(t, 42, value)
616
617 intStringMap := NewMap[int, string]()
618 intStringMap.Set(42, "value")
619 strValue, ok := intStringMap.Get(42)
620 require.True(t, ok)
621 require.Equal(t, "value", strValue)
622
623 structMap := NewMap[string, struct{ Name string }]()
624 structMap.Set("key", struct{ Name string }{Name: "test"})
625 structValue, ok := structMap.Get("key")
626 require.True(t, ok)
627 require.Equal(t, "test", structValue.Name)
628}
629
630func TestMap_InterfaceCompliance(t *testing.T) {
631 t.Parallel()
632
633 var _ json.Marshaler = &Map[string, any]{}
634 var _ json.Unmarshaler = &Map[string, any]{}
635}
636
637func BenchmarkMap_Set(b *testing.B) {
638 m := NewMap[int, int]()
639
640 for i := 0; b.Loop(); i++ {
641 m.Set(i, i*2)
642 }
643}
644
645func BenchmarkMap_Get(b *testing.B) {
646 m := NewMap[int, int]()
647 for i := range 1000 {
648 m.Set(i, i*2)
649 }
650
651 for i := 0; b.Loop(); i++ {
652 m.Get(i % 1000)
653 }
654}
655
656func BenchmarkMap_Seq2(b *testing.B) {
657 m := NewMap[int, int]()
658 for i := range 1000 {
659 m.Set(i, i*2)
660 }
661
662 for b.Loop() {
663 for range m.Seq2() {
664 }
665 }
666}
667
668func BenchmarkMap_Seq(b *testing.B) {
669 m := NewMap[int, int]()
670 for i := range 1000 {
671 m.Set(i, i*2)
672 }
673
674 for b.Loop() {
675 for range m.Seq() {
676 }
677 }
678}
679
680func BenchmarkMap_Take(b *testing.B) {
681 m := NewMap[int, int]()
682 for i := range 1000 {
683 m.Set(i, i*2)
684 }
685
686 b.ResetTimer()
687 for i := 0; b.Loop(); i++ {
688 key := i % 1000
689 m.Take(key)
690 if i%1000 == 999 {
691 b.StopTimer()
692 for j := range 1000 {
693 m.Set(j, j*2)
694 }
695 b.StartTimer()
696 }
697 }
698}
699
700func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
701 m := NewMap[int, int]()
702 for i := range 1000 {
703 m.Set(i, i*2)
704 }
705
706 b.ResetTimer()
707 b.RunParallel(func(pb *testing.PB) {
708 i := 0
709 for pb.Next() {
710 if i%2 == 0 {
711 m.Get(i % 1000)
712 } else {
713 m.Set(i+1000, i*2)
714 }
715 i++
716 }
717 })
718}