context_test.go

  1package tools
  2
  3import (
  4	"context"
  5	"testing"
  6)
  7
  8// Test-specific context key types to avoid collisions
  9type (
 10	testStringKey string
 11	testBoolKey   string
 12	testIntKey    string
 13)
 14
 15const (
 16	testKey     testStringKey = "testKey"
 17	missingKey  testStringKey = "missingKey"
 18	boolTestKey testBoolKey   = "boolKey"
 19	intTestKey  testIntKey    = "intKey"
 20)
 21
 22func TestGetContextValue(t *testing.T) {
 23	tests := []struct {
 24		name         string
 25		setup        func(ctx context.Context) context.Context
 26		key          any
 27		defaultValue any
 28		want         any
 29	}{
 30		{
 31			name: "returns string value",
 32			setup: func(ctx context.Context) context.Context {
 33				return context.WithValue(ctx, testKey, "testValue")
 34			},
 35			key:          testKey,
 36			defaultValue: "",
 37			want:         "testValue",
 38		},
 39		{
 40			name: "returns default when key not found",
 41			setup: func(ctx context.Context) context.Context {
 42				return ctx
 43			},
 44			key:          missingKey,
 45			defaultValue: "default",
 46			want:         "default",
 47		},
 48		{
 49			name: "returns default when type mismatch",
 50			setup: func(ctx context.Context) context.Context {
 51				return context.WithValue(ctx, testKey, 123) // int, not string
 52			},
 53			key:          testKey,
 54			defaultValue: "default",
 55			want:         "default",
 56		},
 57		{
 58			name: "returns bool value",
 59			setup: func(ctx context.Context) context.Context {
 60				return context.WithValue(ctx, boolTestKey, true)
 61			},
 62			key:          boolTestKey,
 63			defaultValue: false,
 64			want:         true,
 65		},
 66		{
 67			name: "returns int value",
 68			setup: func(ctx context.Context) context.Context {
 69				return context.WithValue(ctx, intTestKey, 42)
 70			},
 71			key:          intTestKey,
 72			defaultValue: 0,
 73			want:         42,
 74		},
 75	}
 76
 77	for _, tt := range tests {
 78		t.Run(tt.name, func(t *testing.T) {
 79			ctx := tt.setup(context.Background())
 80
 81			var got any
 82			switch tt.defaultValue.(type) {
 83			case string:
 84				got = getContextValue(ctx, tt.key, tt.defaultValue.(string))
 85			case bool:
 86				got = getContextValue(ctx, tt.key, tt.defaultValue.(bool))
 87			case int:
 88				got = getContextValue(ctx, tt.key, tt.defaultValue.(int))
 89			}
 90
 91			if got != tt.want {
 92				t.Errorf("getContextValue() = %v, want %v", got, tt.want)
 93			}
 94		})
 95	}
 96}
 97
 98func TestGetSessionFromContext(t *testing.T) {
 99	tests := []struct {
100		name string
101		ctx  context.Context
102		want string
103	}{
104		{
105			name: "returns session ID when present",
106			ctx:  context.WithValue(context.Background(), SessionIDContextKey, "session-123"),
107			want: "session-123",
108		},
109		{
110			name: "returns empty string when not present",
111			ctx:  context.Background(),
112			want: "",
113		},
114		{
115			name: "returns empty string when wrong type",
116			ctx:  context.WithValue(context.Background(), SessionIDContextKey, 123),
117			want: "",
118		},
119	}
120
121	for _, tt := range tests {
122		t.Run(tt.name, func(t *testing.T) {
123			got := GetSessionFromContext(tt.ctx)
124			if got != tt.want {
125				t.Errorf("GetSessionFromContext() = %v, want %v", got, tt.want)
126			}
127		})
128	}
129}
130
131func TestGetMessageFromContext(t *testing.T) {
132	tests := []struct {
133		name string
134		ctx  context.Context
135		want string
136	}{
137		{
138			name: "returns message ID when present",
139			ctx:  context.WithValue(context.Background(), MessageIDContextKey, "msg-456"),
140			want: "msg-456",
141		},
142		{
143			name: "returns empty string when not present",
144			ctx:  context.Background(),
145			want: "",
146		},
147		{
148			name: "returns empty string when wrong type",
149			ctx:  context.WithValue(context.Background(), MessageIDContextKey, 456),
150			want: "",
151		},
152	}
153
154	for _, tt := range tests {
155		t.Run(tt.name, func(t *testing.T) {
156			got := GetMessageFromContext(tt.ctx)
157			if got != tt.want {
158				t.Errorf("GetMessageFromContext() = %v, want %v", got, tt.want)
159			}
160		})
161	}
162}
163
164func TestGetSupportsImagesFromContext(t *testing.T) {
165	tests := []struct {
166		name string
167		ctx  context.Context
168		want bool
169	}{
170		{
171			name: "returns true when present and true",
172			ctx:  context.WithValue(context.Background(), SupportsImagesContextKey, true),
173			want: true,
174		},
175		{
176			name: "returns false when present and false",
177			ctx:  context.WithValue(context.Background(), SupportsImagesContextKey, false),
178			want: false,
179		},
180		{
181			name: "returns false when not present",
182			ctx:  context.Background(),
183			want: false,
184		},
185		{
186			name: "returns false when wrong type",
187			ctx:  context.WithValue(context.Background(), SupportsImagesContextKey, "true"),
188			want: false,
189		},
190	}
191
192	for _, tt := range tests {
193		t.Run(tt.name, func(t *testing.T) {
194			got := GetSupportsImagesFromContext(tt.ctx)
195			if got != tt.want {
196				t.Errorf("GetSupportsImagesFromContext() = %v, want %v", got, tt.want)
197			}
198		})
199	}
200}
201
202func TestGetModelNameFromContext(t *testing.T) {
203	tests := []struct {
204		name string
205		ctx  context.Context
206		want string
207	}{
208		{
209			name: "returns model name when present",
210			ctx:  context.WithValue(context.Background(), ModelNameContextKey, "claude-opus-4"),
211			want: "claude-opus-4",
212		},
213		{
214			name: "returns empty string when not present",
215			ctx:  context.Background(),
216			want: "",
217		},
218		{
219			name: "returns empty string when wrong type",
220			ctx:  context.WithValue(context.Background(), ModelNameContextKey, 789),
221			want: "",
222		},
223	}
224
225	for _, tt := range tests {
226		t.Run(tt.name, func(t *testing.T) {
227			got := GetModelNameFromContext(tt.ctx)
228			if got != tt.want {
229				t.Errorf("GetModelNameFromContext() = %v, want %v", got, tt.want)
230			}
231		})
232	}
233}