maps_test.go

  1package csync
  2
  3import (
  4	"encoding/json"
  5	"maps"
  6	"sync"
  7	"testing"
  8	"time"
  9
 10	"github.com/stretchr/testify/require"
 11)
 12
 13func TestNewMap(t *testing.T) {
 14	t.Parallel()
 15
 16	m := NewMap[string, int]()
 17	require.NotNil(t, m)
 18	require.NotNil(t, m.inner)
 19	require.Equal(t, 0, m.Len())
 20}
 21
 22func TestNewMapFrom(t *testing.T) {
 23	t.Parallel()
 24
 25	original := map[string]int{
 26		"key1": 1,
 27		"key2": 2,
 28	}
 29
 30	m := NewMapFrom(original)
 31	require.NotNil(t, m)
 32	require.Equal(t, original, m.inner)
 33	require.Equal(t, 2, m.Len())
 34
 35	value, ok := m.Get("key1")
 36	require.True(t, ok)
 37	require.Equal(t, 1, value)
 38}
 39
 40func TestNewMapFromSeq(t *testing.T) {
 41	t.Parallel()
 42
 43	original := map[string]int{
 44		"key1": 1,
 45		"key2": 2,
 46	}
 47	seq := func(f func(string, int) bool) {
 48		for k, v := range original {
 49			if !f(k, v) {
 50				break
 51			}
 52		}
 53	}
 54
 55	m := NewMapFromSeq(seq)
 56	require.NotNil(t, m)
 57	require.Equal(t, original, m.inner)
 58	require.Equal(t, 2, m.Len())
 59
 60	value, ok := m.Get("key2")
 61	require.True(t, ok)
 62	require.Equal(t, 2, value)
 63}
 64
 65func TestNewLazyMap(t *testing.T) {
 66	t.Parallel()
 67
 68	waiter := sync.Mutex{}
 69	waiter.Lock()
 70	loadCalled := false
 71
 72	loadFunc := func() map[string]int {
 73		waiter.Lock()
 74		defer waiter.Unlock()
 75		loadCalled = true
 76		return map[string]int{
 77			"key1": 1,
 78			"key2": 2,
 79		}
 80	}
 81
 82	m := NewLazyMap(loadFunc)
 83	require.NotNil(t, m)
 84
 85	waiter.Unlock() // Allow the load function to proceed
 86	time.Sleep(100 * time.Millisecond)
 87	require.True(t, loadCalled)
 88	require.Equal(t, 2, m.Len())
 89
 90	value, ok := m.Get("key1")
 91	require.True(t, ok)
 92	require.Equal(t, 1, value)
 93}
 94
 95func TestMap_Set(t *testing.T) {
 96	t.Parallel()
 97
 98	m := NewMap[string, int]()
 99
100	m.Set("key1", 42)
101	value, ok := m.Get("key1")
102	require.True(t, ok)
103	require.Equal(t, 42, value)
104	require.Equal(t, 1, m.Len())
105
106	m.Set("key1", 100)
107	value, ok = m.Get("key1")
108	require.True(t, ok)
109	require.Equal(t, 100, value)
110	require.Equal(t, 1, m.Len())
111}
112
113func TestMap_GetOrSet(t *testing.T) {
114	t.Parallel()
115
116	m := NewMap[string, int]()
117
118	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
119	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
120	require.Equal(t, 1, m.Len())
121}
122
123func TestMap_Get(t *testing.T) {
124	t.Parallel()
125
126	m := NewMap[string, int]()
127
128	value, ok := m.Get("nonexistent")
129	require.False(t, ok)
130	require.Equal(t, 0, value)
131
132	m.Set("key1", 42)
133	value, ok = m.Get("key1")
134	require.True(t, ok)
135	require.Equal(t, 42, value)
136}
137
138func TestMap_Del(t *testing.T) {
139	t.Parallel()
140
141	m := NewMap[string, int]()
142	m.Set("key1", 42)
143	m.Set("key2", 100)
144
145	require.Equal(t, 2, m.Len())
146
147	m.Del("key1")
148	_, ok := m.Get("key1")
149	require.False(t, ok)
150	require.Equal(t, 1, m.Len())
151
152	value, ok := m.Get("key2")
153	require.True(t, ok)
154	require.Equal(t, 100, value)
155
156	m.Del("nonexistent")
157	require.Equal(t, 1, m.Len())
158}
159
160func TestMap_Len(t *testing.T) {
161	t.Parallel()
162
163	m := NewMap[string, int]()
164	require.Equal(t, 0, m.Len())
165
166	m.Set("key1", 1)
167	require.Equal(t, 1, m.Len())
168
169	m.Set("key2", 2)
170	require.Equal(t, 2, m.Len())
171
172	m.Del("key1")
173	require.Equal(t, 1, m.Len())
174
175	m.Del("key2")
176	require.Equal(t, 0, m.Len())
177}
178
179func TestMap_Take(t *testing.T) {
180	t.Parallel()
181
182	m := NewMap[string, int]()
183	m.Set("key1", 42)
184	m.Set("key2", 100)
185
186	require.Equal(t, 2, m.Len())
187
188	value, ok := m.Take("key1")
189	require.True(t, ok)
190	require.Equal(t, 42, value)
191	require.Equal(t, 1, m.Len())
192
193	_, exists := m.Get("key1")
194	require.False(t, exists)
195
196	value, ok = m.Get("key2")
197	require.True(t, ok)
198	require.Equal(t, 100, value)
199}
200
201func TestMap_Take_NonexistentKey(t *testing.T) {
202	t.Parallel()
203
204	m := NewMap[string, int]()
205	m.Set("key1", 42)
206
207	value, ok := m.Take("nonexistent")
208	require.False(t, ok)
209	require.Equal(t, 0, value)
210	require.Equal(t, 1, m.Len())
211
212	value, ok = m.Get("key1")
213	require.True(t, ok)
214	require.Equal(t, 42, value)
215}
216
217func TestMap_Take_EmptyMap(t *testing.T) {
218	t.Parallel()
219
220	m := NewMap[string, int]()
221
222	value, ok := m.Take("key1")
223	require.False(t, ok)
224	require.Equal(t, 0, value)
225	require.Equal(t, 0, m.Len())
226}
227
228func TestMap_Take_SameKeyTwice(t *testing.T) {
229	t.Parallel()
230
231	m := NewMap[string, int]()
232	m.Set("key1", 42)
233
234	value, ok := m.Take("key1")
235	require.True(t, ok)
236	require.Equal(t, 42, value)
237	require.Equal(t, 0, m.Len())
238
239	value, ok = m.Take("key1")
240	require.False(t, ok)
241	require.Equal(t, 0, value)
242	require.Equal(t, 0, m.Len())
243}
244
245func TestMap_Seq2(t *testing.T) {
246	t.Parallel()
247
248	m := NewMap[string, int]()
249	m.Set("key1", 1)
250	m.Set("key2", 2)
251	m.Set("key3", 3)
252
253	collected := maps.Collect(m.Seq2())
254
255	require.Equal(t, 3, len(collected))
256	require.Equal(t, 1, collected["key1"])
257	require.Equal(t, 2, collected["key2"])
258	require.Equal(t, 3, collected["key3"])
259}
260
261func TestMap_Seq2_EarlyReturn(t *testing.T) {
262	t.Parallel()
263
264	m := NewMap[string, int]()
265	m.Set("key1", 1)
266	m.Set("key2", 2)
267	m.Set("key3", 3)
268
269	count := 0
270	for range m.Seq2() {
271		count++
272		if count == 2 {
273			break
274		}
275	}
276
277	require.Equal(t, 2, count)
278}
279
280func TestMap_Seq2_EmptyMap(t *testing.T) {
281	t.Parallel()
282
283	m := NewMap[string, int]()
284
285	count := 0
286	for range m.Seq2() {
287		count++
288	}
289
290	require.Equal(t, 0, count)
291}
292
293func TestMap_Seq(t *testing.T) {
294	t.Parallel()
295
296	m := NewMap[string, int]()
297	m.Set("key1", 1)
298	m.Set("key2", 2)
299	m.Set("key3", 3)
300
301	collected := make([]int, 0)
302	for v := range m.Seq() {
303		collected = append(collected, v)
304	}
305
306	require.Equal(t, 3, len(collected))
307	require.Contains(t, collected, 1)
308	require.Contains(t, collected, 2)
309	require.Contains(t, collected, 3)
310}
311
312func TestMap_Seq_EarlyReturn(t *testing.T) {
313	t.Parallel()
314
315	m := NewMap[string, int]()
316	m.Set("key1", 1)
317	m.Set("key2", 2)
318	m.Set("key3", 3)
319
320	count := 0
321	for range m.Seq() {
322		count++
323		if count == 2 {
324			break
325		}
326	}
327
328	require.Equal(t, 2, count)
329}
330
331func TestMap_Seq_EmptyMap(t *testing.T) {
332	t.Parallel()
333
334	m := NewMap[string, int]()
335
336	count := 0
337	for range m.Seq() {
338		count++
339	}
340
341	require.Equal(t, 0, count)
342}
343
344func TestMap_MarshalJSON(t *testing.T) {
345	t.Parallel()
346
347	m := NewMap[string, int]()
348	m.Set("key1", 1)
349	m.Set("key2", 2)
350
351	data, err := json.Marshal(m)
352	require.NoError(t, err)
353
354	result := &Map[string, int]{}
355	err = json.Unmarshal(data, result)
356	require.NoError(t, err)
357	require.Equal(t, 2, result.Len())
358	v1, _ := result.Get("key1")
359	v2, _ := result.Get("key2")
360	require.Equal(t, 1, v1)
361	require.Equal(t, 2, v2)
362}
363
364func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
365	t.Parallel()
366
367	m := NewMap[string, int]()
368
369	data, err := json.Marshal(m)
370	require.NoError(t, err)
371	require.Equal(t, "{}", string(data))
372}
373
374func TestMap_UnmarshalJSON(t *testing.T) {
375	t.Parallel()
376
377	jsonData := `{"key1": 1, "key2": 2}`
378
379	m := NewMap[string, int]()
380	err := json.Unmarshal([]byte(jsonData), m)
381	require.NoError(t, err)
382
383	require.Equal(t, 2, m.Len())
384	value, ok := m.Get("key1")
385	require.True(t, ok)
386	require.Equal(t, 1, value)
387
388	value, ok = m.Get("key2")
389	require.True(t, ok)
390	require.Equal(t, 2, value)
391}
392
393func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
394	t.Parallel()
395
396	jsonData := `{}`
397
398	m := NewMap[string, int]()
399	err := json.Unmarshal([]byte(jsonData), m)
400	require.NoError(t, err)
401	require.Equal(t, 0, m.Len())
402}
403
404func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
405	t.Parallel()
406
407	jsonData := `{"key1": 1, "key2":}`
408
409	m := NewMap[string, int]()
410	err := json.Unmarshal([]byte(jsonData), m)
411	require.Error(t, err)
412}
413
414func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
415	t.Parallel()
416
417	m := NewMap[string, int]()
418	m.Set("existing", 999)
419
420	jsonData := `{"key1": 1, "key2": 2}`
421	err := json.Unmarshal([]byte(jsonData), m)
422	require.NoError(t, err)
423
424	require.Equal(t, 2, m.Len())
425	_, ok := m.Get("existing")
426	require.False(t, ok)
427
428	value, ok := m.Get("key1")
429	require.True(t, ok)
430	require.Equal(t, 1, value)
431}
432
433func TestMap_JSONRoundTrip(t *testing.T) {
434	t.Parallel()
435
436	original := NewMap[string, int]()
437	original.Set("key1", 1)
438	original.Set("key2", 2)
439	original.Set("key3", 3)
440
441	data, err := json.Marshal(original)
442	require.NoError(t, err)
443
444	restored := NewMap[string, int]()
445	err = json.Unmarshal(data, restored)
446	require.NoError(t, err)
447
448	require.Equal(t, original.Len(), restored.Len())
449
450	for k, v := range original.Seq2() {
451		restoredValue, ok := restored.Get(k)
452		require.True(t, ok)
453		require.Equal(t, v, restoredValue)
454	}
455}
456
457func TestMap_ConcurrentAccess(t *testing.T) {
458	t.Parallel()
459
460	m := NewMap[int, int]()
461	const numGoroutines = 100
462	const numOperations = 100
463
464	var wg sync.WaitGroup
465	wg.Add(numGoroutines)
466
467	for i := range numGoroutines {
468		go func(id int) {
469			defer wg.Done()
470			for j := range numOperations {
471				key := id*numOperations + j
472				m.Set(key, key*2)
473				value, ok := m.Get(key)
474				require.True(t, ok)
475				require.Equal(t, key*2, value)
476			}
477		}(i)
478	}
479
480	wg.Wait()
481
482	require.Equal(t, numGoroutines*numOperations, m.Len())
483}
484
485func TestMap_ConcurrentReadWrite(t *testing.T) {
486	t.Parallel()
487
488	m := NewMap[int, int]()
489	const numReaders = 50
490	const numWriters = 50
491	const numOperations = 100
492
493	for i := range 1000 {
494		m.Set(i, i)
495	}
496
497	var wg sync.WaitGroup
498	wg.Add(numReaders + numWriters)
499
500	for range numReaders {
501		go func() {
502			defer wg.Done()
503			for j := range numOperations {
504				key := j % 1000
505				value, ok := m.Get(key)
506				if ok {
507					require.Equal(t, key, value)
508				}
509				_ = m.Len()
510			}
511		}()
512	}
513
514	for i := range numWriters {
515		go func(id int) {
516			defer wg.Done()
517			for j := range numOperations {
518				key := 1000 + id*numOperations + j
519				m.Set(key, key)
520				if j%10 == 0 {
521					m.Del(key)
522				}
523			}
524		}(i)
525	}
526
527	wg.Wait()
528}
529
530func TestMap_ConcurrentSeq2(t *testing.T) {
531	t.Parallel()
532
533	m := NewMap[int, int]()
534	for i := range 100 {
535		m.Set(i, i*2)
536	}
537
538	var wg sync.WaitGroup
539	const numIterators = 10
540
541	wg.Add(numIterators)
542	for range numIterators {
543		go func() {
544			defer wg.Done()
545			count := 0
546			for k, v := range m.Seq2() {
547				require.Equal(t, k*2, v)
548				count++
549			}
550			require.Equal(t, 100, count)
551		}()
552	}
553
554	wg.Wait()
555}
556
557func TestMap_ConcurrentSeq(t *testing.T) {
558	t.Parallel()
559
560	m := NewMap[int, int]()
561	for i := range 100 {
562		m.Set(i, i*2)
563	}
564
565	var wg sync.WaitGroup
566	const numIterators = 10
567
568	wg.Add(numIterators)
569	for range numIterators {
570		go func() {
571			defer wg.Done()
572			count := 0
573			values := make(map[int]bool)
574			for v := range m.Seq() {
575				values[v] = true
576				count++
577			}
578			require.Equal(t, 100, count)
579			for i := range 100 {
580				require.True(t, values[i*2])
581			}
582		}()
583	}
584
585	wg.Wait()
586}
587
588func TestMap_ConcurrentTake(t *testing.T) {
589	t.Parallel()
590
591	m := NewMap[int, int]()
592	const numItems = 1000
593
594	for i := range numItems {
595		m.Set(i, i*2)
596	}
597
598	var wg sync.WaitGroup
599	const numWorkers = 10
600	taken := make([][]int, numWorkers)
601
602	wg.Add(numWorkers)
603	for i := range numWorkers {
604		go func(workerID int) {
605			defer wg.Done()
606			taken[workerID] = make([]int, 0)
607			for j := workerID; j < numItems; j += numWorkers {
608				if value, ok := m.Take(j); ok {
609					taken[workerID] = append(taken[workerID], value)
610				}
611			}
612		}(i)
613	}
614
615	wg.Wait()
616
617	require.Equal(t, 0, m.Len())
618
619	allTaken := make(map[int]bool)
620	for _, workerTaken := range taken {
621		for _, value := range workerTaken {
622			require.False(t, allTaken[value], "Value %d was taken multiple times", value)
623			allTaken[value] = true
624		}
625	}
626
627	require.Equal(t, numItems, len(allTaken))
628	for i := range numItems {
629		require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
630	}
631}
632
633func TestMap_TypeSafety(t *testing.T) {
634	t.Parallel()
635
636	stringIntMap := NewMap[string, int]()
637	stringIntMap.Set("key", 42)
638	value, ok := stringIntMap.Get("key")
639	require.True(t, ok)
640	require.Equal(t, 42, value)
641
642	intStringMap := NewMap[int, string]()
643	intStringMap.Set(42, "value")
644	strValue, ok := intStringMap.Get(42)
645	require.True(t, ok)
646	require.Equal(t, "value", strValue)
647
648	structMap := NewMap[string, struct{ Name string }]()
649	structMap.Set("key", struct{ Name string }{Name: "test"})
650	structValue, ok := structMap.Get("key")
651	require.True(t, ok)
652	require.Equal(t, "test", structValue.Name)
653}
654
655func TestMap_InterfaceCompliance(t *testing.T) {
656	t.Parallel()
657
658	var _ json.Marshaler = &Map[string, any]{}
659	var _ json.Unmarshaler = &Map[string, any]{}
660}
661
662func BenchmarkMap_Set(b *testing.B) {
663	m := NewMap[int, int]()
664
665	for i := 0; b.Loop(); i++ {
666		m.Set(i, i*2)
667	}
668}
669
670func BenchmarkMap_Get(b *testing.B) {
671	m := NewMap[int, int]()
672	for i := range 1000 {
673		m.Set(i, i*2)
674	}
675
676	for i := 0; b.Loop(); i++ {
677		m.Get(i % 1000)
678	}
679}
680
681func BenchmarkMap_Seq2(b *testing.B) {
682	m := NewMap[int, int]()
683	for i := range 1000 {
684		m.Set(i, i*2)
685	}
686
687	for b.Loop() {
688		for range m.Seq2() {
689		}
690	}
691}
692
693func BenchmarkMap_Seq(b *testing.B) {
694	m := NewMap[int, int]()
695	for i := range 1000 {
696		m.Set(i, i*2)
697	}
698
699	for b.Loop() {
700		for range m.Seq() {
701		}
702	}
703}
704
705func BenchmarkMap_Take(b *testing.B) {
706	m := NewMap[int, int]()
707	for i := range 1000 {
708		m.Set(i, i*2)
709	}
710
711	b.ResetTimer()
712	for i := 0; b.Loop(); i++ {
713		key := i % 1000
714		m.Take(key)
715		if i%1000 == 999 {
716			b.StopTimer()
717			for j := range 1000 {
718				m.Set(j, j*2)
719			}
720			b.StartTimer()
721		}
722	}
723}
724
725func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
726	m := NewMap[int, int]()
727	for i := range 1000 {
728		m.Set(i, i*2)
729	}
730
731	b.ResetTimer()
732	b.RunParallel(func(pb *testing.PB) {
733		i := 0
734		for pb.Next() {
735			if i%2 == 0 {
736				m.Get(i % 1000)
737			} else {
738				m.Set(i+1000, i*2)
739			}
740			i++
741		}
742	})
743}