maps_test.go

  1package csync
  2
  3import (
  4	"encoding/json"
  5	"maps"
  6	"sync"
  7	"testing"
  8	"testing/synctest"
  9	"time"
 10
 11	"github.com/stretchr/testify/require"
 12)
 13
 14func TestNewMap(t *testing.T) {
 15	t.Parallel()
 16
 17	m := NewMap[string, int]()
 18	require.NotNil(t, m)
 19	require.NotNil(t, m.inner)
 20	require.Equal(t, 0, m.Len())
 21}
 22
 23func TestNewMapFrom(t *testing.T) {
 24	t.Parallel()
 25
 26	original := map[string]int{
 27		"key1": 1,
 28		"key2": 2,
 29	}
 30
 31	m := NewMapFrom(original)
 32	require.NotNil(t, m)
 33	require.Equal(t, original, m.inner)
 34	require.Equal(t, 2, m.Len())
 35
 36	value, ok := m.Get("key1")
 37	require.True(t, ok)
 38	require.Equal(t, 1, value)
 39}
 40
 41func TestNewLazyMap(t *testing.T) {
 42	t.Parallel()
 43
 44	synctest.Test(t, func(t *testing.T) {
 45		t.Helper()
 46
 47		waiter := sync.Mutex{}
 48		waiter.Lock()
 49		loadCalled := false
 50
 51		loadFunc := func() map[string]int {
 52			waiter.Lock()
 53			defer waiter.Unlock()
 54			loadCalled = true
 55			return map[string]int{
 56				"key1": 1,
 57				"key2": 2,
 58			}
 59		}
 60
 61		m := NewLazyMap(loadFunc)
 62		require.NotNil(t, m)
 63
 64		waiter.Unlock() // Allow the load function to proceed
 65		time.Sleep(100 * time.Millisecond)
 66		require.True(t, loadCalled)
 67		require.Equal(t, 2, m.Len())
 68
 69		value, ok := m.Get("key1")
 70		require.True(t, ok)
 71		require.Equal(t, 1, value)
 72	})
 73}
 74
 75func TestMap_Set(t *testing.T) {
 76	t.Parallel()
 77
 78	m := NewMap[string, int]()
 79
 80	m.Set("key1", 42)
 81	value, ok := m.Get("key1")
 82	require.True(t, ok)
 83	require.Equal(t, 42, value)
 84	require.Equal(t, 1, m.Len())
 85
 86	m.Set("key1", 100)
 87	value, ok = m.Get("key1")
 88	require.True(t, ok)
 89	require.Equal(t, 100, value)
 90	require.Equal(t, 1, m.Len())
 91}
 92
 93func TestMap_GetOrSet(t *testing.T) {
 94	t.Parallel()
 95
 96	m := NewMap[string, int]()
 97
 98	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
 99	require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
100	require.Equal(t, 1, m.Len())
101}
102
103func TestMap_Get(t *testing.T) {
104	t.Parallel()
105
106	m := NewMap[string, int]()
107
108	value, ok := m.Get("nonexistent")
109	require.False(t, ok)
110	require.Equal(t, 0, value)
111
112	m.Set("key1", 42)
113	value, ok = m.Get("key1")
114	require.True(t, ok)
115	require.Equal(t, 42, value)
116}
117
118func TestMap_Del(t *testing.T) {
119	t.Parallel()
120
121	m := NewMap[string, int]()
122	m.Set("key1", 42)
123	m.Set("key2", 100)
124
125	require.Equal(t, 2, m.Len())
126
127	m.Del("key1")
128	_, ok := m.Get("key1")
129	require.False(t, ok)
130	require.Equal(t, 1, m.Len())
131
132	value, ok := m.Get("key2")
133	require.True(t, ok)
134	require.Equal(t, 100, value)
135
136	m.Del("nonexistent")
137	require.Equal(t, 1, m.Len())
138}
139
140func TestMap_Len(t *testing.T) {
141	t.Parallel()
142
143	m := NewMap[string, int]()
144	require.Equal(t, 0, m.Len())
145
146	m.Set("key1", 1)
147	require.Equal(t, 1, m.Len())
148
149	m.Set("key2", 2)
150	require.Equal(t, 2, m.Len())
151
152	m.Del("key1")
153	require.Equal(t, 1, m.Len())
154
155	m.Del("key2")
156	require.Equal(t, 0, m.Len())
157}
158
159func TestMap_Take(t *testing.T) {
160	t.Parallel()
161
162	m := NewMap[string, int]()
163	m.Set("key1", 42)
164	m.Set("key2", 100)
165
166	require.Equal(t, 2, m.Len())
167
168	value, ok := m.Take("key1")
169	require.True(t, ok)
170	require.Equal(t, 42, value)
171	require.Equal(t, 1, m.Len())
172
173	_, exists := m.Get("key1")
174	require.False(t, exists)
175
176	value, ok = m.Get("key2")
177	require.True(t, ok)
178	require.Equal(t, 100, value)
179}
180
181func TestMap_Take_NonexistentKey(t *testing.T) {
182	t.Parallel()
183
184	m := NewMap[string, int]()
185	m.Set("key1", 42)
186
187	value, ok := m.Take("nonexistent")
188	require.False(t, ok)
189	require.Equal(t, 0, value)
190	require.Equal(t, 1, m.Len())
191
192	value, ok = m.Get("key1")
193	require.True(t, ok)
194	require.Equal(t, 42, value)
195}
196
197func TestMap_Take_EmptyMap(t *testing.T) {
198	t.Parallel()
199
200	m := NewMap[string, int]()
201
202	value, ok := m.Take("key1")
203	require.False(t, ok)
204	require.Equal(t, 0, value)
205	require.Equal(t, 0, m.Len())
206}
207
208func TestMap_Take_SameKeyTwice(t *testing.T) {
209	t.Parallel()
210
211	m := NewMap[string, int]()
212	m.Set("key1", 42)
213
214	value, ok := m.Take("key1")
215	require.True(t, ok)
216	require.Equal(t, 42, value)
217	require.Equal(t, 0, m.Len())
218
219	value, ok = m.Take("key1")
220	require.False(t, ok)
221	require.Equal(t, 0, value)
222	require.Equal(t, 0, m.Len())
223}
224
225func TestMap_Seq2(t *testing.T) {
226	t.Parallel()
227
228	m := NewMap[string, int]()
229	m.Set("key1", 1)
230	m.Set("key2", 2)
231	m.Set("key3", 3)
232
233	collected := maps.Collect(m.Seq2())
234
235	require.Equal(t, 3, len(collected))
236	require.Equal(t, 1, collected["key1"])
237	require.Equal(t, 2, collected["key2"])
238	require.Equal(t, 3, collected["key3"])
239}
240
241func TestMap_Seq2_EarlyReturn(t *testing.T) {
242	t.Parallel()
243
244	m := NewMap[string, int]()
245	m.Set("key1", 1)
246	m.Set("key2", 2)
247	m.Set("key3", 3)
248
249	count := 0
250	for range m.Seq2() {
251		count++
252		if count == 2 {
253			break
254		}
255	}
256
257	require.Equal(t, 2, count)
258}
259
260func TestMap_Seq2_EmptyMap(t *testing.T) {
261	t.Parallel()
262
263	m := NewMap[string, int]()
264
265	count := 0
266	for range m.Seq2() {
267		count++
268	}
269
270	require.Equal(t, 0, count)
271}
272
273func TestMap_Seq(t *testing.T) {
274	t.Parallel()
275
276	m := NewMap[string, int]()
277	m.Set("key1", 1)
278	m.Set("key2", 2)
279	m.Set("key3", 3)
280
281	collected := make([]int, 0)
282	for v := range m.Seq() {
283		collected = append(collected, v)
284	}
285
286	require.Equal(t, 3, len(collected))
287	require.Contains(t, collected, 1)
288	require.Contains(t, collected, 2)
289	require.Contains(t, collected, 3)
290}
291
292func TestMap_Seq_EarlyReturn(t *testing.T) {
293	t.Parallel()
294
295	m := NewMap[string, int]()
296	m.Set("key1", 1)
297	m.Set("key2", 2)
298	m.Set("key3", 3)
299
300	count := 0
301	for range m.Seq() {
302		count++
303		if count == 2 {
304			break
305		}
306	}
307
308	require.Equal(t, 2, count)
309}
310
311func TestMap_Seq_EmptyMap(t *testing.T) {
312	t.Parallel()
313
314	m := NewMap[string, int]()
315
316	count := 0
317	for range m.Seq() {
318		count++
319	}
320
321	require.Equal(t, 0, count)
322}
323
324func TestMap_MarshalJSON(t *testing.T) {
325	t.Parallel()
326
327	m := NewMap[string, int]()
328	m.Set("key1", 1)
329	m.Set("key2", 2)
330
331	data, err := json.Marshal(m)
332	require.NoError(t, err)
333
334	result := &Map[string, int]{}
335	err = json.Unmarshal(data, result)
336	require.NoError(t, err)
337	require.Equal(t, 2, result.Len())
338	v1, _ := result.Get("key1")
339	v2, _ := result.Get("key2")
340	require.Equal(t, 1, v1)
341	require.Equal(t, 2, v2)
342}
343
344func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
345	t.Parallel()
346
347	m := NewMap[string, int]()
348
349	data, err := json.Marshal(m)
350	require.NoError(t, err)
351	require.Equal(t, "{}", string(data))
352}
353
354func TestMap_UnmarshalJSON(t *testing.T) {
355	t.Parallel()
356
357	jsonData := `{"key1": 1, "key2": 2}`
358
359	m := NewMap[string, int]()
360	err := json.Unmarshal([]byte(jsonData), m)
361	require.NoError(t, err)
362
363	require.Equal(t, 2, m.Len())
364	value, ok := m.Get("key1")
365	require.True(t, ok)
366	require.Equal(t, 1, value)
367
368	value, ok = m.Get("key2")
369	require.True(t, ok)
370	require.Equal(t, 2, value)
371}
372
373func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
374	t.Parallel()
375
376	jsonData := `{}`
377
378	m := NewMap[string, int]()
379	err := json.Unmarshal([]byte(jsonData), m)
380	require.NoError(t, err)
381	require.Equal(t, 0, m.Len())
382}
383
384func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
385	t.Parallel()
386
387	jsonData := `{"key1": 1, "key2":}`
388
389	m := NewMap[string, int]()
390	err := json.Unmarshal([]byte(jsonData), m)
391	require.Error(t, err)
392}
393
394func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
395	t.Parallel()
396
397	m := NewMap[string, int]()
398	m.Set("existing", 999)
399
400	jsonData := `{"key1": 1, "key2": 2}`
401	err := json.Unmarshal([]byte(jsonData), m)
402	require.NoError(t, err)
403
404	require.Equal(t, 2, m.Len())
405	_, ok := m.Get("existing")
406	require.False(t, ok)
407
408	value, ok := m.Get("key1")
409	require.True(t, ok)
410	require.Equal(t, 1, value)
411}
412
413func TestMap_JSONRoundTrip(t *testing.T) {
414	t.Parallel()
415
416	original := NewMap[string, int]()
417	original.Set("key1", 1)
418	original.Set("key2", 2)
419	original.Set("key3", 3)
420
421	data, err := json.Marshal(original)
422	require.NoError(t, err)
423
424	restored := NewMap[string, int]()
425	err = json.Unmarshal(data, restored)
426	require.NoError(t, err)
427
428	require.Equal(t, original.Len(), restored.Len())
429
430	for k, v := range original.Seq2() {
431		restoredValue, ok := restored.Get(k)
432		require.True(t, ok)
433		require.Equal(t, v, restoredValue)
434	}
435}
436
437func TestMap_ConcurrentAccess(t *testing.T) {
438	t.Parallel()
439
440	m := NewMap[int, int]()
441	const numGoroutines = 100
442	const numOperations = 100
443
444	var wg sync.WaitGroup
445	wg.Add(numGoroutines)
446
447	for i := range numGoroutines {
448		go func(id int) {
449			defer wg.Done()
450			for j := range numOperations {
451				key := id*numOperations + j
452				m.Set(key, key*2)
453				value, ok := m.Get(key)
454				require.True(t, ok)
455				require.Equal(t, key*2, value)
456			}
457		}(i)
458	}
459
460	wg.Wait()
461
462	require.Equal(t, numGoroutines*numOperations, m.Len())
463}
464
465func TestMap_ConcurrentReadWrite(t *testing.T) {
466	t.Parallel()
467
468	m := NewMap[int, int]()
469	const numReaders = 50
470	const numWriters = 50
471	const numOperations = 100
472
473	for i := range 1000 {
474		m.Set(i, i)
475	}
476
477	var wg sync.WaitGroup
478	wg.Add(numReaders + numWriters)
479
480	for range numReaders {
481		go func() {
482			defer wg.Done()
483			for j := range numOperations {
484				key := j % 1000
485				value, ok := m.Get(key)
486				if ok {
487					require.Equal(t, key, value)
488				}
489				_ = m.Len()
490			}
491		}()
492	}
493
494	for i := range numWriters {
495		go func(id int) {
496			defer wg.Done()
497			for j := range numOperations {
498				key := 1000 + id*numOperations + j
499				m.Set(key, key)
500				if j%10 == 0 {
501					m.Del(key)
502				}
503			}
504		}(i)
505	}
506
507	wg.Wait()
508}
509
510func TestMap_ConcurrentSeq2(t *testing.T) {
511	t.Parallel()
512
513	m := NewMap[int, int]()
514	for i := range 100 {
515		m.Set(i, i*2)
516	}
517
518	var wg sync.WaitGroup
519	const numIterators = 10
520
521	wg.Add(numIterators)
522	for range numIterators {
523		go func() {
524			defer wg.Done()
525			count := 0
526			for k, v := range m.Seq2() {
527				require.Equal(t, k*2, v)
528				count++
529			}
530			require.Equal(t, 100, count)
531		}()
532	}
533
534	wg.Wait()
535}
536
537func TestMap_ConcurrentSeq(t *testing.T) {
538	t.Parallel()
539
540	m := NewMap[int, int]()
541	for i := range 100 {
542		m.Set(i, i*2)
543	}
544
545	var wg sync.WaitGroup
546	const numIterators = 10
547
548	wg.Add(numIterators)
549	for range numIterators {
550		go func() {
551			defer wg.Done()
552			count := 0
553			values := make(map[int]bool)
554			for v := range m.Seq() {
555				values[v] = true
556				count++
557			}
558			require.Equal(t, 100, count)
559			for i := range 100 {
560				require.True(t, values[i*2])
561			}
562		}()
563	}
564
565	wg.Wait()
566}
567
568func TestMap_ConcurrentTake(t *testing.T) {
569	t.Parallel()
570
571	m := NewMap[int, int]()
572	const numItems = 1000
573
574	for i := range numItems {
575		m.Set(i, i*2)
576	}
577
578	var wg sync.WaitGroup
579	const numWorkers = 10
580	taken := make([][]int, numWorkers)
581
582	wg.Add(numWorkers)
583	for i := range numWorkers {
584		go func(workerID int) {
585			defer wg.Done()
586			taken[workerID] = make([]int, 0)
587			for j := workerID; j < numItems; j += numWorkers {
588				if value, ok := m.Take(j); ok {
589					taken[workerID] = append(taken[workerID], value)
590				}
591			}
592		}(i)
593	}
594
595	wg.Wait()
596
597	require.Equal(t, 0, m.Len())
598
599	allTaken := make(map[int]bool)
600	for _, workerTaken := range taken {
601		for _, value := range workerTaken {
602			require.False(t, allTaken[value], "Value %d was taken multiple times", value)
603			allTaken[value] = true
604		}
605	}
606
607	require.Equal(t, numItems, len(allTaken))
608	for i := range numItems {
609		require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
610	}
611}
612
613func TestMap_TypeSafety(t *testing.T) {
614	t.Parallel()
615
616	stringIntMap := NewMap[string, int]()
617	stringIntMap.Set("key", 42)
618	value, ok := stringIntMap.Get("key")
619	require.True(t, ok)
620	require.Equal(t, 42, value)
621
622	intStringMap := NewMap[int, string]()
623	intStringMap.Set(42, "value")
624	strValue, ok := intStringMap.Get(42)
625	require.True(t, ok)
626	require.Equal(t, "value", strValue)
627
628	structMap := NewMap[string, struct{ Name string }]()
629	structMap.Set("key", struct{ Name string }{Name: "test"})
630	structValue, ok := structMap.Get("key")
631	require.True(t, ok)
632	require.Equal(t, "test", structValue.Name)
633}
634
635func TestMap_InterfaceCompliance(t *testing.T) {
636	t.Parallel()
637
638	var _ json.Marshaler = &Map[string, any]{}
639	var _ json.Unmarshaler = &Map[string, any]{}
640}
641
642func BenchmarkMap_Set(b *testing.B) {
643	m := NewMap[int, int]()
644
645	for i := 0; b.Loop(); i++ {
646		m.Set(i, i*2)
647	}
648}
649
650func BenchmarkMap_Get(b *testing.B) {
651	m := NewMap[int, int]()
652	for i := range 1000 {
653		m.Set(i, i*2)
654	}
655
656	for i := 0; b.Loop(); i++ {
657		m.Get(i % 1000)
658	}
659}
660
661func BenchmarkMap_Seq2(b *testing.B) {
662	m := NewMap[int, int]()
663	for i := range 1000 {
664		m.Set(i, i*2)
665	}
666
667	for b.Loop() {
668		for range m.Seq2() {
669		}
670	}
671}
672
673func BenchmarkMap_Seq(b *testing.B) {
674	m := NewMap[int, int]()
675	for i := range 1000 {
676		m.Set(i, i*2)
677	}
678
679	for b.Loop() {
680		for range m.Seq() {
681		}
682	}
683}
684
685func BenchmarkMap_Take(b *testing.B) {
686	m := NewMap[int, int]()
687	for i := range 1000 {
688		m.Set(i, i*2)
689	}
690
691	b.ResetTimer()
692	for i := 0; b.Loop(); i++ {
693		key := i % 1000
694		m.Take(key)
695		if i%1000 == 999 {
696			b.StopTimer()
697			for j := range 1000 {
698				m.Set(j, j*2)
699			}
700			b.StartTimer()
701		}
702	}
703}
704
705func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
706	m := NewMap[int, int]()
707	for i := range 1000 {
708		m.Set(i, i*2)
709	}
710
711	b.ResetTimer()
712	b.RunParallel(func(pb *testing.PB) {
713		i := 0
714		for pb.Next() {
715			if i%2 == 0 {
716				m.Get(i % 1000)
717			} else {
718				m.Set(i+1000, i*2)
719			}
720			i++
721		}
722	})
723}