1package hooks
2
3import (
4 "encoding/base64"
5 "testing"
6
7 "github.com/stretchr/testify/assert"
8 "github.com/stretchr/testify/require"
9)
10
11func TestParseShellEnv(t *testing.T) {
12 t.Run("parses basic fields", func(t *testing.T) {
13 env := []string{
14 "PATH=/usr/bin",
15 "CRUSH_CONTINUE=false",
16 "CRUSH_PERMISSION=approve",
17 "CRUSH_MESSAGE=test message",
18 "HOME=/home/user",
19 }
20
21 result := parseShellEnv(env)
22
23 assert.False(t, result.Continue)
24 assert.Equal(t, "approve", result.Permission)
25 assert.Equal(t, "test message", result.Message)
26 })
27
28 t.Run("parses modified prompt", func(t *testing.T) {
29 env := []string{
30 "CRUSH_MODIFIED_PROMPT=new prompt text",
31 }
32
33 result := parseShellEnv(env)
34
35 require.NotNil(t, result.ModifiedPrompt)
36 assert.Equal(t, "new prompt text", *result.ModifiedPrompt)
37 })
38
39 t.Run("parses context content", func(t *testing.T) {
40 env := []string{
41 "CRUSH_CONTEXT_CONTENT=some context",
42 }
43
44 result := parseShellEnv(env)
45
46 assert.Equal(t, "some context", result.ContextContent)
47 })
48
49 t.Run("parses base64 context content", func(t *testing.T) {
50 text := "multiline\ncontext\nhere"
51 encoded := base64.StdEncoding.EncodeToString([]byte(text))
52
53 env := []string{
54 "CRUSH_CONTEXT_CONTENT=" + encoded,
55 }
56
57 result := parseShellEnv(env)
58
59 assert.Equal(t, text, result.ContextContent)
60 })
61
62 t.Run("parses context files", func(t *testing.T) {
63 env := []string{
64 "CRUSH_CONTEXT_FILES=file1.md:file2.txt:file3.go",
65 }
66
67 result := parseShellEnv(env)
68
69 assert.Equal(t, []string{"file1.md", "file2.txt", "file3.go"}, result.ContextFiles)
70 })
71
72 t.Run("defaults to continue=true", func(t *testing.T) {
73 env := []string{}
74
75 result := parseShellEnv(env)
76
77 assert.True(t, result.Continue)
78 })
79
80 t.Run("ignores non-CRUSH env vars", func(t *testing.T) {
81 env := []string{
82 "PATH=/usr/bin",
83 "HOME=/home/user",
84 "CRUSH_MESSAGE=test",
85 }
86
87 result := parseShellEnv(env)
88
89 assert.Equal(t, "test", result.Message)
90 })
91
92 t.Run("falls back to raw value for invalid base64", func(t *testing.T) {
93 // Invalid base64 string should be used as-is.
94 env := []string{
95 "CRUSH_CONTEXT_CONTENT=this is not base64!@#$",
96 }
97
98 result := parseShellEnv(env)
99
100 assert.Equal(t, "this is not base64!@#$", result.ContextContent)
101 })
102
103 t.Run("parses modified input", func(t *testing.T) {
104 env := []string{
105 "CRUSH_MODIFIED_INPUT=command=ls -la:working_dir=/tmp",
106 }
107
108 result := parseShellEnv(env)
109
110 require.NotNil(t, result.ModifiedInput)
111 assert.Equal(t, "ls -la", result.ModifiedInput["command"])
112 assert.Equal(t, "/tmp", result.ModifiedInput["working_dir"])
113 })
114
115 t.Run("parses modified output", func(t *testing.T) {
116 env := []string{
117 "CRUSH_MODIFIED_OUTPUT=status=redacted:data=[REDACTED]",
118 }
119
120 result := parseShellEnv(env)
121
122 require.NotNil(t, result.ModifiedOutput)
123 assert.Equal(t, "redacted", result.ModifiedOutput["status"])
124 assert.Equal(t, "[REDACTED]", result.ModifiedOutput["data"])
125 })
126
127 t.Run("parses modified input with JSON types", func(t *testing.T) {
128 env := []string{
129 `CRUSH_MODIFIED_INPUT=offset=100:limit=50:run_in_background=true:ignore=["*.log","*.tmp"]`,
130 }
131
132 result := parseShellEnv(env)
133
134 require.NotNil(t, result.ModifiedInput)
135 assert.Equal(t, float64(100), result.ModifiedInput["offset"]) // JSON numbers are float64
136 assert.Equal(t, float64(50), result.ModifiedInput["limit"])
137 assert.Equal(t, true, result.ModifiedInput["run_in_background"])
138 assert.Equal(t, []any{"*.log", "*.tmp"}, result.ModifiedInput["ignore"])
139 })
140
141 t.Run("parses modified input with strings containing colons", func(t *testing.T) {
142 // Colons in file paths should work if the value doesn't contain '='
143 env := []string{
144 `CRUSH_MODIFIED_INPUT=path=/usr/local/bin:name=test`,
145 }
146
147 result := parseShellEnv(env)
148
149 require.NotNil(t, result.ModifiedInput)
150 // First pair: path=/usr/local/bin
151 // Second pair: name=test
152 // Note: This splits on first '=' in each pair
153 assert.Equal(t, "/usr/local/bin", result.ModifiedInput["path"])
154 assert.Equal(t, "test", result.ModifiedInput["name"])
155 })
156}
157
158func TestParseJSONResult(t *testing.T) {
159 t.Run("parses basic fields", func(t *testing.T) {
160 json := []byte(`{
161 "continue": false,
162 "permission": "deny",
163 "message": "blocked"
164 }`)
165
166 result, err := parseJSONResult(json)
167
168 require.NoError(t, err)
169 assert.False(t, result.Continue)
170 assert.Equal(t, "deny", result.Permission)
171 assert.Equal(t, "blocked", result.Message)
172 })
173
174 t.Run("parses modified_input", func(t *testing.T) {
175 json := []byte(`{
176 "modified_input": {
177 "command": "ls -la",
178 "working_dir": "/tmp"
179 }
180 }`)
181
182 result, err := parseJSONResult(json)
183
184 require.NoError(t, err)
185 assert.Equal(t, map[string]any{
186 "command": "ls -la",
187 "working_dir": "/tmp",
188 }, result.ModifiedInput)
189 })
190
191 t.Run("parses modified_output", func(t *testing.T) {
192 json := []byte(`{
193 "modified_output": {
194 "content": "filtered output"
195 }
196 }`)
197
198 result, err := parseJSONResult(json)
199
200 require.NoError(t, err)
201 assert.Equal(t, map[string]any{
202 "content": "filtered output",
203 }, result.ModifiedOutput)
204 })
205
206 t.Run("parses context_files array", func(t *testing.T) {
207 json := []byte(`{
208 "context_files": ["file1.md", "file2.txt"]
209 }`)
210
211 result, err := parseJSONResult(json)
212
213 require.NoError(t, err)
214 assert.Equal(t, []string{"file1.md", "file2.txt"}, result.ContextFiles)
215 })
216
217 t.Run("returns error on invalid JSON", func(t *testing.T) {
218 json := []byte(`{invalid}`)
219
220 _, err := parseJSONResult(json)
221
222 assert.Error(t, err)
223 })
224
225 t.Run("defaults to continue=true", func(t *testing.T) {
226 json := []byte(`{"message": "test"}`)
227
228 result, err := parseJSONResult(json)
229
230 require.NoError(t, err)
231 assert.True(t, result.Continue)
232 })
233
234 t.Run("handles wrong type for modified_input", func(t *testing.T) {
235 // modified_input should be a map, but here it's a string.
236 json := []byte(`{
237 "modified_input": "not a map"
238 }`)
239
240 result, err := parseJSONResult(json)
241
242 require.NoError(t, err)
243 // Should be nil/empty since type assertion failed.
244 assert.Nil(t, result.ModifiedInput)
245 })
246
247 t.Run("handles wrong type for modified_output", func(t *testing.T) {
248 // modified_output should be a map, but here it's an array.
249 json := []byte(`{
250 "modified_output": ["not", "a", "map"]
251 }`)
252
253 result, err := parseJSONResult(json)
254
255 require.NoError(t, err)
256 assert.Nil(t, result.ModifiedOutput)
257 })
258
259 t.Run("handles non-string elements in context_files", func(t *testing.T) {
260 // context_files should be array of strings, but has numbers.
261 json := []byte(`{
262 "context_files": ["file1.md", 123, "file2.md", null]
263 }`)
264
265 result, err := parseJSONResult(json)
266
267 require.NoError(t, err)
268 // Should only include valid strings.
269 assert.Equal(t, []string{"file1.md", "file2.md"}, result.ContextFiles)
270 })
271
272 t.Run("handles wrong type for context_files", func(t *testing.T) {
273 // context_files should be an array, but here it's a string.
274 json := []byte(`{
275 "context_files": "not an array"
276 }`)
277
278 result, err := parseJSONResult(json)
279
280 require.NoError(t, err)
281 // Should be empty since type assertion failed.
282 assert.Empty(t, result.ContextFiles)
283 })
284}
285
286func TestMergeJSONResult(t *testing.T) {
287 t.Run("merges continue flag", func(t *testing.T) {
288 base := &HookResult{Continue: true}
289 json := &HookResult{Continue: false}
290
291 mergeJSONResult(base, json)
292
293 assert.False(t, base.Continue)
294 })
295
296 t.Run("merges permission", func(t *testing.T) {
297 base := &HookResult{}
298 json := &HookResult{Permission: "approve"}
299
300 mergeJSONResult(base, json)
301
302 assert.Equal(t, "approve", base.Permission)
303 })
304
305 t.Run("appends messages", func(t *testing.T) {
306 base := &HookResult{Message: "first"}
307 json := &HookResult{Message: "second"}
308
309 mergeJSONResult(base, json)
310
311 assert.Equal(t, "first; second", base.Message)
312 })
313
314 t.Run("merges modified_input maps", func(t *testing.T) {
315 base := &HookResult{
316 ModifiedInput: map[string]any{
317 "field1": "value1",
318 },
319 }
320 json := &HookResult{
321 ModifiedInput: map[string]any{
322 "field2": "value2",
323 },
324 }
325
326 mergeJSONResult(base, json)
327
328 assert.Equal(t, map[string]any{
329 "field1": "value1",
330 "field2": "value2",
331 }, base.ModifiedInput)
332 })
333
334 t.Run("overwrites conflicting modified_input fields", func(t *testing.T) {
335 base := &HookResult{
336 ModifiedInput: map[string]any{
337 "field": "old",
338 },
339 }
340 json := &HookResult{
341 ModifiedInput: map[string]any{
342 "field": "new",
343 },
344 }
345
346 mergeJSONResult(base, json)
347
348 assert.Equal(t, "new", base.ModifiedInput["field"])
349 })
350
351 t.Run("appends context content", func(t *testing.T) {
352 base := &HookResult{ContextContent: "first"}
353 json := &HookResult{ContextContent: "second"}
354
355 mergeJSONResult(base, json)
356
357 assert.Equal(t, "first\n\nsecond", base.ContextContent)
358 })
359
360 t.Run("appends context files", func(t *testing.T) {
361 base := &HookResult{ContextFiles: []string{"file1.md"}}
362 json := &HookResult{ContextFiles: []string{"file2.md", "file3.md"}}
363
364 mergeJSONResult(base, json)
365
366 assert.Equal(t, []string{"file1.md", "file2.md", "file3.md"}, base.ContextFiles)
367 })
368
369 t.Run("initializes ModifiedInput when nil", func(t *testing.T) {
370 // Base has nil ModifiedInput.
371 base := &HookResult{}
372 json := &HookResult{
373 ModifiedInput: map[string]any{
374 "field": "value",
375 },
376 }
377
378 mergeJSONResult(base, json)
379
380 require.NotNil(t, base.ModifiedInput)
381 assert.Equal(t, "value", base.ModifiedInput["field"])
382 })
383
384 t.Run("initializes ModifiedOutput when nil", func(t *testing.T) {
385 // Base has nil ModifiedOutput.
386 base := &HookResult{}
387 json := &HookResult{
388 ModifiedOutput: map[string]any{
389 "filtered": true,
390 },
391 }
392
393 mergeJSONResult(base, json)
394
395 require.NotNil(t, base.ModifiedOutput)
396 assert.Equal(t, true, base.ModifiedOutput["filtered"])
397 })
398
399 t.Run("sets context content when base is empty", func(t *testing.T) {
400 base := &HookResult{}
401 json := &HookResult{ContextContent: "new content"}
402
403 mergeJSONResult(base, json)
404
405 assert.Equal(t, "new content", base.ContextContent)
406 })
407
408 t.Run("sets message when base is empty", func(t *testing.T) {
409 base := &HookResult{}
410 json := &HookResult{Message: "new message"}
411
412 mergeJSONResult(base, json)
413
414 assert.Equal(t, "new message", base.Message)
415 })
416}