1package tools
2
3import (
4 "context"
5 "testing"
6)
7
8// Test-specific context key types to avoid collisions
9type (
10 testStringKey string
11 testBoolKey string
12 testIntKey string
13)
14
15const (
16 testKey testStringKey = "testKey"
17 missingKey testStringKey = "missingKey"
18 boolTestKey testBoolKey = "boolKey"
19 intTestKey testIntKey = "intKey"
20)
21
22func TestGetContextValue(t *testing.T) {
23 tests := []struct {
24 name string
25 setup func(ctx context.Context) context.Context
26 key any
27 defaultValue any
28 want any
29 }{
30 {
31 name: "returns string value",
32 setup: func(ctx context.Context) context.Context {
33 return context.WithValue(ctx, testKey, "testValue")
34 },
35 key: testKey,
36 defaultValue: "",
37 want: "testValue",
38 },
39 {
40 name: "returns default when key not found",
41 setup: func(ctx context.Context) context.Context {
42 return ctx
43 },
44 key: missingKey,
45 defaultValue: "default",
46 want: "default",
47 },
48 {
49 name: "returns default when type mismatch",
50 setup: func(ctx context.Context) context.Context {
51 return context.WithValue(ctx, testKey, 123) // int, not string
52 },
53 key: testKey,
54 defaultValue: "default",
55 want: "default",
56 },
57 {
58 name: "returns bool value",
59 setup: func(ctx context.Context) context.Context {
60 return context.WithValue(ctx, boolTestKey, true)
61 },
62 key: boolTestKey,
63 defaultValue: false,
64 want: true,
65 },
66 {
67 name: "returns int value",
68 setup: func(ctx context.Context) context.Context {
69 return context.WithValue(ctx, intTestKey, 42)
70 },
71 key: intTestKey,
72 defaultValue: 0,
73 want: 42,
74 },
75 }
76
77 for _, tt := range tests {
78 t.Run(tt.name, func(t *testing.T) {
79 ctx := tt.setup(context.Background())
80
81 var got any
82 switch tt.defaultValue.(type) {
83 case string:
84 got = getContextValue(ctx, tt.key, tt.defaultValue.(string))
85 case bool:
86 got = getContextValue(ctx, tt.key, tt.defaultValue.(bool))
87 case int:
88 got = getContextValue(ctx, tt.key, tt.defaultValue.(int))
89 }
90
91 if got != tt.want {
92 t.Errorf("getContextValue() = %v, want %v", got, tt.want)
93 }
94 })
95 }
96}
97
98func TestGetSessionFromContext(t *testing.T) {
99 tests := []struct {
100 name string
101 ctx context.Context
102 want string
103 }{
104 {
105 name: "returns session ID when present",
106 ctx: context.WithValue(context.Background(), SessionIDContextKey, "session-123"),
107 want: "session-123",
108 },
109 {
110 name: "returns empty string when not present",
111 ctx: context.Background(),
112 want: "",
113 },
114 {
115 name: "returns empty string when wrong type",
116 ctx: context.WithValue(context.Background(), SessionIDContextKey, 123),
117 want: "",
118 },
119 }
120
121 for _, tt := range tests {
122 t.Run(tt.name, func(t *testing.T) {
123 got := GetSessionFromContext(tt.ctx)
124 if got != tt.want {
125 t.Errorf("GetSessionFromContext() = %v, want %v", got, tt.want)
126 }
127 })
128 }
129}
130
131func TestGetMessageFromContext(t *testing.T) {
132 tests := []struct {
133 name string
134 ctx context.Context
135 want string
136 }{
137 {
138 name: "returns message ID when present",
139 ctx: context.WithValue(context.Background(), MessageIDContextKey, "msg-456"),
140 want: "msg-456",
141 },
142 {
143 name: "returns empty string when not present",
144 ctx: context.Background(),
145 want: "",
146 },
147 {
148 name: "returns empty string when wrong type",
149 ctx: context.WithValue(context.Background(), MessageIDContextKey, 456),
150 want: "",
151 },
152 }
153
154 for _, tt := range tests {
155 t.Run(tt.name, func(t *testing.T) {
156 got := GetMessageFromContext(tt.ctx)
157 if got != tt.want {
158 t.Errorf("GetMessageFromContext() = %v, want %v", got, tt.want)
159 }
160 })
161 }
162}
163
164func TestGetSupportsImagesFromContext(t *testing.T) {
165 tests := []struct {
166 name string
167 ctx context.Context
168 want bool
169 }{
170 {
171 name: "returns true when present and true",
172 ctx: context.WithValue(context.Background(), SupportsImagesContextKey, true),
173 want: true,
174 },
175 {
176 name: "returns false when present and false",
177 ctx: context.WithValue(context.Background(), SupportsImagesContextKey, false),
178 want: false,
179 },
180 {
181 name: "returns false when not present",
182 ctx: context.Background(),
183 want: false,
184 },
185 {
186 name: "returns false when wrong type",
187 ctx: context.WithValue(context.Background(), SupportsImagesContextKey, "true"),
188 want: false,
189 },
190 }
191
192 for _, tt := range tests {
193 t.Run(tt.name, func(t *testing.T) {
194 got := GetSupportsImagesFromContext(tt.ctx)
195 if got != tt.want {
196 t.Errorf("GetSupportsImagesFromContext() = %v, want %v", got, tt.want)
197 }
198 })
199 }
200}
201
202func TestGetModelNameFromContext(t *testing.T) {
203 tests := []struct {
204 name string
205 ctx context.Context
206 want string
207 }{
208 {
209 name: "returns model name when present",
210 ctx: context.WithValue(context.Background(), ModelNameContextKey, "claude-opus-4"),
211 want: "claude-opus-4",
212 },
213 {
214 name: "returns empty string when not present",
215 ctx: context.Background(),
216 want: "",
217 },
218 {
219 name: "returns empty string when wrong type",
220 ctx: context.WithValue(context.Background(), ModelNameContextKey, 789),
221 want: "",
222 },
223 }
224
225 for _, tt := range tests {
226 t.Run(tt.name, func(t *testing.T) {
227 got := GetModelNameFromContext(tt.ctx)
228 if got != tt.want {
229 t.Errorf("GetModelNameFromContext() = %v, want %v", got, tt.want)
230 }
231 })
232 }
233}