1package csync
  2
  3import (
  4	"encoding/json"
  5	"maps"
  6	"sync"
  7	"testing"
  8	"testing/synctest"
  9	"time"
 10
 11	"github.com/stretchr/testify/require"
 12)
 13
 14func TestNewMap(t *testing.T) {
 15	t.Parallel()
 16
 17	m := NewMap[string, int]()
 18	require.NotNil(t, m)
 19	require.NotNil(t, m.inner)
 20	require.Equal(t, 0, m.Len())
 21}
 22
 23func TestNewMapFrom(t *testing.T) {
 24	t.Parallel()
 25
 26	original := map[string]int{
 27		"key1": 1,
 28		"key2": 2,
 29	}
 30
 31	m := NewMapFrom(original)
 32	require.NotNil(t, m)
 33	require.Equal(t, original, m.inner)
 34	require.Equal(t, 2, m.Len())
 35
 36	value, ok := m.Get("key1")
 37	require.True(t, ok)
 38	require.Equal(t, 1, value)
 39}
 40
 41func TestNewLazyMap(t *testing.T) {
 42	t.Parallel()
 43
 44	synctest.Test(t, func(t *testing.T) {
 45		t.Helper()
 46
 47		waiter := sync.Mutex{}
 48		waiter.Lock()
 49		loadCalled := false
 50
 51		loadFunc := func() map[string]int {
 52			waiter.Lock()
 53			defer waiter.Unlock()
 54			loadCalled = true
 55			return map[string]int{
 56				"key1": 1,
 57				"key2": 2,
 58			}
 59		}
 60
 61		m := NewLazyMap(loadFunc)
 62		require.NotNil(t, m)
 63
 64		waiter.Unlock() // Allow the load function to proceed
 65		time.Sleep(100 * time.Millisecond)
 66		require.True(t, loadCalled)
 67		require.Equal(t, 2, m.Len())
 68
 69		value, ok := m.Get("key1")
 70		require.True(t, ok)
 71		require.Equal(t, 1, value)
 72	})
 73}
 74
 75func TestMap_Reset(t *testing.T) {
 76	t.Parallel()
 77
 78	m := NewMapFrom(map[string]int{
 79		"a": 10,
 80	})
 81
 82	m.Reset(map[string]int{
 83		"b": 20,
 84	})
 85	value, ok := m.Get("b")
 86	require.True(t, ok)
 87	require.Equal(t, 20, value)
 88	require.Equal(t, 1, m.Len())
 89}
 90
 91func TestMap_Set(t *testing.T) {
 92	t.Parallel()
 93
 94	m := NewMap[string, int]()
 95
 96	m.Set("key1", 42)
 97	value, ok := m.Get("key1")
 98	require.True(t, ok)
 99	require.Equal(t, 42, value)
100	require.Equal(t, 1, m.Len())
101
102	m.Set("key1", 100)
103	value, ok = m.Get("key1")
104	require.True(t, ok)
105	require.Equal(t, 100, value)
106	require.Equal(t, 1, m.Len())
107}
108
109func TestMap_GetOrSet(t *testing.T) {
110	t.Parallel()
111
112	m := NewMap[string, int]()
113
114	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
115	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
116	require.Equal(t, 1, m.Len())
117}
118
119func TestMap_Get(t *testing.T) {
120	t.Parallel()
121
122	m := NewMap[string, int]()
123
124	value, ok := m.Get("nonexistent")
125	require.False(t, ok)
126	require.Equal(t, 0, value)
127
128	m.Set("key1", 42)
129	value, ok = m.Get("key1")
130	require.True(t, ok)
131	require.Equal(t, 42, value)
132}
133
134func TestMap_Del(t *testing.T) {
135	t.Parallel()
136
137	m := NewMap[string, int]()
138	m.Set("key1", 42)
139	m.Set("key2", 100)
140
141	require.Equal(t, 2, m.Len())
142
143	m.Del("key1")
144	_, ok := m.Get("key1")
145	require.False(t, ok)
146	require.Equal(t, 1, m.Len())
147
148	value, ok := m.Get("key2")
149	require.True(t, ok)
150	require.Equal(t, 100, value)
151
152	m.Del("nonexistent")
153	require.Equal(t, 1, m.Len())
154}
155
156func TestMap_Len(t *testing.T) {
157	t.Parallel()
158
159	m := NewMap[string, int]()
160	require.Equal(t, 0, m.Len())
161
162	m.Set("key1", 1)
163	require.Equal(t, 1, m.Len())
164
165	m.Set("key2", 2)
166	require.Equal(t, 2, m.Len())
167
168	m.Del("key1")
169	require.Equal(t, 1, m.Len())
170
171	m.Del("key2")
172	require.Equal(t, 0, m.Len())
173}
174
175func TestMap_Take(t *testing.T) {
176	t.Parallel()
177
178	m := NewMap[string, int]()
179	m.Set("key1", 42)
180	m.Set("key2", 100)
181
182	require.Equal(t, 2, m.Len())
183
184	value, ok := m.Take("key1")
185	require.True(t, ok)
186	require.Equal(t, 42, value)
187	require.Equal(t, 1, m.Len())
188
189	_, exists := m.Get("key1")
190	require.False(t, exists)
191
192	value, ok = m.Get("key2")
193	require.True(t, ok)
194	require.Equal(t, 100, value)
195}
196
197func TestMap_Take_NonexistentKey(t *testing.T) {
198	t.Parallel()
199
200	m := NewMap[string, int]()
201	m.Set("key1", 42)
202
203	value, ok := m.Take("nonexistent")
204	require.False(t, ok)
205	require.Equal(t, 0, value)
206	require.Equal(t, 1, m.Len())
207
208	value, ok = m.Get("key1")
209	require.True(t, ok)
210	require.Equal(t, 42, value)
211}
212
213func TestMap_Take_EmptyMap(t *testing.T) {
214	t.Parallel()
215
216	m := NewMap[string, int]()
217
218	value, ok := m.Take("key1")
219	require.False(t, ok)
220	require.Equal(t, 0, value)
221	require.Equal(t, 0, m.Len())
222}
223
224func TestMap_Take_SameKeyTwice(t *testing.T) {
225	t.Parallel()
226
227	m := NewMap[string, int]()
228	m.Set("key1", 42)
229
230	value, ok := m.Take("key1")
231	require.True(t, ok)
232	require.Equal(t, 42, value)
233	require.Equal(t, 0, m.Len())
234
235	value, ok = m.Take("key1")
236	require.False(t, ok)
237	require.Equal(t, 0, value)
238	require.Equal(t, 0, m.Len())
239}
240
241func TestMap_Seq2(t *testing.T) {
242	t.Parallel()
243
244	m := NewMap[string, int]()
245	m.Set("key1", 1)
246	m.Set("key2", 2)
247	m.Set("key3", 3)
248
249	collected := maps.Collect(m.Seq2())
250
251	require.Equal(t, 3, len(collected))
252	require.Equal(t, 1, collected["key1"])
253	require.Equal(t, 2, collected["key2"])
254	require.Equal(t, 3, collected["key3"])
255}
256
257func TestMap_Seq2_EarlyReturn(t *testing.T) {
258	t.Parallel()
259
260	m := NewMap[string, int]()
261	m.Set("key1", 1)
262	m.Set("key2", 2)
263	m.Set("key3", 3)
264
265	count := 0
266	for range m.Seq2() {
267		count++
268		if count == 2 {
269			break
270		}
271	}
272
273	require.Equal(t, 2, count)
274}
275
276func TestMap_Seq2_EmptyMap(t *testing.T) {
277	t.Parallel()
278
279	m := NewMap[string, int]()
280
281	count := 0
282	for range m.Seq2() {
283		count++
284	}
285
286	require.Equal(t, 0, count)
287}
288
289func TestMap_Seq(t *testing.T) {
290	t.Parallel()
291
292	m := NewMap[string, int]()
293	m.Set("key1", 1)
294	m.Set("key2", 2)
295	m.Set("key3", 3)
296
297	collected := make([]int, 0)
298	for v := range m.Seq() {
299		collected = append(collected, v)
300	}
301
302	require.Equal(t, 3, len(collected))
303	require.Contains(t, collected, 1)
304	require.Contains(t, collected, 2)
305	require.Contains(t, collected, 3)
306}
307
308func TestMap_Seq_EarlyReturn(t *testing.T) {
309	t.Parallel()
310
311	m := NewMap[string, int]()
312	m.Set("key1", 1)
313	m.Set("key2", 2)
314	m.Set("key3", 3)
315
316	count := 0
317	for range m.Seq() {
318		count++
319		if count == 2 {
320			break
321		}
322	}
323
324	require.Equal(t, 2, count)
325}
326
327func TestMap_Seq_EmptyMap(t *testing.T) {
328	t.Parallel()
329
330	m := NewMap[string, int]()
331
332	count := 0
333	for range m.Seq() {
334		count++
335	}
336
337	require.Equal(t, 0, count)
338}
339
340func TestMap_MarshalJSON(t *testing.T) {
341	t.Parallel()
342
343	m := NewMap[string, int]()
344	m.Set("key1", 1)
345	m.Set("key2", 2)
346
347	data, err := json.Marshal(m)
348	require.NoError(t, err)
349
350	result := &Map[string, int]{}
351	err = json.Unmarshal(data, result)
352	require.NoError(t, err)
353	require.Equal(t, 2, result.Len())
354	v1, _ := result.Get("key1")
355	v2, _ := result.Get("key2")
356	require.Equal(t, 1, v1)
357	require.Equal(t, 2, v2)
358}
359
360func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
361	t.Parallel()
362
363	m := NewMap[string, int]()
364
365	data, err := json.Marshal(m)
366	require.NoError(t, err)
367	require.Equal(t, "{}", string(data))
368}
369
370func TestMap_UnmarshalJSON(t *testing.T) {
371	t.Parallel()
372
373	jsonData := `{"key1": 1, "key2": 2}`
374
375	m := NewMap[string, int]()
376	err := json.Unmarshal([]byte(jsonData), m)
377	require.NoError(t, err)
378
379	require.Equal(t, 2, m.Len())
380	value, ok := m.Get("key1")
381	require.True(t, ok)
382	require.Equal(t, 1, value)
383
384	value, ok = m.Get("key2")
385	require.True(t, ok)
386	require.Equal(t, 2, value)
387}
388
389func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
390	t.Parallel()
391
392	jsonData := `{}`
393
394	m := NewMap[string, int]()
395	err := json.Unmarshal([]byte(jsonData), m)
396	require.NoError(t, err)
397	require.Equal(t, 0, m.Len())
398}
399
400func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
401	t.Parallel()
402
403	jsonData := `{"key1": 1, "key2":}`
404
405	m := NewMap[string, int]()
406	err := json.Unmarshal([]byte(jsonData), m)
407	require.Error(t, err)
408}
409
410func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
411	t.Parallel()
412
413	m := NewMap[string, int]()
414	m.Set("existing", 999)
415
416	jsonData := `{"key1": 1, "key2": 2}`
417	err := json.Unmarshal([]byte(jsonData), m)
418	require.NoError(t, err)
419
420	require.Equal(t, 2, m.Len())
421	_, ok := m.Get("existing")
422	require.False(t, ok)
423
424	value, ok := m.Get("key1")
425	require.True(t, ok)
426	require.Equal(t, 1, value)
427}
428
429func TestMap_JSONRoundTrip(t *testing.T) {
430	t.Parallel()
431
432	original := NewMap[string, int]()
433	original.Set("key1", 1)
434	original.Set("key2", 2)
435	original.Set("key3", 3)
436
437	data, err := json.Marshal(original)
438	require.NoError(t, err)
439
440	restored := NewMap[string, int]()
441	err = json.Unmarshal(data, restored)
442	require.NoError(t, err)
443
444	require.Equal(t, original.Len(), restored.Len())
445
446	for k, v := range original.Seq2() {
447		restoredValue, ok := restored.Get(k)
448		require.True(t, ok)
449		require.Equal(t, v, restoredValue)
450	}
451}
452
453func TestMap_ConcurrentAccess(t *testing.T) {
454	t.Parallel()
455
456	m := NewMap[int, int]()
457	const numGoroutines = 100
458	const numOperations = 100
459
460	var wg sync.WaitGroup
461	wg.Add(numGoroutines)
462
463	for i := range numGoroutines {
464		go func(id int) {
465			defer wg.Done()
466			for j := range numOperations {
467				key := id*numOperations + j
468				m.Set(key, key*2)
469				value, ok := m.Get(key)
470				require.True(t, ok)
471				require.Equal(t, key*2, value)
472			}
473		}(i)
474	}
475
476	wg.Wait()
477
478	require.Equal(t, numGoroutines*numOperations, m.Len())
479}
480
481func TestMap_ConcurrentReadWrite(t *testing.T) {
482	t.Parallel()
483
484	m := NewMap[int, int]()
485	const numReaders = 50
486	const numWriters = 50
487	const numOperations = 100
488
489	for i := range 1000 {
490		m.Set(i, i)
491	}
492
493	var wg sync.WaitGroup
494	wg.Add(numReaders + numWriters)
495
496	for range numReaders {
497		go func() {
498			defer wg.Done()
499			for j := range numOperations {
500				key := j % 1000
501				value, ok := m.Get(key)
502				if ok {
503					require.Equal(t, key, value)
504				}
505				_ = m.Len()
506			}
507		}()
508	}
509
510	for i := range numWriters {
511		go func(id int) {
512			defer wg.Done()
513			for j := range numOperations {
514				key := 1000 + id*numOperations + j
515				m.Set(key, key)
516				if j%10 == 0 {
517					m.Del(key)
518				}
519			}
520		}(i)
521	}
522
523	wg.Wait()
524}
525
526func TestMap_ConcurrentSeq2(t *testing.T) {
527	t.Parallel()
528
529	m := NewMap[int, int]()
530	for i := range 100 {
531		m.Set(i, i*2)
532	}
533
534	var wg sync.WaitGroup
535	const numIterators = 10
536
537	wg.Add(numIterators)
538	for range numIterators {
539		go func() {
540			defer wg.Done()
541			count := 0
542			for k, v := range m.Seq2() {
543				require.Equal(t, k*2, v)
544				count++
545			}
546			require.Equal(t, 100, count)
547		}()
548	}
549
550	wg.Wait()
551}
552
553func TestMap_ConcurrentSeq(t *testing.T) {
554	t.Parallel()
555
556	m := NewMap[int, int]()
557	for i := range 100 {
558		m.Set(i, i*2)
559	}
560
561	var wg sync.WaitGroup
562	const numIterators = 10
563
564	wg.Add(numIterators)
565	for range numIterators {
566		go func() {
567			defer wg.Done()
568			count := 0
569			values := make(map[int]bool)
570			for v := range m.Seq() {
571				values[v] = true
572				count++
573			}
574			require.Equal(t, 100, count)
575			for i := range 100 {
576				require.True(t, values[i*2])
577			}
578		}()
579	}
580
581	wg.Wait()
582}
583
584func TestMap_ConcurrentTake(t *testing.T) {
585	t.Parallel()
586
587	m := NewMap[int, int]()
588	const numItems = 1000
589
590	for i := range numItems {
591		m.Set(i, i*2)
592	}
593
594	var wg sync.WaitGroup
595	const numWorkers = 10
596	taken := make([][]int, numWorkers)
597
598	wg.Add(numWorkers)
599	for i := range numWorkers {
600		go func(workerID int) {
601			defer wg.Done()
602			taken[workerID] = make([]int, 0)
603			for j := workerID; j < numItems; j += numWorkers {
604				if value, ok := m.Take(j); ok {
605					taken[workerID] = append(taken[workerID], value)
606				}
607			}
608		}(i)
609	}
610
611	wg.Wait()
612
613	require.Equal(t, 0, m.Len())
614
615	allTaken := make(map[int]bool)
616	for _, workerTaken := range taken {
617		for _, value := range workerTaken {
618			require.False(t, allTaken[value], "Value %d was taken multiple times", value)
619			allTaken[value] = true
620		}
621	}
622
623	require.Equal(t, numItems, len(allTaken))
624	for i := range numItems {
625		require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
626	}
627}
628
629func TestMap_TypeSafety(t *testing.T) {
630	t.Parallel()
631
632	stringIntMap := NewMap[string, int]()
633	stringIntMap.Set("key", 42)
634	value, ok := stringIntMap.Get("key")
635	require.True(t, ok)
636	require.Equal(t, 42, value)
637
638	intStringMap := NewMap[int, string]()
639	intStringMap.Set(42, "value")
640	strValue, ok := intStringMap.Get(42)
641	require.True(t, ok)
642	require.Equal(t, "value", strValue)
643
644	structMap := NewMap[string, struct{ Name string }]()
645	structMap.Set("key", struct{ Name string }{Name: "test"})
646	structValue, ok := structMap.Get("key")
647	require.True(t, ok)
648	require.Equal(t, "test", structValue.Name)
649}
650
651func TestMap_InterfaceCompliance(t *testing.T) {
652	t.Parallel()
653
654	var _ json.Marshaler = &Map[string, any]{}
655	var _ json.Unmarshaler = &Map[string, any]{}
656}
657
658func BenchmarkMap_Set(b *testing.B) {
659	m := NewMap[int, int]()
660
661	for i := 0; b.Loop(); i++ {
662		m.Set(i, i*2)
663	}
664}
665
666func BenchmarkMap_Get(b *testing.B) {
667	m := NewMap[int, int]()
668	for i := range 1000 {
669		m.Set(i, i*2)
670	}
671
672	for i := 0; b.Loop(); i++ {
673		m.Get(i % 1000)
674	}
675}
676
677func BenchmarkMap_Seq2(b *testing.B) {
678	m := NewMap[int, int]()
679	for i := range 1000 {
680		m.Set(i, i*2)
681	}
682
683	for b.Loop() {
684		for range m.Seq2() {
685		}
686	}
687}
688
689func BenchmarkMap_Seq(b *testing.B) {
690	m := NewMap[int, int]()
691	for i := range 1000 {
692		m.Set(i, i*2)
693	}
694
695	for b.Loop() {
696		for range m.Seq() {
697		}
698	}
699}
700
701func BenchmarkMap_Take(b *testing.B) {
702	m := NewMap[int, int]()
703	for i := range 1000 {
704		m.Set(i, i*2)
705	}
706
707	b.ResetTimer()
708	for i := 0; b.Loop(); i++ {
709		key := i % 1000
710		m.Take(key)
711		if i%1000 == 999 {
712			b.StopTimer()
713			for j := range 1000 {
714				m.Set(j, j*2)
715			}
716			b.StartTimer()
717		}
718	}
719}
720
721func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
722	m := NewMap[int, int]()
723	for i := range 1000 {
724		m.Set(i, i*2)
725	}
726
727	b.ResetTimer()
728	b.RunParallel(func(pb *testing.PB) {
729		i := 0
730		for pb.Next() {
731			if i%2 == 0 {
732				m.Get(i % 1000)
733			} else {
734				m.Set(i+1000, i*2)
735			}
736			i++
737		}
738	})
739}