1package csync
  2
  3import (
  4	"encoding/json"
  5	"maps"
  6	"sync"
  7	"testing"
  8
  9	"github.com/stretchr/testify/require"
 10)
 11
 12func TestNewMap(t *testing.T) {
 13	t.Parallel()
 14
 15	m := NewMap[string, int]()
 16	require.NotNil(t, m)
 17	require.NotNil(t, m.inner)
 18	require.Equal(t, 0, m.Len())
 19}
 20
 21func TestNewMapFrom(t *testing.T) {
 22	t.Parallel()
 23
 24	original := map[string]int{
 25		"key1": 1,
 26		"key2": 2,
 27	}
 28
 29	m := NewMapFrom(original)
 30	require.NotNil(t, m)
 31	require.Equal(t, original, m.inner)
 32	require.Equal(t, 2, m.Len())
 33
 34	value, ok := m.Get("key1")
 35	require.True(t, ok)
 36	require.Equal(t, 1, value)
 37}
 38
 39func TestMap_Set(t *testing.T) {
 40	t.Parallel()
 41
 42	m := NewMap[string, int]()
 43
 44	m.Set("key1", 42)
 45	value, ok := m.Get("key1")
 46	require.True(t, ok)
 47	require.Equal(t, 42, value)
 48	require.Equal(t, 1, m.Len())
 49
 50	m.Set("key1", 100)
 51	value, ok = m.Get("key1")
 52	require.True(t, ok)
 53	require.Equal(t, 100, value)
 54	require.Equal(t, 1, m.Len())
 55}
 56
 57func TestMap_GetOrSet(t *testing.T) {
 58	t.Parallel()
 59
 60	m := NewMap[string, int]()
 61
 62	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
 63	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
 64	require.Equal(t, 1, m.Len())
 65}
 66
 67func TestMap_Get(t *testing.T) {
 68	t.Parallel()
 69
 70	m := NewMap[string, int]()
 71
 72	value, ok := m.Get("nonexistent")
 73	require.False(t, ok)
 74	require.Equal(t, 0, value)
 75
 76	m.Set("key1", 42)
 77	value, ok = m.Get("key1")
 78	require.True(t, ok)
 79	require.Equal(t, 42, value)
 80}
 81
 82func TestMap_Del(t *testing.T) {
 83	t.Parallel()
 84
 85	m := NewMap[string, int]()
 86	m.Set("key1", 42)
 87	m.Set("key2", 100)
 88
 89	require.Equal(t, 2, m.Len())
 90
 91	m.Del("key1")
 92	_, ok := m.Get("key1")
 93	require.False(t, ok)
 94	require.Equal(t, 1, m.Len())
 95
 96	value, ok := m.Get("key2")
 97	require.True(t, ok)
 98	require.Equal(t, 100, value)
 99
100	m.Del("nonexistent")
101	require.Equal(t, 1, m.Len())
102}
103
104func TestMap_Len(t *testing.T) {
105	t.Parallel()
106
107	m := NewMap[string, int]()
108	require.Equal(t, 0, m.Len())
109
110	m.Set("key1", 1)
111	require.Equal(t, 1, m.Len())
112
113	m.Set("key2", 2)
114	require.Equal(t, 2, m.Len())
115
116	m.Del("key1")
117	require.Equal(t, 1, m.Len())
118
119	m.Del("key2")
120	require.Equal(t, 0, m.Len())
121}
122
123func TestMap_Take(t *testing.T) {
124	t.Parallel()
125
126	m := NewMap[string, int]()
127	m.Set("key1", 42)
128	m.Set("key2", 100)
129
130	require.Equal(t, 2, m.Len())
131
132	value, ok := m.Take("key1")
133	require.True(t, ok)
134	require.Equal(t, 42, value)
135	require.Equal(t, 1, m.Len())
136
137	_, exists := m.Get("key1")
138	require.False(t, exists)
139
140	value, ok = m.Get("key2")
141	require.True(t, ok)
142	require.Equal(t, 100, value)
143}
144
145func TestMap_Take_NonexistentKey(t *testing.T) {
146	t.Parallel()
147
148	m := NewMap[string, int]()
149	m.Set("key1", 42)
150
151	value, ok := m.Take("nonexistent")
152	require.False(t, ok)
153	require.Equal(t, 0, value)
154	require.Equal(t, 1, m.Len())
155
156	value, ok = m.Get("key1")
157	require.True(t, ok)
158	require.Equal(t, 42, value)
159}
160
161func TestMap_Take_EmptyMap(t *testing.T) {
162	t.Parallel()
163
164	m := NewMap[string, int]()
165
166	value, ok := m.Take("key1")
167	require.False(t, ok)
168	require.Equal(t, 0, value)
169	require.Equal(t, 0, m.Len())
170}
171
172func TestMap_Take_SameKeyTwice(t *testing.T) {
173	t.Parallel()
174
175	m := NewMap[string, int]()
176	m.Set("key1", 42)
177
178	value, ok := m.Take("key1")
179	require.True(t, ok)
180	require.Equal(t, 42, value)
181	require.Equal(t, 0, m.Len())
182
183	value, ok = m.Take("key1")
184	require.False(t, ok)
185	require.Equal(t, 0, value)
186	require.Equal(t, 0, m.Len())
187}
188
189func TestMap_Seq2(t *testing.T) {
190	t.Parallel()
191
192	m := NewMap[string, int]()
193	m.Set("key1", 1)
194	m.Set("key2", 2)
195	m.Set("key3", 3)
196
197	collected := maps.Collect(m.Seq2())
198
199	require.Equal(t, 3, len(collected))
200	require.Equal(t, 1, collected["key1"])
201	require.Equal(t, 2, collected["key2"])
202	require.Equal(t, 3, collected["key3"])
203}
204
205func TestMap_Seq2_EarlyReturn(t *testing.T) {
206	t.Parallel()
207
208	m := NewMap[string, int]()
209	m.Set("key1", 1)
210	m.Set("key2", 2)
211	m.Set("key3", 3)
212
213	count := 0
214	for range m.Seq2() {
215		count++
216		if count == 2 {
217			break
218		}
219	}
220
221	require.Equal(t, 2, count)
222}
223
224func TestMap_Seq2_EmptyMap(t *testing.T) {
225	t.Parallel()
226
227	m := NewMap[string, int]()
228
229	count := 0
230	for range m.Seq2() {
231		count++
232	}
233
234	require.Equal(t, 0, count)
235}
236
237func TestMap_Seq(t *testing.T) {
238	t.Parallel()
239
240	m := NewMap[string, int]()
241	m.Set("key1", 1)
242	m.Set("key2", 2)
243	m.Set("key3", 3)
244
245	collected := make([]int, 0)
246	for v := range m.Seq() {
247		collected = append(collected, v)
248	}
249
250	require.Equal(t, 3, len(collected))
251	require.Contains(t, collected, 1)
252	require.Contains(t, collected, 2)
253	require.Contains(t, collected, 3)
254}
255
256func TestMap_Seq_EarlyReturn(t *testing.T) {
257	t.Parallel()
258
259	m := NewMap[string, int]()
260	m.Set("key1", 1)
261	m.Set("key2", 2)
262	m.Set("key3", 3)
263
264	count := 0
265	for range m.Seq() {
266		count++
267		if count == 2 {
268			break
269		}
270	}
271
272	require.Equal(t, 2, count)
273}
274
275func TestMap_Seq_EmptyMap(t *testing.T) {
276	t.Parallel()
277
278	m := NewMap[string, int]()
279
280	count := 0
281	for range m.Seq() {
282		count++
283	}
284
285	require.Equal(t, 0, count)
286}
287
288func TestMap_MarshalJSON(t *testing.T) {
289	t.Parallel()
290
291	m := NewMap[string, int]()
292	m.Set("key1", 1)
293	m.Set("key2", 2)
294
295	data, err := json.Marshal(m)
296	require.NoError(t, err)
297
298	result := &Map[string, int]{}
299	err = json.Unmarshal(data, result)
300	require.NoError(t, err)
301	require.Equal(t, 2, result.Len())
302	v1, _ := result.Get("key1")
303	v2, _ := result.Get("key2")
304	require.Equal(t, 1, v1)
305	require.Equal(t, 2, v2)
306}
307
308func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
309	t.Parallel()
310
311	m := NewMap[string, int]()
312
313	data, err := json.Marshal(m)
314	require.NoError(t, err)
315	require.Equal(t, "{}", string(data))
316}
317
318func TestMap_UnmarshalJSON(t *testing.T) {
319	t.Parallel()
320
321	jsonData := `{"key1": 1, "key2": 2}`
322
323	m := NewMap[string, int]()
324	err := json.Unmarshal([]byte(jsonData), m)
325	require.NoError(t, err)
326
327	require.Equal(t, 2, m.Len())
328	value, ok := m.Get("key1")
329	require.True(t, ok)
330	require.Equal(t, 1, value)
331
332	value, ok = m.Get("key2")
333	require.True(t, ok)
334	require.Equal(t, 2, value)
335}
336
337func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
338	t.Parallel()
339
340	jsonData := `{}`
341
342	m := NewMap[string, int]()
343	err := json.Unmarshal([]byte(jsonData), m)
344	require.NoError(t, err)
345	require.Equal(t, 0, m.Len())
346}
347
348func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
349	t.Parallel()
350
351	jsonData := `{"key1": 1, "key2":}`
352
353	m := NewMap[string, int]()
354	err := json.Unmarshal([]byte(jsonData), m)
355	require.Error(t, err)
356}
357
358func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
359	t.Parallel()
360
361	m := NewMap[string, int]()
362	m.Set("existing", 999)
363
364	jsonData := `{"key1": 1, "key2": 2}`
365	err := json.Unmarshal([]byte(jsonData), m)
366	require.NoError(t, err)
367
368	require.Equal(t, 2, m.Len())
369	_, ok := m.Get("existing")
370	require.False(t, ok)
371
372	value, ok := m.Get("key1")
373	require.True(t, ok)
374	require.Equal(t, 1, value)
375}
376
377func TestMap_JSONRoundTrip(t *testing.T) {
378	t.Parallel()
379
380	original := NewMap[string, int]()
381	original.Set("key1", 1)
382	original.Set("key2", 2)
383	original.Set("key3", 3)
384
385	data, err := json.Marshal(original)
386	require.NoError(t, err)
387
388	restored := NewMap[string, int]()
389	err = json.Unmarshal(data, restored)
390	require.NoError(t, err)
391
392	require.Equal(t, original.Len(), restored.Len())
393
394	for k, v := range original.Seq2() {
395		restoredValue, ok := restored.Get(k)
396		require.True(t, ok)
397		require.Equal(t, v, restoredValue)
398	}
399}
400
401func TestMap_ConcurrentAccess(t *testing.T) {
402	t.Parallel()
403
404	m := NewMap[int, int]()
405	const numGoroutines = 100
406	const numOperations = 100
407
408	var wg sync.WaitGroup
409	wg.Add(numGoroutines)
410
411	for i := range numGoroutines {
412		go func(id int) {
413			defer wg.Done()
414			for j := range numOperations {
415				key := id*numOperations + j
416				m.Set(key, key*2)
417				value, ok := m.Get(key)
418				require.True(t, ok)
419				require.Equal(t, key*2, value)
420			}
421		}(i)
422	}
423
424	wg.Wait()
425
426	require.Equal(t, numGoroutines*numOperations, m.Len())
427}
428
429func TestMap_ConcurrentReadWrite(t *testing.T) {
430	t.Parallel()
431
432	m := NewMap[int, int]()
433	const numReaders = 50
434	const numWriters = 50
435	const numOperations = 100
436
437	for i := range 1000 {
438		m.Set(i, i)
439	}
440
441	var wg sync.WaitGroup
442	wg.Add(numReaders + numWriters)
443
444	for range numReaders {
445		go func() {
446			defer wg.Done()
447			for j := range numOperations {
448				key := j % 1000
449				value, ok := m.Get(key)
450				if ok {
451					require.Equal(t, key, value)
452				}
453				_ = m.Len()
454			}
455		}()
456	}
457
458	for i := range numWriters {
459		go func(id int) {
460			defer wg.Done()
461			for j := range numOperations {
462				key := 1000 + id*numOperations + j
463				m.Set(key, key)
464				if j%10 == 0 {
465					m.Del(key)
466				}
467			}
468		}(i)
469	}
470
471	wg.Wait()
472}
473
474func TestMap_ConcurrentSeq2(t *testing.T) {
475	t.Parallel()
476
477	m := NewMap[int, int]()
478	for i := range 100 {
479		m.Set(i, i*2)
480	}
481
482	var wg sync.WaitGroup
483	const numIterators = 10
484
485	wg.Add(numIterators)
486	for range numIterators {
487		go func() {
488			defer wg.Done()
489			count := 0
490			for k, v := range m.Seq2() {
491				require.Equal(t, k*2, v)
492				count++
493			}
494			require.Equal(t, 100, count)
495		}()
496	}
497
498	wg.Wait()
499}
500
501func TestMap_ConcurrentSeq(t *testing.T) {
502	t.Parallel()
503
504	m := NewMap[int, int]()
505	for i := range 100 {
506		m.Set(i, i*2)
507	}
508
509	var wg sync.WaitGroup
510	const numIterators = 10
511
512	wg.Add(numIterators)
513	for range numIterators {
514		go func() {
515			defer wg.Done()
516			count := 0
517			values := make(map[int]bool)
518			for v := range m.Seq() {
519				values[v] = true
520				count++
521			}
522			require.Equal(t, 100, count)
523			for i := range 100 {
524				require.True(t, values[i*2])
525			}
526		}()
527	}
528
529	wg.Wait()
530}
531
532func TestMap_ConcurrentTake(t *testing.T) {
533	t.Parallel()
534
535	m := NewMap[int, int]()
536	const numItems = 1000
537
538	for i := range numItems {
539		m.Set(i, i*2)
540	}
541
542	var wg sync.WaitGroup
543	const numWorkers = 10
544	taken := make([][]int, numWorkers)
545
546	wg.Add(numWorkers)
547	for i := range numWorkers {
548		go func(workerID int) {
549			defer wg.Done()
550			taken[workerID] = make([]int, 0)
551			for j := workerID; j < numItems; j += numWorkers {
552				if value, ok := m.Take(j); ok {
553					taken[workerID] = append(taken[workerID], value)
554				}
555			}
556		}(i)
557	}
558
559	wg.Wait()
560
561	require.Equal(t, 0, m.Len())
562
563	allTaken := make(map[int]bool)
564	for _, workerTaken := range taken {
565		for _, value := range workerTaken {
566			require.False(t, allTaken[value], "Value %d was taken multiple times", value)
567			allTaken[value] = true
568		}
569	}
570
571	require.Equal(t, numItems, len(allTaken))
572	for i := range numItems {
573		require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
574	}
575}
576
577func TestMap_TypeSafety(t *testing.T) {
578	t.Parallel()
579
580	stringIntMap := NewMap[string, int]()
581	stringIntMap.Set("key", 42)
582	value, ok := stringIntMap.Get("key")
583	require.True(t, ok)
584	require.Equal(t, 42, value)
585
586	intStringMap := NewMap[int, string]()
587	intStringMap.Set(42, "value")
588	strValue, ok := intStringMap.Get(42)
589	require.True(t, ok)
590	require.Equal(t, "value", strValue)
591
592	structMap := NewMap[string, struct{ Name string }]()
593	structMap.Set("key", struct{ Name string }{Name: "test"})
594	structValue, ok := structMap.Get("key")
595	require.True(t, ok)
596	require.Equal(t, "test", structValue.Name)
597}
598
599func TestMap_InterfaceCompliance(t *testing.T) {
600	t.Parallel()
601
602	var _ json.Marshaler = &Map[string, any]{}
603	var _ json.Unmarshaler = &Map[string, any]{}
604}
605
606func BenchmarkMap_Set(b *testing.B) {
607	m := NewMap[int, int]()
608
609	for i := 0; b.Loop(); i++ {
610		m.Set(i, i*2)
611	}
612}
613
614func BenchmarkMap_Get(b *testing.B) {
615	m := NewMap[int, int]()
616	for i := range 1000 {
617		m.Set(i, i*2)
618	}
619
620	for i := 0; b.Loop(); i++ {
621		m.Get(i % 1000)
622	}
623}
624
625func BenchmarkMap_Seq2(b *testing.B) {
626	m := NewMap[int, int]()
627	for i := range 1000 {
628		m.Set(i, i*2)
629	}
630
631	for b.Loop() {
632		for range m.Seq2() {
633		}
634	}
635}
636
637func BenchmarkMap_Seq(b *testing.B) {
638	m := NewMap[int, int]()
639	for i := range 1000 {
640		m.Set(i, i*2)
641	}
642
643	for b.Loop() {
644		for range m.Seq() {
645		}
646	}
647}
648
649func BenchmarkMap_Take(b *testing.B) {
650	m := NewMap[int, int]()
651	for i := range 1000 {
652		m.Set(i, i*2)
653	}
654
655	b.ResetTimer()
656	for i := 0; b.Loop(); i++ {
657		key := i % 1000
658		m.Take(key)
659		if i%1000 == 999 {
660			b.StopTimer()
661			for j := range 1000 {
662				m.Set(j, j*2)
663			}
664			b.StartTimer()
665		}
666	}
667}
668
669func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
670	m := NewMap[int, int]()
671	for i := range 1000 {
672		m.Set(i, i*2)
673	}
674
675	b.ResetTimer()
676	b.RunParallel(func(pb *testing.PB) {
677		i := 0
678		for pb.Next() {
679			if i%2 == 0 {
680				m.Get(i % 1000)
681			} else {
682				m.Set(i+1000, i*2)
683			}
684			i++
685		}
686	})
687}