maps_test.go

  1package csync
  2
  3import (
  4	"encoding/json"
  5	"maps"
  6	"sync"
  7	"sync/atomic"
  8	"testing"
  9	"testing/synctest"
 10	"time"
 11
 12	"github.com/stretchr/testify/require"
 13)
 14
 15func TestNewMap(t *testing.T) {
 16	t.Parallel()
 17
 18	m := NewMap[string, int]()
 19	require.NotNil(t, m)
 20	require.NotNil(t, m.inner)
 21	require.Equal(t, 0, m.Len())
 22}
 23
 24func TestNewMapFrom(t *testing.T) {
 25	t.Parallel()
 26
 27	original := map[string]int{
 28		"key1": 1,
 29		"key2": 2,
 30	}
 31
 32	m := NewMapFrom(original)
 33	require.NotNil(t, m)
 34	require.Equal(t, original, m.inner)
 35	require.Equal(t, 2, m.Len())
 36
 37	value, ok := m.Get("key1")
 38	require.True(t, ok)
 39	require.Equal(t, 1, value)
 40}
 41
 42func TestNewLazyMap(t *testing.T) {
 43	t.Parallel()
 44
 45	synctest.Test(t, func(t *testing.T) {
 46		t.Helper()
 47
 48		waiter := sync.Mutex{}
 49		waiter.Lock()
 50		var loadCalled atomic.Bool
 51
 52		loadFunc := func() map[string]int {
 53			waiter.Lock()
 54			defer waiter.Unlock()
 55			loadCalled.Store(true)
 56			return map[string]int{
 57				"key1": 1,
 58				"key2": 2,
 59			}
 60		}
 61
 62		m := NewLazyMap(loadFunc)
 63		require.NotNil(t, m)
 64
 65		waiter.Unlock() // Allow the load function to proceed
 66		time.Sleep(100 * time.Millisecond)
 67		require.True(t, loadCalled.Load())
 68		require.Equal(t, 2, m.Len())
 69
 70		value, ok := m.Get("key1")
 71		require.True(t, ok)
 72		require.Equal(t, 1, value)
 73	})
 74}
 75
 76func TestMap_Reset(t *testing.T) {
 77	t.Parallel()
 78
 79	m := NewMapFrom(map[string]int{
 80		"a": 10,
 81	})
 82
 83	m.Reset(map[string]int{
 84		"b": 20,
 85	})
 86	value, ok := m.Get("b")
 87	require.True(t, ok)
 88	require.Equal(t, 20, value)
 89	require.Equal(t, 1, m.Len())
 90}
 91
 92func TestMap_Set(t *testing.T) {
 93	t.Parallel()
 94
 95	m := NewMap[string, int]()
 96
 97	m.Set("key1", 42)
 98	value, ok := m.Get("key1")
 99	require.True(t, ok)
100	require.Equal(t, 42, value)
101	require.Equal(t, 1, m.Len())
102
103	m.Set("key1", 100)
104	value, ok = m.Get("key1")
105	require.True(t, ok)
106	require.Equal(t, 100, value)
107	require.Equal(t, 1, m.Len())
108}
109
110func TestMap_GetOrSet(t *testing.T) {
111	t.Parallel()
112
113	m := NewMap[string, int]()
114
115	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
116	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
117	require.Equal(t, 1, m.Len())
118}
119
120func TestMap_Get(t *testing.T) {
121	t.Parallel()
122
123	m := NewMap[string, int]()
124
125	value, ok := m.Get("nonexistent")
126	require.False(t, ok)
127	require.Equal(t, 0, value)
128
129	m.Set("key1", 42)
130	value, ok = m.Get("key1")
131	require.True(t, ok)
132	require.Equal(t, 42, value)
133}
134
135func TestMap_Del(t *testing.T) {
136	t.Parallel()
137
138	m := NewMap[string, int]()
139	m.Set("key1", 42)
140	m.Set("key2", 100)
141
142	require.Equal(t, 2, m.Len())
143
144	m.Del("key1")
145	_, ok := m.Get("key1")
146	require.False(t, ok)
147	require.Equal(t, 1, m.Len())
148
149	value, ok := m.Get("key2")
150	require.True(t, ok)
151	require.Equal(t, 100, value)
152
153	m.Del("nonexistent")
154	require.Equal(t, 1, m.Len())
155}
156
157func TestMap_Len(t *testing.T) {
158	t.Parallel()
159
160	m := NewMap[string, int]()
161	require.Equal(t, 0, m.Len())
162
163	m.Set("key1", 1)
164	require.Equal(t, 1, m.Len())
165
166	m.Set("key2", 2)
167	require.Equal(t, 2, m.Len())
168
169	m.Del("key1")
170	require.Equal(t, 1, m.Len())
171
172	m.Del("key2")
173	require.Equal(t, 0, m.Len())
174}
175
176func TestMap_Take(t *testing.T) {
177	t.Parallel()
178
179	m := NewMap[string, int]()
180	m.Set("key1", 42)
181	m.Set("key2", 100)
182
183	require.Equal(t, 2, m.Len())
184
185	value, ok := m.Take("key1")
186	require.True(t, ok)
187	require.Equal(t, 42, value)
188	require.Equal(t, 1, m.Len())
189
190	_, exists := m.Get("key1")
191	require.False(t, exists)
192
193	value, ok = m.Get("key2")
194	require.True(t, ok)
195	require.Equal(t, 100, value)
196}
197
198func TestMap_Take_NonexistentKey(t *testing.T) {
199	t.Parallel()
200
201	m := NewMap[string, int]()
202	m.Set("key1", 42)
203
204	value, ok := m.Take("nonexistent")
205	require.False(t, ok)
206	require.Equal(t, 0, value)
207	require.Equal(t, 1, m.Len())
208
209	value, ok = m.Get("key1")
210	require.True(t, ok)
211	require.Equal(t, 42, value)
212}
213
214func TestMap_Take_EmptyMap(t *testing.T) {
215	t.Parallel()
216
217	m := NewMap[string, int]()
218
219	value, ok := m.Take("key1")
220	require.False(t, ok)
221	require.Equal(t, 0, value)
222	require.Equal(t, 0, m.Len())
223}
224
225func TestMap_Take_SameKeyTwice(t *testing.T) {
226	t.Parallel()
227
228	m := NewMap[string, int]()
229	m.Set("key1", 42)
230
231	value, ok := m.Take("key1")
232	require.True(t, ok)
233	require.Equal(t, 42, value)
234	require.Equal(t, 0, m.Len())
235
236	value, ok = m.Take("key1")
237	require.False(t, ok)
238	require.Equal(t, 0, value)
239	require.Equal(t, 0, m.Len())
240}
241
242func TestMap_Seq2(t *testing.T) {
243	t.Parallel()
244
245	m := NewMap[string, int]()
246	m.Set("key1", 1)
247	m.Set("key2", 2)
248	m.Set("key3", 3)
249
250	collected := maps.Collect(m.Seq2())
251
252	require.Equal(t, 3, len(collected))
253	require.Equal(t, 1, collected["key1"])
254	require.Equal(t, 2, collected["key2"])
255	require.Equal(t, 3, collected["key3"])
256}
257
258func TestMap_Seq2_EarlyReturn(t *testing.T) {
259	t.Parallel()
260
261	m := NewMap[string, int]()
262	m.Set("key1", 1)
263	m.Set("key2", 2)
264	m.Set("key3", 3)
265
266	count := 0
267	for range m.Seq2() {
268		count++
269		if count == 2 {
270			break
271		}
272	}
273
274	require.Equal(t, 2, count)
275}
276
277func TestMap_Seq2_EmptyMap(t *testing.T) {
278	t.Parallel()
279
280	m := NewMap[string, int]()
281
282	count := 0
283	for range m.Seq2() {
284		count++
285	}
286
287	require.Equal(t, 0, count)
288}
289
290func TestMap_Seq(t *testing.T) {
291	t.Parallel()
292
293	m := NewMap[string, int]()
294	m.Set("key1", 1)
295	m.Set("key2", 2)
296	m.Set("key3", 3)
297
298	collected := make([]int, 0)
299	for v := range m.Seq() {
300		collected = append(collected, v)
301	}
302
303	require.Equal(t, 3, len(collected))
304	require.Contains(t, collected, 1)
305	require.Contains(t, collected, 2)
306	require.Contains(t, collected, 3)
307}
308
309func TestMap_Seq_EarlyReturn(t *testing.T) {
310	t.Parallel()
311
312	m := NewMap[string, int]()
313	m.Set("key1", 1)
314	m.Set("key2", 2)
315	m.Set("key3", 3)
316
317	count := 0
318	for range m.Seq() {
319		count++
320		if count == 2 {
321			break
322		}
323	}
324
325	require.Equal(t, 2, count)
326}
327
328func TestMap_Seq_EmptyMap(t *testing.T) {
329	t.Parallel()
330
331	m := NewMap[string, int]()
332
333	count := 0
334	for range m.Seq() {
335		count++
336	}
337
338	require.Equal(t, 0, count)
339}
340
341func TestMap_MarshalJSON(t *testing.T) {
342	t.Parallel()
343
344	m := NewMap[string, int]()
345	m.Set("key1", 1)
346	m.Set("key2", 2)
347
348	data, err := json.Marshal(m)
349	require.NoError(t, err)
350
351	result := &Map[string, int]{}
352	err = json.Unmarshal(data, result)
353	require.NoError(t, err)
354	require.Equal(t, 2, result.Len())
355	v1, _ := result.Get("key1")
356	v2, _ := result.Get("key2")
357	require.Equal(t, 1, v1)
358	require.Equal(t, 2, v2)
359}
360
361func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
362	t.Parallel()
363
364	m := NewMap[string, int]()
365
366	data, err := json.Marshal(m)
367	require.NoError(t, err)
368	require.Equal(t, "{}", string(data))
369}
370
371func TestMap_UnmarshalJSON(t *testing.T) {
372	t.Parallel()
373
374	jsonData := `{"key1": 1, "key2": 2}`
375
376	m := NewMap[string, int]()
377	err := json.Unmarshal([]byte(jsonData), m)
378	require.NoError(t, err)
379
380	require.Equal(t, 2, m.Len())
381	value, ok := m.Get("key1")
382	require.True(t, ok)
383	require.Equal(t, 1, value)
384
385	value, ok = m.Get("key2")
386	require.True(t, ok)
387	require.Equal(t, 2, value)
388}
389
390func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
391	t.Parallel()
392
393	jsonData := `{}`
394
395	m := NewMap[string, int]()
396	err := json.Unmarshal([]byte(jsonData), m)
397	require.NoError(t, err)
398	require.Equal(t, 0, m.Len())
399}
400
401func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
402	t.Parallel()
403
404	jsonData := `{"key1": 1, "key2":}`
405
406	m := NewMap[string, int]()
407	err := json.Unmarshal([]byte(jsonData), m)
408	require.Error(t, err)
409}
410
411func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
412	t.Parallel()
413
414	m := NewMap[string, int]()
415	m.Set("existing", 999)
416
417	jsonData := `{"key1": 1, "key2": 2}`
418	err := json.Unmarshal([]byte(jsonData), m)
419	require.NoError(t, err)
420
421	require.Equal(t, 2, m.Len())
422	_, ok := m.Get("existing")
423	require.False(t, ok)
424
425	value, ok := m.Get("key1")
426	require.True(t, ok)
427	require.Equal(t, 1, value)
428}
429
430func TestMap_JSONRoundTrip(t *testing.T) {
431	t.Parallel()
432
433	original := NewMap[string, int]()
434	original.Set("key1", 1)
435	original.Set("key2", 2)
436	original.Set("key3", 3)
437
438	data, err := json.Marshal(original)
439	require.NoError(t, err)
440
441	restored := NewMap[string, int]()
442	err = json.Unmarshal(data, restored)
443	require.NoError(t, err)
444
445	require.Equal(t, original.Len(), restored.Len())
446
447	for k, v := range original.Seq2() {
448		restoredValue, ok := restored.Get(k)
449		require.True(t, ok)
450		require.Equal(t, v, restoredValue)
451	}
452}
453
454func TestMap_ConcurrentAccess(t *testing.T) {
455	t.Parallel()
456
457	m := NewMap[int, int]()
458	const numGoroutines = 100
459	const numOperations = 100
460
461	var wg sync.WaitGroup
462	wg.Add(numGoroutines)
463
464	for i := range numGoroutines {
465		go func(id int) {
466			defer wg.Done()
467			for j := range numOperations {
468				key := id*numOperations + j
469				m.Set(key, key*2)
470				value, ok := m.Get(key)
471				require.True(t, ok)
472				require.Equal(t, key*2, value)
473			}
474		}(i)
475	}
476
477	wg.Wait()
478
479	require.Equal(t, numGoroutines*numOperations, m.Len())
480}
481
482func TestMap_ConcurrentReadWrite(t *testing.T) {
483	t.Parallel()
484
485	m := NewMap[int, int]()
486	const numReaders = 50
487	const numWriters = 50
488	const numOperations = 100
489
490	for i := range 1000 {
491		m.Set(i, i)
492	}
493
494	var wg sync.WaitGroup
495	wg.Add(numReaders + numWriters)
496
497	for range numReaders {
498		go func() {
499			defer wg.Done()
500			for j := range numOperations {
501				key := j % 1000
502				value, ok := m.Get(key)
503				if ok {
504					require.Equal(t, key, value)
505				}
506				_ = m.Len()
507			}
508		}()
509	}
510
511	for i := range numWriters {
512		go func(id int) {
513			defer wg.Done()
514			for j := range numOperations {
515				key := 1000 + id*numOperations + j
516				m.Set(key, key)
517				if j%10 == 0 {
518					m.Del(key)
519				}
520			}
521		}(i)
522	}
523
524	wg.Wait()
525}
526
527func TestMap_ConcurrentSeq2(t *testing.T) {
528	t.Parallel()
529
530	m := NewMap[int, int]()
531	for i := range 100 {
532		m.Set(i, i*2)
533	}
534
535	var wg sync.WaitGroup
536	const numIterators = 10
537
538	wg.Add(numIterators)
539	for range numIterators {
540		go func() {
541			defer wg.Done()
542			count := 0
543			for k, v := range m.Seq2() {
544				require.Equal(t, k*2, v)
545				count++
546			}
547			require.Equal(t, 100, count)
548		}()
549	}
550
551	wg.Wait()
552}
553
554func TestMap_ConcurrentSeq(t *testing.T) {
555	t.Parallel()
556
557	m := NewMap[int, int]()
558	for i := range 100 {
559		m.Set(i, i*2)
560	}
561
562	var wg sync.WaitGroup
563	const numIterators = 10
564
565	wg.Add(numIterators)
566	for range numIterators {
567		go func() {
568			defer wg.Done()
569			count := 0
570			values := make(map[int]bool)
571			for v := range m.Seq() {
572				values[v] = true
573				count++
574			}
575			require.Equal(t, 100, count)
576			for i := range 100 {
577				require.True(t, values[i*2])
578			}
579		}()
580	}
581
582	wg.Wait()
583}
584
585func TestMap_ConcurrentTake(t *testing.T) {
586	t.Parallel()
587
588	m := NewMap[int, int]()
589	const numItems = 1000
590
591	for i := range numItems {
592		m.Set(i, i*2)
593	}
594
595	var wg sync.WaitGroup
596	const numWorkers = 10
597	taken := make([][]int, numWorkers)
598
599	wg.Add(numWorkers)
600	for i := range numWorkers {
601		go func(workerID int) {
602			defer wg.Done()
603			taken[workerID] = make([]int, 0)
604			for j := workerID; j < numItems; j += numWorkers {
605				if value, ok := m.Take(j); ok {
606					taken[workerID] = append(taken[workerID], value)
607				}
608			}
609		}(i)
610	}
611
612	wg.Wait()
613
614	require.Equal(t, 0, m.Len())
615
616	allTaken := make(map[int]bool)
617	for _, workerTaken := range taken {
618		for _, value := range workerTaken {
619			require.False(t, allTaken[value], "Value %d was taken multiple times", value)
620			allTaken[value] = true
621		}
622	}
623
624	require.Equal(t, numItems, len(allTaken))
625	for i := range numItems {
626		require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
627	}
628}
629
630func TestMap_TypeSafety(t *testing.T) {
631	t.Parallel()
632
633	stringIntMap := NewMap[string, int]()
634	stringIntMap.Set("key", 42)
635	value, ok := stringIntMap.Get("key")
636	require.True(t, ok)
637	require.Equal(t, 42, value)
638
639	intStringMap := NewMap[int, string]()
640	intStringMap.Set(42, "value")
641	strValue, ok := intStringMap.Get(42)
642	require.True(t, ok)
643	require.Equal(t, "value", strValue)
644
645	structMap := NewMap[string, struct{ Name string }]()
646	structMap.Set("key", struct{ Name string }{Name: "test"})
647	structValue, ok := structMap.Get("key")
648	require.True(t, ok)
649	require.Equal(t, "test", structValue.Name)
650}
651
652func TestMap_InterfaceCompliance(t *testing.T) {
653	t.Parallel()
654
655	var _ json.Marshaler = &Map[string, any]{}
656	var _ json.Unmarshaler = &Map[string, any]{}
657}
658
659func BenchmarkMap_Set(b *testing.B) {
660	m := NewMap[int, int]()
661
662	for i := 0; b.Loop(); i++ {
663		m.Set(i, i*2)
664	}
665}
666
667func BenchmarkMap_Get(b *testing.B) {
668	m := NewMap[int, int]()
669	for i := range 1000 {
670		m.Set(i, i*2)
671	}
672
673	for i := 0; b.Loop(); i++ {
674		m.Get(i % 1000)
675	}
676}
677
678func BenchmarkMap_Seq2(b *testing.B) {
679	m := NewMap[int, int]()
680	for i := range 1000 {
681		m.Set(i, i*2)
682	}
683
684	for b.Loop() {
685		for range m.Seq2() {
686		}
687	}
688}
689
690func BenchmarkMap_Seq(b *testing.B) {
691	m := NewMap[int, int]()
692	for i := range 1000 {
693		m.Set(i, i*2)
694	}
695
696	for b.Loop() {
697		for range m.Seq() {
698		}
699	}
700}
701
702func BenchmarkMap_Take(b *testing.B) {
703	m := NewMap[int, int]()
704	for i := range 1000 {
705		m.Set(i, i*2)
706	}
707
708	b.ResetTimer()
709	for i := 0; b.Loop(); i++ {
710		key := i % 1000
711		m.Take(key)
712		if i%1000 == 999 {
713			b.StopTimer()
714			for j := range 1000 {
715				m.Set(j, j*2)
716			}
717			b.StartTimer()
718		}
719	}
720}
721
722func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
723	m := NewMap[int, int]()
724	for i := range 1000 {
725		m.Set(i, i*2)
726	}
727
728	b.ResetTimer()
729	b.RunParallel(func(pb *testing.PB) {
730		i := 0
731		for pb.Next() {
732			if i%2 == 0 {
733				m.Get(i % 1000)
734			} else {
735				m.Set(i+1000, i*2)
736			}
737			i++
738		}
739	})
740}