examples_test.go

  1package hooks
  2
  3import (
  4	"context"
  5	"os"
  6	"path/filepath"
  7	"strings"
  8	"testing"
  9
 10	"github.com/stretchr/testify/assert"
 11	"github.com/stretchr/testify/require"
 12)
 13
 14// TestReadmeExamples tests that all examples from the README work as documented.
 15func TestReadmeExamples(t *testing.T) {
 16	t.Parallel()
 17
 18	t.Run("block dangerous commands", func(t *testing.T) {
 19		t.Parallel()
 20		tempDir := t.TempDir()
 21		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use")
 22		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
 23
 24		hookScript := `#!/bin/bash
 25if [ "$CRUSH_TOOL_NAME" = "bash" ]; then
 26  COMMAND=$(crush_get_tool_input command)
 27  if [[ "$COMMAND" =~ "rm -rf /" ]]; then
 28    crush_deny "Blocked dangerous command"
 29  fi
 30fi
 31`
 32		hookPath := filepath.Join(hooksDir, "01-block-dangerous.sh")
 33		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
 34
 35		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
 36
 37		// Test: Should block "rm -rf /"
 38		result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
 39			SessionID:  "test",
 40			WorkingDir: tempDir,
 41			ToolName:   "bash",
 42			ToolCallID: "call-1",
 43			Data: map[string]any{
 44				"tool_input": map[string]any{
 45					"command": "rm -rf /",
 46				},
 47			},
 48		})
 49
 50		require.NoError(t, err)
 51		assert.False(t, result.Continue, "Should stop execution for dangerous command")
 52		assert.Equal(t, "deny", result.Permission)
 53		assert.Contains(t, result.Message, "Blocked dangerous command")
 54
 55		// Test: Should allow safe commands
 56		result2, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
 57			SessionID:  "test",
 58			WorkingDir: tempDir,
 59			ToolName:   "bash",
 60			ToolCallID: "call-2",
 61			Data: map[string]any{
 62				"tool_input": map[string]any{
 63					"command": "ls -la",
 64				},
 65			},
 66		})
 67
 68		require.NoError(t, err)
 69		assert.True(t, result2.Continue, "Should allow safe commands")
 70	})
 71
 72	t.Run("auto-approve read-only tools", func(t *testing.T) {
 73		t.Parallel()
 74		tempDir := t.TempDir()
 75		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use")
 76		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
 77
 78		hookScript := `#!/bin/bash
 79case "$CRUSH_TOOL_NAME" in
 80  view|ls|grep|glob)
 81    crush_approve "Auto-approved read-only tool"
 82    ;;
 83  bash)
 84    COMMAND=$(crush_get_tool_input command)
 85    if [[ "$COMMAND" =~ ^(ls|cat|grep) ]]; then
 86      crush_approve "Auto-approved safe bash command"
 87    fi
 88    ;;
 89esac
 90`
 91		hookPath := filepath.Join(hooksDir, "01-auto-approve.sh")
 92		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
 93
 94		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
 95
 96		// Test: Should auto-approve view tool
 97		result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
 98			SessionID:  "test",
 99			WorkingDir: tempDir,
100			ToolName:   "view",
101			ToolCallID: "call-1",
102			Data:       map[string]any{},
103		})
104
105		require.NoError(t, err)
106		assert.True(t, result.Continue)
107		assert.Equal(t, "approve", result.Permission)
108		assert.Contains(t, result.Message, "Auto-approved read-only tool")
109
110		// Test: Should auto-approve safe bash commands
111		result2, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
112			SessionID:  "test",
113			WorkingDir: tempDir,
114			ToolName:   "bash",
115			ToolCallID: "call-2",
116			Data: map[string]any{
117				"tool_input": map[string]any{
118					"command": "ls -la",
119				},
120			},
121		})
122
123		require.NoError(t, err)
124		assert.True(t, result2.Continue)
125		assert.Equal(t, "approve", result2.Permission)
126		assert.Contains(t, result2.Message, "Auto-approved safe bash command")
127	})
128
129	t.Run("add git context", func(t *testing.T) {
130		t.Parallel()
131		tempDir := t.TempDir()
132		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "user-prompt-submit")
133		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
134
135		// Initialize git repo with a branch
136		gitDir := filepath.Join(tempDir, ".git")
137		require.NoError(t, os.MkdirAll(gitDir, 0o755))
138		require.NoError(t, os.WriteFile(filepath.Join(gitDir, "HEAD"), []byte("ref: refs/heads/main\n"), 0o644))
139
140		hookScript := `#!/bin/bash
141BRANCH=$(git branch --show-current 2>/dev/null)
142if [ -n "$BRANCH" ]; then
143  crush_add_context "Current branch: $BRANCH"
144fi
145
146if [ -f "README.md" ]; then
147  crush_add_context_file "README.md"
148fi
149`
150		hookPath := filepath.Join(hooksDir, "01-add-context.sh")
151		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
152
153		// Create README.md
154		readmePath := filepath.Join(tempDir, "README.md")
155		require.NoError(t, os.WriteFile(readmePath, []byte("# Test Project\n"), 0o644))
156
157		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
158
159		result, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{
160			SessionID:  "test",
161			WorkingDir: tempDir,
162			Data: map[string]any{
163				"prompt": "help me",
164			},
165		})
166
167		require.NoError(t, err)
168		assert.True(t, result.Continue)
169		// Should add context file (using relative path)
170		require.Len(t, result.ContextFiles, 1)
171		assert.Equal(t, "README.md", result.ContextFiles[0])
172	})
173
174	t.Run("audit logging", func(t *testing.T) {
175		t.Parallel()
176		tempDir := t.TempDir()
177		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "post-tool-use")
178		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
179
180		auditFile := filepath.Join(tempDir, "audit.log")
181		hookScript := `#!/bin/bash
182AUDIT_FILE="` + auditFile + `"
183TIMESTAMP=$(date -Iseconds)
184echo "$TIMESTAMP|$CRUSH_TOOL_NAME|$CRUSH_TOOL_CALL_ID" >> "$AUDIT_FILE"
185`
186		hookPath := filepath.Join(hooksDir, "01-audit.sh")
187		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
188
189		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
190
191		result, err := manager.ExecuteHooks(context.Background(), HookPostToolUse, HookContext{
192			SessionID:  "test",
193			WorkingDir: tempDir,
194			ToolName:   "bash",
195			ToolCallID: "call-123",
196			Data:       map[string]any{},
197		})
198
199		require.NoError(t, err)
200		assert.True(t, result.Continue)
201
202		// Verify audit log was written
203		content, err := os.ReadFile(auditFile)
204		require.NoError(t, err)
205		assert.Contains(t, string(content), "bash|call-123")
206	})
207
208	t.Run("catch-all hook", func(t *testing.T) {
209		t.Parallel()
210		tempDir := t.TempDir()
211		hooksDir := filepath.Join(tempDir, ".crush", "hooks")
212		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
213
214		logFile := filepath.Join(tempDir, "global.log")
215		hookScript := `#!/bin/bash
216echo "Hook: $CRUSH_HOOK_TYPE" >> "` + logFile + `"
217`
218		hookPath := filepath.Join(hooksDir, "00-global-log.sh")
219		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
220
221		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
222
223		// Test with different hook types
224		_, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
225			SessionID:  "test",
226			WorkingDir: tempDir,
227			Data:       map[string]any{},
228		})
229		require.NoError(t, err)
230
231		_, err = manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{
232			SessionID:  "test",
233			WorkingDir: tempDir,
234			Data:       map[string]any{},
235		})
236		require.NoError(t, err)
237
238		// Verify both hook types were logged
239		content, err := os.ReadFile(logFile)
240		require.NoError(t, err)
241		assert.Contains(t, string(content), "Hook: pre-tool-use")
242		assert.Contains(t, string(content), "Hook: user-prompt-submit")
243	})
244
245	t.Run("rate limiting", func(t *testing.T) {
246		t.Parallel()
247		tempDir := t.TempDir()
248		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use")
249		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
250
251		usageLog := filepath.Join(tempDir, "usage.log")
252		// Pre-populate with entries
253		today := "2024-01-15" // Fixed date for testing
254		for i := 0; i < 5; i++ {
255			f, err := os.OpenFile(usageLog, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
256			require.NoError(t, err)
257			_, err = f.WriteString(today + "\n")
258			require.NoError(t, err)
259			f.Close()
260		}
261
262		hookScript := `#!/bin/bash
263COUNT=$(grep -c "2024-01-15" "` + usageLog + `" 2>/dev/null || echo "0")
264if [ "$COUNT" -ge 3 ]; then
265  export CRUSH_CONTINUE=false
266  export CRUSH_MESSAGE="Rate limit exceeded"
267fi
268`
269		hookPath := filepath.Join(hooksDir, "01-rate-limit.sh")
270		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
271
272		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
273
274		result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
275			SessionID:  "test",
276			WorkingDir: tempDir,
277			Data:       map[string]any{},
278		})
279
280		require.NoError(t, err)
281		assert.False(t, result.Continue, "Should stop execution when rate limit exceeded")
282		assert.Contains(t, result.Message, "Rate limit exceeded")
283	})
284
285	t.Run("conditional context", func(t *testing.T) {
286		t.Parallel()
287		tempDir := t.TempDir()
288		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "user-prompt-submit")
289		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
290
291		// Create package.json
292		packageJSON := filepath.Join(tempDir, "package.json")
293		require.NoError(t, os.WriteFile(packageJSON, []byte(`{"name": "test"}`), 0o644))
294
295		hookScript := `#!/bin/bash
296if [ -f "package.json" ]; then
297  crush_add_context_file "package.json"
298fi
299`
300		hookPath := filepath.Join(hooksDir, "01-conditional.sh")
301		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
302
303		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
304
305		result, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{
306			SessionID:  "test",
307			WorkingDir: tempDir,
308			Data:       map[string]any{},
309		})
310
311		require.NoError(t, err)
312		assert.True(t, result.Continue)
313		require.Len(t, result.ContextFiles, 1)
314		assert.Equal(t, "package.json", result.ContextFiles[0])
315	})
316
317	t.Run("JSON output example", func(t *testing.T) {
318		t.Parallel()
319		tempDir := t.TempDir()
320		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use")
321		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
322
323		hookScript := `#!/bin/bash
324COMMAND=$(crush_get_tool_input command)
325SAFE_CMD=$(echo "$COMMAND" | sed 's/--force//')
326echo "{\"modified_input\": {\"command\": \"$SAFE_CMD\"}}"
327`
328		hookPath := filepath.Join(hooksDir, "01-modify.sh")
329		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
330
331		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
332
333		result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
334			SessionID:  "test",
335			WorkingDir: tempDir,
336			ToolName:   "bash",
337			ToolCallID: "call-1",
338			Data: map[string]any{
339				"tool_input": map[string]any{
340					"command": "rm --force file.txt",
341				},
342			},
343		})
344
345		require.NoError(t, err)
346		assert.True(t, result.Continue)
347		require.NotNil(t, result.ModifiedInput)
348		assert.Equal(t, "rm  file.txt", result.ModifiedInput["command"])
349	})
350
351	t.Run("environment variables example", func(t *testing.T) {
352		t.Parallel()
353		tempDir := t.TempDir()
354		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use")
355		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
356
357		hookScript := `#!/bin/bash
358export CRUSH_PERMISSION=approve
359export CRUSH_MESSAGE="Auto-approved"
360`
361		hookPath := filepath.Join(hooksDir, "01-env-vars.sh")
362		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
363
364		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
365
366		result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
367			SessionID:  "test",
368			WorkingDir: tempDir,
369			Data:       map[string]any{},
370		})
371
372		require.NoError(t, err)
373		assert.True(t, result.Continue)
374		assert.Equal(t, "approve", result.Permission)
375		assert.Equal(t, "Auto-approved", result.Message)
376	})
377
378	t.Run("exit codes example", func(t *testing.T) {
379		t.Parallel()
380		tempDir := t.TempDir()
381		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use")
382		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
383
384		usageLog := filepath.Join(tempDir, "usage.log")
385		// Create usage log with entries
386		for i := 0; i < 150; i++ {
387			f, err := os.OpenFile(usageLog, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
388			require.NoError(t, err)
389			_, err = f.WriteString("2024-01-15\n")
390			require.NoError(t, err)
391			f.Close()
392		}
393
394		hookScript := `#!/bin/bash
395COUNT=$(grep -c "2024-01-15" "` + usageLog + `")
396if [ "$COUNT" -gt 100 ]; then
397  echo "Rate limit exceeded" >&2
398  exit 2  # Stops execution
399fi
400`
401		hookPath := filepath.Join(hooksDir, "01-exit-code.sh")
402		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
403
404		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
405
406		result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
407			SessionID:  "test",
408			WorkingDir: tempDir,
409			Data:       map[string]any{},
410		})
411
412		require.NoError(t, err)
413		assert.False(t, result.Continue, "Exit code 2 should stop execution")
414	})
415
416	t.Run("helper functions comprehensive test", func(t *testing.T) {
417		t.Parallel()
418		tempDir := t.TempDir()
419		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "user-prompt-submit")
420		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
421
422		// Test all helper functions in one hook
423		hookScript := `#!/bin/bash
424# Read stdin once into variable
425CONTEXT=$(cat)
426
427# Test input parsing
428PROMPT=$(echo "$CONTEXT" | crush_get_prompt)
429MODEL=$(echo "$CONTEXT" | crush_get_input model)
430
431# Test context helpers
432crush_add_context "Using model: $MODEL"
433
434# Test logging
435crush_log "Processing prompt"
436
437# Test modification
438export CRUSH_MODIFIED_PROMPT="Enhanced: $PROMPT"
439`
440		hookPath := filepath.Join(hooksDir, "01-helpers.sh")
441		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
442
443		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
444
445		result, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{
446			SessionID:  "test",
447			WorkingDir: tempDir,
448			Data: map[string]any{
449				"prompt": "original prompt",
450				"model":  "gpt-4",
451			},
452		})
453
454		require.NoError(t, err)
455		assert.True(t, result.Continue)
456		assert.Contains(t, result.ContextContent, "Using model: gpt-4")
457		require.NotNil(t, result.ModifiedPrompt)
458		// Trim any trailing whitespace/CRLF for cross-platform compatibility
459		assert.Equal(t, "Enhanced: original prompt", strings.TrimSpace(*result.ModifiedPrompt))
460	})
461
462	t.Run("is_first_message flag", func(t *testing.T) {
463		t.Parallel()
464		tempDir := t.TempDir()
465		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "user-prompt-submit")
466		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
467
468		// Hook that adds README only on first message
469		hookScript := `#!/bin/bash
470IS_FIRST=$(crush_get_input is_first_message)
471if [ "$IS_FIRST" = "true" ]; then
472  crush_add_context "This is the first message"
473else
474  crush_add_context "This is a follow-up message"
475fi
476`
477		hookPath := filepath.Join(hooksDir, "01-first-msg.sh")
478		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
479
480		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
481
482		// Test: First message
483		result1, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{
484			SessionID:  "test",
485			WorkingDir: tempDir,
486			Data: map[string]any{
487				"prompt":           "first prompt",
488				"is_first_message": true,
489			},
490		})
491		require.NoError(t, err)
492		assert.Contains(t, result1.ContextContent, "This is the first message")
493
494		// Test: Follow-up message
495		result2, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{
496			SessionID:  "test",
497			WorkingDir: tempDir,
498			Data: map[string]any{
499				"prompt":           "follow-up prompt",
500				"is_first_message": false,
501			},
502		})
503		require.NoError(t, err)
504		assert.Contains(t, result2.ContextContent, "This is a follow-up message")
505	})
506}
507
508// TestReadmeQuickExamples tests the quick examples from the quick reference.
509func TestReadmeQuickExamples(t *testing.T) {
510	t.Parallel()
511
512	t.Run("hook ordering", func(t *testing.T) {
513		t.Parallel()
514		tempDir := t.TempDir()
515		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use")
516		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
517
518		// Create hooks with specific order
519		hook1 := `#!/bin/bash
520export CRUSH_MESSAGE="first"
521`
522		hook2 := `#!/bin/bash
523export CRUSH_MESSAGE="${CRUSH_MESSAGE:-}; second"
524`
525		hook3 := `#!/bin/bash
526export CRUSH_MESSAGE="${CRUSH_MESSAGE:-}; third"
527`
528
529		require.NoError(t, os.WriteFile(filepath.Join(hooksDir, "01-first.sh"), []byte(hook1), 0o755))
530		require.NoError(t, os.WriteFile(filepath.Join(hooksDir, "02-second.sh"), []byte(hook2), 0o755))
531		require.NoError(t, os.WriteFile(filepath.Join(hooksDir, "99-third.sh"), []byte(hook3), 0o755))
532
533		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
534
535		result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
536			SessionID:  "test",
537			WorkingDir: tempDir,
538			Data:       map[string]any{},
539		})
540
541		require.NoError(t, err)
542		// Messages should be merged in order
543		assert.Contains(t, result.Message, "first")
544		assert.Contains(t, result.Message, "second")
545		assert.Contains(t, result.Message, "third")
546	})
547
548	t.Run("mixed env vars and JSON", func(t *testing.T) {
549		t.Parallel()
550		tempDir := t.TempDir()
551		hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use")
552		require.NoError(t, os.MkdirAll(hooksDir, 0o755))
553
554		hookScript := `#!/bin/bash
555# Set via environment variable
556export CRUSH_PERMISSION=approve
557
558# Output via JSON
559echo '{"message": "Combined output", "modified_input": {"key": "value"}}'
560`
561		hookPath := filepath.Join(hooksDir, "01-mixed.sh")
562		require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755))
563
564		manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil)
565
566		result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{
567			SessionID:  "test",
568			WorkingDir: tempDir,
569			Data:       map[string]any{},
570		})
571
572		require.NoError(t, err)
573		assert.True(t, result.Continue)
574		assert.Equal(t, "approve", result.Permission)
575		assert.Equal(t, "Combined output", result.Message)
576		assert.Equal(t, "value", result.ModifiedInput["key"])
577	})
578}