From 927115cf068d22cd874a0c6a668ded76d856421e Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 20 Nov 2025 13:36:54 +0100 Subject: [PATCH] wip: initial hook needs review and testing, a lot of AI. --- internal/hooks/README.md | 567 +++++++++++++++++++++++++++++++ internal/hooks/config.go | 35 ++ internal/hooks/examples_test.go | 578 ++++++++++++++++++++++++++++++++ internal/hooks/executor.go | 101 ++++++ internal/hooks/executor_test.go | 395 ++++++++++++++++++++++ internal/hooks/helpers.sh | 121 +++++++ internal/hooks/manager.go | 285 ++++++++++++++++ internal/hooks/manager_test.go | 524 +++++++++++++++++++++++++++++ internal/hooks/parser.go | 183 ++++++++++ internal/hooks/parser_test.go | 416 +++++++++++++++++++++++ internal/hooks/types.go | 93 +++++ internal/shell/shell.go | 24 +- 12 files changed, 3314 insertions(+), 8 deletions(-) create mode 100644 internal/hooks/README.md create mode 100644 internal/hooks/config.go create mode 100644 internal/hooks/examples_test.go create mode 100644 internal/hooks/executor.go create mode 100644 internal/hooks/executor_test.go create mode 100644 internal/hooks/helpers.sh create mode 100644 internal/hooks/manager.go create mode 100644 internal/hooks/manager_test.go create mode 100644 internal/hooks/parser.go create mode 100644 internal/hooks/parser_test.go create mode 100644 internal/hooks/types.go diff --git a/internal/hooks/README.md b/internal/hooks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a7af873e96c12756d51c1874dc01c3dc6d28791b --- /dev/null +++ b/internal/hooks/README.md @@ -0,0 +1,567 @@ +# Hooks Package + +A Git-like hooks system for Crush that allows users to intercept and modify behavior at key points in the application lifecycle. + +## Overview + +The hooks package provides a flexible, shell-based system for customizing Crush behavior through executable scripts. Hooks can: + +- Add context to LLM requests +- Control tool execution permissions +- Modify prompts and tool parameters +- Audit and log activity +- Execute cleanup on shutdown + +### Cross-Platform Support + +The hooks system works on **Windows, macOS, and Linux**: + +- **Hook Files**: All hooks must be `.sh` files (shell scripts) +- **Shell Execution**: Uses Crush's internal POSIX shell emulator (`mvdan.cc/sh`) on all platforms +- **Hook Discovery**: + - **Unix/macOS**: `.sh` files must have execute permission (`chmod +x hook.sh`) + - **Windows**: `.sh` files are automatically recognized (no permission needed) +- **Path Separators**: Use forward slashes (`/`) in hook scripts for cross-platform compatibility + +**Example**: +```bash +# Works on Windows, macOS, and Linux +.crush/hooks/pre-tool-use/01-check.sh +``` + +## Quick Start + +### Creating a Hook + +1. Create an executable script in `.crush/hooks/{hook-type}/`: + +```bash +#!/bin/bash +# .crush/hooks/pre-tool-use/01-block-dangerous.sh + +if [ "$CRUSH_TOOL_NAME" = "bash" ]; then + COMMAND=$(crush_get_tool_input command) + if [[ "$COMMAND" =~ "rm -rf /" ]]; then + crush_deny "Blocked dangerous command" + fi +fi +``` + +2. Make it executable: + +```bash +chmod +x .crush/hooks/pre-tool-use/01-block-dangerous.sh +``` + +3. The hook will automatically execute when the event occurs. + +## Hook Types + +### 1. UserPromptSubmit + +**When**: After user submits prompt, before sending to LLM +**Use cases**: Add context, modify prompts, validate input +**Location**: `.crush/hooks/user-prompt-submit/` + +**Available data** (via stdin JSON): +- `prompt` - User's prompt text +- `attachments` - List of attached files +- `model` - Model name +- `is_first_message` - Boolean indicating if this is the first message in the conversation + +**Example**: +```bash +#!/bin/bash +# Add git context to every prompt, and README only for first message + +BRANCH=$(git branch --show-current 2>/dev/null) +if [ -n "$BRANCH" ]; then + crush_add_context "Current branch: $BRANCH" +fi + +# Only add README context for the first message to avoid repetition +IS_FIRST=$(crush_get_input is_first_message) +if [ "$IS_FIRST" = "true" ] && [ -f "README.md" ]; then + crush_add_context_file "README.md" +fi +``` + +### 2. PreToolUse + +**When**: After LLM requests tool use, before permission check & execution +**Use cases**: Auto-approve, deny dangerous commands, audit requests +**Location**: `.crush/hooks/pre-tool-use/` + +**Available data** (via stdin JSON): +- `tool_input` - Tool parameters (object) + +**Environment variables**: +- `$CRUSH_TOOL_NAME` - Name of the tool being called +- `$CRUSH_TOOL_CALL_ID` - Unique ID for this tool call + +**Example**: +```bash +#!/bin/bash +# Auto-approve read-only tools and modify parameters + +case "$CRUSH_TOOL_NAME" in + view|ls|grep|glob) + crush_approve "Auto-approved read-only tool" + ;; + bash) + COMMAND=$(crush_get_tool_input command) + if [[ "$COMMAND" =~ ^(ls|cat|grep) ]]; then + crush_approve "Auto-approved safe bash command" + fi + ;; + view) + # Limit file reads to 1000 lines max for performance + crush_modify_input "limit" "1000" + ;; +esac +``` + +### 3. PostToolUse + +**When**: After tool executes, before result sent to LLM +**Use cases**: Filter output, redact secrets, log results +**Location**: `.crush/hooks/post-tool-use/` + +**Available data** (via stdin JSON): +- `tool_input` - Tool parameters (object) +- `tool_output` - Tool result (object with `success`, `content`) +- `execution_time_ms` - How long the tool took + +**Environment variables**: +- `$CRUSH_TOOL_NAME` - Name of the tool +- `$CRUSH_TOOL_CALL_ID` - Unique ID for this tool call + +**Example**: +```bash +#!/bin/bash +# Redact sensitive information from tool output + +# Get tool output using helper (stdin is automatically available) +OUTPUT_CONTENT=$(crush_get_input tool_output | jq -r '.content // empty') + +# Check if output contains sensitive patterns +if echo "$OUTPUT_CONTENT" | grep -qE '(password|api[_-]?key|secret|token)'; then + # Redact sensitive data + REDACTED=$(echo "$OUTPUT_CONTENT" | sed -E 's/(password|api[_-]?key|secret|token)[[:space:]]*[:=][[:space:]]*[^[:space:]]+/\1=\[REDACTED\]/gi') + crush_modify_output "content" "$REDACTED" + crush_log "Redacted sensitive information from $CRUSH_TOOL_NAME output" +fi +``` + +### 4. Stop + +**When**: When agent conversation loop stops or is cancelled +**Use cases**: Save conversation state, cleanup session resources, archive logs +**Location**: `.crush/hooks/stop/` + +**Available data** (via stdin JSON): +- `reason` - Why the loop stopped (e.g., "completed", "cancelled", "error") +- `session_id` - The session ID that stopped + +**Example**: +```bash +#!/bin/bash +# Save conversation summary when agent loop stops + +REASON=$(crush_get_input reason) +SESSION_ID=$(crush_get_input session_id) + +# Archive session logs +if [ -f ".crush/session-$SESSION_ID.log" ]; then + ARCHIVE="logs/session-$SESSION_ID-$(date +%Y%m%d-%H%M%S).log" + mkdir -p logs + mv ".crush/session-$SESSION_ID.log" "$ARCHIVE" + gzip "$ARCHIVE" + crush_log "Archived session logs: $ARCHIVE.gz (reason: $REASON)" +fi +``` + +## Catch-All Hooks + +Place hooks at the **root level** (`.crush/hooks/*.sh`) to run for **ALL hook types**: + +```bash +#!/bin/bash +# .crush/hooks/00-global-log.sh +# This runs for every hook type + +echo "[$CRUSH_HOOK_TYPE] Session: $CRUSH_SESSION_ID" >> global.log +``` + +**Execution order**: +1. Catch-all hooks (alphabetically sorted) +2. Type-specific hooks (alphabetically sorted) + +Use `$CRUSH_HOOK_TYPE` to determine which event triggered the hook. + +## Helper Functions + +All hooks have access to these built-in functions (no sourcing required): + +### Permission Helpers + +#### `crush_approve [message]` +Approve the current tool call (PreToolUse only). + +```bash +crush_approve "Auto-approved read-only command" +``` + +#### `crush_deny [message]` +Deny the current tool call and stop execution (PreToolUse only). + +```bash +crush_deny "Blocked dangerous operation" +# Script exits immediately with code 2 +``` + +#### `crush_ask [message]` +Ask user for permission (default behavior). + +```bash +crush_ask "This command modifies files, please review" +``` + +### Context Helpers + +#### `crush_add_context "content"` +Add raw text content to LLM context. + +```bash +crush_add_context "Project uses React 18 with TypeScript" +``` + +#### `crush_add_context_file "path"` +Load a file and add its content to LLM context. + +```bash +crush_add_context_file "docs/ARCHITECTURE.md" +crush_add_context_file "package.json" +``` + +### Modification Helpers + +#### `crush_modify_prompt "new_prompt"` +Replace the user's prompt (UserPromptSubmit only). + +```bash +PROMPT=$(crush_get_prompt) +MODIFIED="$PROMPT\n\nNote: Always use TypeScript." +crush_modify_prompt "$MODIFIED" +``` + +#### `crush_modify_input "param_name" "value"` +Modify tool input parameters (PreToolUse only). + +Values are parsed as JSON when valid, supporting all JSON types (strings, numbers, booleans, arrays, objects). + +```bash +# Strings (no quotes needed for simple strings) +crush_modify_input "command" "ls -la" +crush_modify_input "working_dir" "/tmp" + +# Numbers (parsed as JSON) +crush_modify_input "offset" "100" +crush_modify_input "limit" "50" + +# Booleans (parsed as JSON) +crush_modify_input "run_in_background" "true" +crush_modify_input "replace_all" "false" + +# Arrays (JSON format) +crush_modify_input "ignore" '["*.log","*.tmp"]' + +# Quoted strings (for strings with spaces or special chars) +crush_modify_input "message" '"hello world"' +``` + +#### `crush_modify_output "field_name" "value"` +Modify tool output before sending to LLM (PostToolUse only). + +```bash +# Redact sensitive information from tool output content +crush_modify_output "content" "[REDACTED - sensitive data removed]" + +# Can also modify other fields in the tool_output object +crush_modify_output "success" "false" +``` + +#### `crush_stop [message]` +Stop execution immediately. + +```bash +if [ "$(date +%H)" -lt 9 ]; then + crush_stop "Crush is only available during business hours" +fi +``` + +### Input Parsing Helpers + +Hooks receive JSON context via stdin, which is automatically saved and available to all helper functions. You can call multiple helpers without manually reading stdin first. + +#### `crush_get_input "field_name"` +Get a top-level field from the hook context. + +```bash +# Can call multiple times without saving stdin +PROMPT=$(crush_get_input prompt) +MODEL=$(crush_get_input model) +``` + +#### `crush_get_tool_input "parameter"` +Get a tool parameter (PreToolUse/PostToolUse only). + +```bash +# Can call multiple times without saving stdin +COMMAND=$(crush_get_tool_input command) +FILE_PATH=$(crush_get_tool_input file_path) +``` + +#### `crush_get_prompt` +Get the user's prompt (UserPromptSubmit only). + +```bash +PROMPT=$(crush_get_prompt) +if [[ "$PROMPT" =~ "password" ]]; then + crush_stop "Never include passwords in prompts" +fi +``` + +### Logging Helper + +#### `crush_log "message"` +Write to Crush's log (stderr). + +```bash +crush_log "Processing hook for tool: $CRUSH_TOOL_NAME" +``` + +## Environment Variables + +All hooks have access to these environment variables: + +### Always Available +- `$CRUSH_HOOK_TYPE` - Type of hook: `user-prompt-submit`, `pre-tool-use`, `post-tool-use`, `stop` +- `$CRUSH_SESSION_ID` - Current session ID +- `$CRUSH_WORKING_DIR` - Working directory + +### Tool Hooks (PreToolUse, PostToolUse) +- `$CRUSH_TOOL_NAME` - Name of the tool being called +- `$CRUSH_TOOL_CALL_ID` - Unique ID for this tool call + +## Result Communication + +Hooks communicate results back to Crush in two ways: + +### 1. Environment Variables (Simple) + +Export variables to set hook results: + +```bash +export CRUSH_PERMISSION=approve +export CRUSH_MESSAGE="Auto-approved" +export CRUSH_CONTINUE=false +export CRUSH_CONTEXT_CONTENT="Additional context" +export CRUSH_CONTEXT_FILES="/path/to/file1.md:/path/to/file2.md" +``` + +**Available variables**: +- `CRUSH_PERMISSION` - `approve`, `ask`, or `deny` +- `CRUSH_MESSAGE` - User-facing message +- `CRUSH_CONTINUE` - `true` or `false` (stop execution) +- `CRUSH_MODIFIED_PROMPT` - New prompt text +- `CRUSH_MODIFIED_INPUT` - Modified tool input (format: `key=value:key2=value2`, values parsed as JSON) +- `CRUSH_MODIFIED_OUTPUT` - Modified tool output (format: `key=value:key2=value2`, values parsed as JSON) +- `CRUSH_CONTEXT_CONTENT` - Text to add to LLM context +- `CRUSH_CONTEXT_FILES` - Colon-separated file paths + +**Note**: `CRUSH_MODIFIED_INPUT` and `CRUSH_MODIFIED_OUTPUT` use `:` as delimiter between pairs. For complex values with multiple fields or nested structures, use JSON output instead (see below). + +### 2. JSON Output (Complex) + +Echo JSON to stdout for complex modifications: + +```bash +echo '{ + "permission": "approve", + "message": "Modified command", + "modified_input": { + "command": "ls -la --color=auto" + }, + "context_content": "Added context" +}' +``` + +**JSON fields**: +- `continue` (bool) - Continue execution +- `permission` (string) - `approve`, `ask`, `deny` +- `message` (string) - User-facing message +- `modified_prompt` (string) - New prompt +- `modified_input` (object) - Modified tool parameters +- `modified_output` (object) - Modified tool results +- `context_content` (string) - Context to add +- `context_files` (array) - File paths to load + +**Note**: Environment variables and JSON output are merged automatically. + +## Exit Codes + +- **0** - Success, continue execution +- **1** - Error (PreToolUse: denies permission, others: logs and continues) +- **2** - Deny/stop execution (sets `Continue=false`) + +```bash +# Example: Check rate limit +COUNT=$(grep -c "$(date +%Y-%m-%d)" usage.log) +if [ "$COUNT" -gt 100 ]; then + echo "Rate limit exceeded" >&2 + exit 2 # Stops execution +fi +``` + +## Hook Ordering + +Hooks execute **sequentially** in alphabetical order. Use numeric prefixes to control order: + +``` +.crush/hooks/ + 00-global-log.sh # Catch-all: runs first for all types + pre-tool-use/ + 01-rate-limit.sh # Runs first + 02-auto-approve.sh # Runs second + 99-audit.sh # Runs last +``` + +## Result Merging + +When multiple hooks execute, their results are merged: + +### Permission (Most Restrictive Wins) +- `deny` > `ask` > `approve` +- If any hook denies, the final result is deny + +### Continue (AND Logic) +- All hooks must set `Continue=true` (or not set it) +- If any hook sets `Continue=false`, execution stops + +### Context (Append) +- Context content from all hooks is concatenated +- Context files from all hooks are combined + +### Messages (Append) +- Messages are joined with `; ` separator + +### Modified Fields (Last Wins) +- Modified prompt: last hook's value wins +- Modified input/output: maps are merged, last value wins for conflicts + +## Configuration + +Configure hooks in `crush.json`: + +```json +{ + "hooks": { + "enabled": true, + "timeout_seconds": 30, + "directories": [ + "/path/to/custom/hooks", + ".crush/hooks" + ], + "disabled": [ + "pre-tool-use/slow-check.sh", + "user-prompt-submit/verbose.sh" + ], + "environment": { + "CUSTOM_VAR": "value" + }, + "inline": { + "pre-tool-use": [{ + "name": "rate-limit", + "script": "#!/bin/bash\n# Inline hook script here..." + }] + } + } +} +``` + +### Configuration Options + +- **enabled** (bool) - Enable/disable the entire hooks system (default: `true`) +- **timeout_seconds** (int) - Maximum execution time per hook (default: `30`) +- **directories** ([]string) - Additional directories to search for hooks +- **disabled** ([]string) - List of hook paths to skip (relative to hooks directory) +- **environment** (map) - Environment variables to pass to all hooks +- **inline** (map) - Hooks defined directly in config (by hook type) + +## Best Practices + +### 1. Keep Hooks Fast +Hooks run synchronously. Keep them under 1 second to avoid slowing down the UI. + +```bash +# Bad: Slow network call +curl -X POST https://api.example.com/log + +# Good: Log locally, sync in background +echo "$LOG_ENTRY" >> audit.log +``` + +### 2. Handle Errors Gracefully +Don't let hooks crash. Use error handling: + +```bash +BRANCH=$(git branch --show-current 2>/dev/null) +if [ -n "$BRANCH" ]; then + crush_add_context "Branch: $BRANCH" +fi +``` + +### 3. Use Descriptive Names +Use numeric prefixes and descriptive names: + +```bash +01-security-check.sh # Good +99-audit-log.sh # Good +hook.sh # Bad +``` + +### 4. Test Hooks Independently +Run hooks manually to test: + +```bash +export CRUSH_HOOK_TYPE=pre-tool-use +export CRUSH_TOOL_NAME=bash +echo '{"tool_input":{"command":"rm -rf /"}}' | .crush/hooks/pre-tool-use/01-block-dangerous.sh +echo "Exit code: $?" +``` + +### 5. Log for Debugging +Use `crush_log` to debug hook execution: + +```bash +crush_log "Checking command: $COMMAND" +if [[ "$COMMAND" =~ "dangerous" ]]; then + crush_log "Blocking dangerous command" + crush_deny "Command blocked" +fi +``` + +### 6. Don't Block on I/O +Avoid blocking operations: + +```bash +# Bad: Waits for user input +read -p "Continue? " answer + +# Bad: Long-running process +./expensive-analysis.sh + +# Good: Quick checks +[ -f ".allowed" ] && crush_approve +``` diff --git a/internal/hooks/config.go b/internal/hooks/config.go new file mode 100644 index 0000000000000000000000000000000000000000..bc04858e3634660db61f8ada9d3e17166d9e9149 --- /dev/null +++ b/internal/hooks/config.go @@ -0,0 +1,35 @@ +package hooks + +// Config defines hook system configuration. +type Config struct { + // Enabled controls whether hooks are executed. + Enabled bool + + // TimeoutSeconds is the maximum time a hook can run. + TimeoutSeconds int + + // Directories are additional directories to search for hooks. + // Defaults to [".crush/hooks"] if empty. + Directories []string + + // Inline hooks defined directly in configuration. + // Map key is the hook type (e.g., "pre-tool-use"). + Inline map[string][]InlineHook + + // Disabled is a list of hook paths to skip. + // Paths are relative to the hooks directory. + // Example: ["pre-tool-use/02-slow-check.sh"] + Disabled []string + + // Environment variables to pass to hooks. + Environment map[string]string +} + +// InlineHook is a hook defined inline in the config. +type InlineHook struct { + // Name is the name of the hook (used as filename). + Name string + + // Script is the bash script content. + Script string +} diff --git a/internal/hooks/examples_test.go b/internal/hooks/examples_test.go new file mode 100644 index 0000000000000000000000000000000000000000..77651aee1ac5f14516c42c590a1b474f028294dc --- /dev/null +++ b/internal/hooks/examples_test.go @@ -0,0 +1,578 @@ +package hooks + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestReadmeExamples tests that all examples from the README work as documented. +func TestReadmeExamples(t *testing.T) { + t.Parallel() + + t.Run("block dangerous commands", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + hookScript := `#!/bin/bash +if [ "$CRUSH_TOOL_NAME" = "bash" ]; then + COMMAND=$(crush_get_tool_input command) + if [[ "$COMMAND" =~ "rm -rf /" ]]; then + crush_deny "Blocked dangerous command" + fi +fi +` + hookPath := filepath.Join(hooksDir, "01-block-dangerous.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + // Test: Should block "rm -rf /" + result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + ToolName: "bash", + ToolCallID: "call-1", + Data: map[string]any{ + "tool_input": map[string]any{ + "command": "rm -rf /", + }, + }, + }) + + require.NoError(t, err) + assert.False(t, result.Continue, "Should stop execution for dangerous command") + assert.Equal(t, "deny", result.Permission) + assert.Contains(t, result.Message, "Blocked dangerous command") + + // Test: Should allow safe commands + result2, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + ToolName: "bash", + ToolCallID: "call-2", + Data: map[string]any{ + "tool_input": map[string]any{ + "command": "ls -la", + }, + }, + }) + + require.NoError(t, err) + assert.True(t, result2.Continue, "Should allow safe commands") + }) + + t.Run("auto-approve read-only tools", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + hookScript := `#!/bin/bash +case "$CRUSH_TOOL_NAME" in + view|ls|grep|glob) + crush_approve "Auto-approved read-only tool" + ;; + bash) + COMMAND=$(crush_get_tool_input command) + if [[ "$COMMAND" =~ ^(ls|cat|grep) ]]; then + crush_approve "Auto-approved safe bash command" + fi + ;; +esac +` + hookPath := filepath.Join(hooksDir, "01-auto-approve.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + // Test: Should auto-approve view tool + result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + ToolName: "view", + ToolCallID: "call-1", + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + assert.Equal(t, "approve", result.Permission) + assert.Contains(t, result.Message, "Auto-approved read-only tool") + + // Test: Should auto-approve safe bash commands + result2, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + ToolName: "bash", + ToolCallID: "call-2", + Data: map[string]any{ + "tool_input": map[string]any{ + "command": "ls -la", + }, + }, + }) + + require.NoError(t, err) + assert.True(t, result2.Continue) + assert.Equal(t, "approve", result2.Permission) + assert.Contains(t, result2.Message, "Auto-approved safe bash command") + }) + + t.Run("add git context", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "user-prompt-submit") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Initialize git repo with a branch + gitDir := filepath.Join(tempDir, ".git") + require.NoError(t, os.MkdirAll(gitDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(gitDir, "HEAD"), []byte("ref: refs/heads/main\n"), 0o644)) + + hookScript := `#!/bin/bash +BRANCH=$(git branch --show-current 2>/dev/null) +if [ -n "$BRANCH" ]; then + crush_add_context "Current branch: $BRANCH" +fi + +if [ -f "README.md" ]; then + crush_add_context_file "README.md" +fi +` + hookPath := filepath.Join(hooksDir, "01-add-context.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + // Create README.md + readmePath := filepath.Join(tempDir, "README.md") + require.NoError(t, os.WriteFile(readmePath, []byte("# Test Project\n"), 0o644)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{ + "prompt": "help me", + }, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + // Should add context file (using relative path) + require.Len(t, result.ContextFiles, 1) + assert.Equal(t, "README.md", result.ContextFiles[0]) + }) + + t.Run("audit logging", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "post-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + auditFile := filepath.Join(tempDir, "audit.log") + hookScript := `#!/bin/bash +AUDIT_FILE="` + auditFile + `" +TIMESTAMP=$(date -Iseconds) +echo "$TIMESTAMP|$CRUSH_TOOL_NAME|$CRUSH_TOOL_CALL_ID" >> "$AUDIT_FILE" +` + hookPath := filepath.Join(hooksDir, "01-audit.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookPostToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + ToolName: "bash", + ToolCallID: "call-123", + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + + // Verify audit log was written + content, err := os.ReadFile(auditFile) + require.NoError(t, err) + assert.Contains(t, string(content), "bash|call-123") + }) + + t.Run("catch-all hook", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + logFile := filepath.Join(tempDir, "global.log") + hookScript := `#!/bin/bash +echo "Hook: $CRUSH_HOOK_TYPE" >> "` + logFile + `" +` + hookPath := filepath.Join(hooksDir, "00-global-log.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + // Test with different hook types + _, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + require.NoError(t, err) + + _, err = manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + require.NoError(t, err) + + // Verify both hook types were logged + content, err := os.ReadFile(logFile) + require.NoError(t, err) + assert.Contains(t, string(content), "Hook: pre-tool-use") + assert.Contains(t, string(content), "Hook: user-prompt-submit") + }) + + t.Run("rate limiting", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + usageLog := filepath.Join(tempDir, "usage.log") + // Pre-populate with entries + today := "2024-01-15" // Fixed date for testing + for i := 0; i < 5; i++ { + f, err := os.OpenFile(usageLog, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + require.NoError(t, err) + _, err = f.WriteString(today + "\n") + require.NoError(t, err) + f.Close() + } + + hookScript := `#!/bin/bash +COUNT=$(grep -c "2024-01-15" "` + usageLog + `" 2>/dev/null || echo "0") +if [ "$COUNT" -ge 3 ]; then + export CRUSH_CONTINUE=false + export CRUSH_MESSAGE="Rate limit exceeded" +fi +` + hookPath := filepath.Join(hooksDir, "01-rate-limit.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.False(t, result.Continue, "Should stop execution when rate limit exceeded") + assert.Contains(t, result.Message, "Rate limit exceeded") + }) + + t.Run("conditional context", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "user-prompt-submit") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Create package.json + packageJSON := filepath.Join(tempDir, "package.json") + require.NoError(t, os.WriteFile(packageJSON, []byte(`{"name": "test"}`), 0o644)) + + hookScript := `#!/bin/bash +if [ -f "package.json" ]; then + crush_add_context_file "package.json" +fi +` + hookPath := filepath.Join(hooksDir, "01-conditional.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + require.Len(t, result.ContextFiles, 1) + assert.Equal(t, "package.json", result.ContextFiles[0]) + }) + + t.Run("JSON output example", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + hookScript := `#!/bin/bash +COMMAND=$(crush_get_tool_input command) +SAFE_CMD=$(echo "$COMMAND" | sed 's/--force//') +echo "{\"modified_input\": {\"command\": \"$SAFE_CMD\"}}" +` + hookPath := filepath.Join(hooksDir, "01-modify.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + ToolName: "bash", + ToolCallID: "call-1", + Data: map[string]any{ + "tool_input": map[string]any{ + "command": "rm --force file.txt", + }, + }, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + require.NotNil(t, result.ModifiedInput) + assert.Equal(t, "rm file.txt", result.ModifiedInput["command"]) + }) + + t.Run("environment variables example", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + hookScript := `#!/bin/bash +export CRUSH_PERMISSION=approve +export CRUSH_MESSAGE="Auto-approved" +` + hookPath := filepath.Join(hooksDir, "01-env-vars.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + assert.Equal(t, "approve", result.Permission) + assert.Equal(t, "Auto-approved", result.Message) + }) + + t.Run("exit codes example", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + usageLog := filepath.Join(tempDir, "usage.log") + // Create usage log with entries + for i := 0; i < 150; i++ { + f, err := os.OpenFile(usageLog, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + require.NoError(t, err) + _, err = f.WriteString("2024-01-15\n") + require.NoError(t, err) + f.Close() + } + + hookScript := `#!/bin/bash +COUNT=$(grep -c "2024-01-15" "` + usageLog + `") +if [ "$COUNT" -gt 100 ]; then + echo "Rate limit exceeded" >&2 + exit 2 # Stops execution +fi +` + hookPath := filepath.Join(hooksDir, "01-exit-code.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.False(t, result.Continue, "Exit code 2 should stop execution") + }) + + t.Run("helper functions comprehensive test", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "user-prompt-submit") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Test all helper functions in one hook + hookScript := `#!/bin/bash +# Read stdin once into variable +CONTEXT=$(cat) + +# Test input parsing +PROMPT=$(echo "$CONTEXT" | crush_get_prompt) +MODEL=$(echo "$CONTEXT" | crush_get_input model) + +# Test context helpers +crush_add_context "Using model: $MODEL" + +# Test logging +crush_log "Processing prompt" + +# Test modification +export CRUSH_MODIFIED_PROMPT="Enhanced: $PROMPT" +` + hookPath := filepath.Join(hooksDir, "01-helpers.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{ + "prompt": "original prompt", + "model": "gpt-4", + }, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + assert.Contains(t, result.ContextContent, "Using model: gpt-4") + require.NotNil(t, result.ModifiedPrompt) + // Trim any trailing whitespace/CRLF for cross-platform compatibility + assert.Equal(t, "Enhanced: original prompt", strings.TrimSpace(*result.ModifiedPrompt)) + }) + + t.Run("is_first_message flag", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "user-prompt-submit") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Hook that adds README only on first message + hookScript := `#!/bin/bash +IS_FIRST=$(crush_get_input is_first_message) +if [ "$IS_FIRST" = "true" ]; then + crush_add_context "This is the first message" +else + crush_add_context "This is a follow-up message" +fi +` + hookPath := filepath.Join(hooksDir, "01-first-msg.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + // Test: First message + result1, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{ + "prompt": "first prompt", + "is_first_message": true, + }, + }) + require.NoError(t, err) + assert.Contains(t, result1.ContextContent, "This is the first message") + + // Test: Follow-up message + result2, err := manager.ExecuteHooks(context.Background(), HookUserPromptSubmit, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{ + "prompt": "follow-up prompt", + "is_first_message": false, + }, + }) + require.NoError(t, err) + assert.Contains(t, result2.ContextContent, "This is a follow-up message") + }) +} + +// TestReadmeQuickExamples tests the quick examples from the quick reference. +func TestReadmeQuickExamples(t *testing.T) { + t.Parallel() + + t.Run("hook ordering", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Create hooks with specific order + hook1 := `#!/bin/bash +export CRUSH_MESSAGE="first" +` + hook2 := `#!/bin/bash +export CRUSH_MESSAGE="${CRUSH_MESSAGE:-}; second" +` + hook3 := `#!/bin/bash +export CRUSH_MESSAGE="${CRUSH_MESSAGE:-}; third" +` + + require.NoError(t, os.WriteFile(filepath.Join(hooksDir, "01-first.sh"), []byte(hook1), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(hooksDir, "02-second.sh"), []byte(hook2), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(hooksDir, "99-third.sh"), []byte(hook3), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + // Messages should be merged in order + assert.Contains(t, result.Message, "first") + assert.Contains(t, result.Message, "second") + assert.Contains(t, result.Message, "third") + }) + + t.Run("mixed env vars and JSON", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + hooksDir := filepath.Join(tempDir, ".crush", "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + hookScript := `#!/bin/bash +# Set via environment variable +export CRUSH_PERMISSION=approve + +# Output via JSON +echo '{"message": "Combined output", "modified_input": {"key": "value"}}' +` + hookPath := filepath.Join(hooksDir, "01-mixed.sh") + require.NoError(t, os.WriteFile(hookPath, []byte(hookScript), 0o755)) + + manager := NewManager(tempDir, filepath.Join(tempDir, ".crush"), nil) + + result, err := manager.ExecuteHooks(context.Background(), HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + assert.Equal(t, "approve", result.Permission) + assert.Equal(t, "Combined output", result.Message) + assert.Equal(t, "value", result.ModifiedInput["key"]) + }) +} diff --git a/internal/hooks/executor.go b/internal/hooks/executor.go new file mode 100644 index 0000000000000000000000000000000000000000..c037370138e4e3b439e2f607ed0777748247dc9c --- /dev/null +++ b/internal/hooks/executor.go @@ -0,0 +1,101 @@ +package hooks + +import ( + "context" + _ "embed" + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/charmbracelet/crush/internal/shell" +) + +//go:embed helpers.sh +var helpersScript string + +// Executor executes individual hook scripts. +type Executor struct { + workingDir string +} + +// NewExecutor creates a new hook executor. +func NewExecutor(workingDir string) *Executor { + return &Executor{workingDir: workingDir} +} + +// Execute runs a single hook script and returns the result. +func (e *Executor) Execute(ctx context.Context, hookPath string, context HookContext) (*HookResult, error) { + hookScript, err := os.ReadFile(hookPath) + if err != nil { + return nil, fmt.Errorf("failed to read hook: %w", err) + } + + contextJSON, err := json.Marshal(context.Data) + if err != nil { + return nil, fmt.Errorf("failed to marshal context: %w", err) + } + + // Wrap user hook in a function and prepend helper functions + // Read stdin before calling the function, then export it + fullScript := fmt.Sprintf(`%s + +# Save stdin to variable before entering function +_CRUSH_STDIN=$(cat) +export _CRUSH_STDIN + +_crush_hook_main() { +%s +} + +_crush_hook_main +`, helpersScript, string(hookScript)) + + env := append(os.Environ(), + "CRUSH_HOOK_TYPE="+string(context.HookType), + "CRUSH_SESSION_ID="+context.SessionID, + "CRUSH_WORKING_DIR="+context.WorkingDir, + ) + + if context.ToolName != "" { + env = append(env, + "CRUSH_TOOL_NAME="+context.ToolName, + "CRUSH_TOOL_CALL_ID="+context.ToolCallID, + ) + } + + for k, v := range context.Environment { + env = append(env, k+"="+v) + } + + hookShell := shell.NewShell(&shell.Options{ + WorkingDir: context.WorkingDir, + Env: env, + }) + + // Pass JSON context via stdin instead of heredoc + stdin := strings.NewReader(string(contextJSON)) + stdout, stderr, err := hookShell.ExecWithStdin(ctx, fullScript, stdin) + + result := parseShellEnv(hookShell.GetEnv()) + exitCode := shell.ExitCode(err) + switch exitCode { + case 2: + result.Continue = false + case 1: + return nil, fmt.Errorf("hook failed with exit code 1: %w\nstderr: %s", err, stderr) + } + + if trimmed := strings.TrimSpace(stdout); len(trimmed) > 0 && trimmed[0] == '{' { + if jsonResult, parseErr := parseJSONResult([]byte(trimmed)); parseErr == nil { + mergeJSONResult(result, jsonResult) + } + } + + return result, nil +} + +// GetHelpersScript returns the embedded helper script for display. +func GetHelpersScript() string { + return helpersScript +} diff --git a/internal/hooks/executor_test.go b/internal/hooks/executor_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7b9aa5e8c11a461aa9de2302c81523b12098112d --- /dev/null +++ b/internal/hooks/executor_test.go @@ -0,0 +1,395 @@ +package hooks + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecutor(t *testing.T) { + // Create temp directory for test hooks. + tempDir := t.TempDir() + + t.Run("executes simple hook with env vars", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "test-hook.sh") + hookScript := `#!/bin/bash +export CRUSH_PERMISSION=approve +export CRUSH_MESSAGE="test message" +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{ + "tool_input": map[string]any{ + "command": "ls", + }, + }, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + assert.True(t, result.Continue) + assert.Equal(t, "approve", result.Permission) + assert.Equal(t, "test message", result.Message) + }) + + t.Run("helper functions are available", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "helper-test.sh") + hookScript := `#!/bin/bash +crush_approve "auto approved" +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + assert.Equal(t, "approve", result.Permission) + assert.Equal(t, "auto approved", result.Message) + }) + + t.Run("crush_deny sets continue=false and exits", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "deny-test.sh") + hookScript := `#!/bin/bash +crush_deny "blocked" +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + assert.False(t, result.Continue) + assert.Equal(t, "deny", result.Permission) + assert.Equal(t, "blocked", result.Message) + }) + + t.Run("reads JSON from stdin", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "stdin-test.sh") + hookScript := `#!/bin/bash +COMMAND=$(crush_get_tool_input command) +if [ "$COMMAND" = "dangerous" ]; then + crush_deny "dangerous command" +fi +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{ + "tool_input": map[string]any{ + "command": "dangerous", + }, + }, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + assert.False(t, result.Continue) + assert.Equal(t, "deny", result.Permission) + }) + + t.Run("env variables are set correctly", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "env-test.sh") + hookScript := `#!/bin/bash +if [ "$CRUSH_HOOK_TYPE" = "pre-tool-use" ] && \ + [ "$CRUSH_SESSION_ID" = "test-123" ] && \ + [ "$CRUSH_TOOL_NAME" = "bash" ]; then + export CRUSH_MESSAGE="env vars correct" +fi +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-123", + WorkingDir: tempDir, + ToolName: "bash", + ToolCallID: "call-123", + Data: map[string]any{}, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + assert.Equal(t, "env vars correct", result.Message) + }) + + t.Run("supports JSON output for complex mutations", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "json-test.sh") + hookScript := `#!/bin/bash +cat <&2 +exit 1 +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + } + + _, err = executor.Execute(ctx, hookPath, hookCtx) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "hook failed with exit code 1") + }) + + t.Run("context files helper", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "files-test.sh") + hookScript := `#!/bin/bash +crush_add_context_file "file1.md" +crush_add_context_file "file2.txt" +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookUserPromptSubmit, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + assert.Equal(t, []string{"file1.md", "file2.txt"}, result.ContextFiles) + }) + + t.Run("context content helper", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "content-test.sh") + hookScript := `#!/bin/bash +crush_add_context "This is additional context" +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookUserPromptSubmit, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + assert.Equal(t, "This is additional context", result.ContextContent) + }) + + t.Run("returns error if hook file doesn't exist", func(t *testing.T) { + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + } + + _, err := executor.Execute(ctx, "/nonexistent/hook.sh", hookCtx) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to read hook") + }) + + t.Run("passes custom environment variables", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "custom-env-test.sh") + hookScript := `#!/bin/bash +if [ "$CUSTOM_API_KEY" = "secret123" ] && [ "$CUSTOM_REGION" = "us-west-2" ]; then + export CRUSH_MESSAGE="custom env vars set correctly" +fi +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + Environment: map[string]string{ + "CUSTOM_API_KEY": "secret123", + "CUSTOM_REGION": "us-west-2", + }, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + assert.Equal(t, "custom env vars set correctly", result.Message) + }) + + t.Run("modify input helper function", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "modify-input-test.sh") + hookScript := `#!/bin/bash +crush_modify_input "command" "ls -la" +crush_modify_input "working_dir" "/tmp" +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + require.NotNil(t, result.ModifiedInput) + assert.Equal(t, "ls -la", result.ModifiedInput["command"]) + assert.Equal(t, "/tmp", result.ModifiedInput["working_dir"]) + }) + + t.Run("modify output helper function", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "modify-output-test.sh") + hookScript := `#!/bin/bash +crush_modify_output "status" "redacted" +crush_modify_output "data" "[REDACTED]" +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPostToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + require.NotNil(t, result.ModifiedOutput) + assert.Equal(t, "redacted", result.ModifiedOutput["status"]) + assert.Equal(t, "[REDACTED]", result.ModifiedOutput["data"]) + }) + + t.Run("modify input with JSON types", func(t *testing.T) { + hookPath := filepath.Join(tempDir, "modify-input-json-test.sh") + hookScript := `#!/bin/bash +crush_modify_input "offset" "100" +crush_modify_input "limit" "50" +crush_modify_input "run_in_background" "true" +crush_modify_input "ignore" '["*.log","*.tmp"]' +` + err := os.WriteFile(hookPath, []byte(hookScript), 0o755) + require.NoError(t, err) + + executor := NewExecutor(tempDir) + ctx := context.Background() + hookCtx := HookContext{ + HookType: HookPreToolUse, + SessionID: "test-session", + WorkingDir: tempDir, + Data: map[string]any{}, + } + + result, err := executor.Execute(ctx, hookPath, hookCtx) + + require.NoError(t, err) + require.NotNil(t, result.ModifiedInput) + assert.Equal(t, float64(100), result.ModifiedInput["offset"]) + assert.Equal(t, float64(50), result.ModifiedInput["limit"]) + assert.Equal(t, true, result.ModifiedInput["run_in_background"]) + assert.Equal(t, []any{"*.log", "*.tmp"}, result.ModifiedInput["ignore"]) + }) +} + +func TestGetHelpersScript(t *testing.T) { + script := GetHelpersScript() + + assert.NotEmpty(t, script) + assert.Contains(t, script, "crush_approve") + assert.Contains(t, script, "crush_deny") + assert.Contains(t, script, "crush_add_context") +} diff --git a/internal/hooks/helpers.sh b/internal/hooks/helpers.sh new file mode 100644 index 0000000000000000000000000000000000000000..b19749628e42e1310a8dfe5cfcd8843c82c257e0 --- /dev/null +++ b/internal/hooks/helpers.sh @@ -0,0 +1,121 @@ +#!/bin/bash +# Crush Hook Helper Functions +# These functions are automatically available in all hooks. +# No need to source this file - it's prepended automatically. + +# Permission helpers + +# Approve the current tool call. +# Usage: crush_approve ["message"] +crush_approve() { + export CRUSH_PERMISSION=approve + [ -n "$1" ] && export CRUSH_MESSAGE="$1" +} + +# Deny the current tool call. +# Usage: crush_deny ["message"] +crush_deny() { + export CRUSH_PERMISSION=deny + export CRUSH_CONTINUE=false + [ -n "$1" ] && export CRUSH_MESSAGE="$1" + exit 2 +} + +# Ask user for permission (default behavior). +# Usage: crush_ask ["message"] +crush_ask() { + export CRUSH_PERMISSION=ask + [ -n "$1" ] && export CRUSH_MESSAGE="$1" +} + +# Context helpers + +# Add raw text content to LLM context. +# Usage: crush_add_context "content" +crush_add_context() { + export CRUSH_CONTEXT_CONTENT="$1" +} + +# Add a file to be loaded into LLM context. +# Usage: crush_add_context_file "/path/to/file.md" +crush_add_context_file() { + if [ -z "$CRUSH_CONTEXT_FILES" ]; then + export CRUSH_CONTEXT_FILES="$1" + else + export CRUSH_CONTEXT_FILES="$CRUSH_CONTEXT_FILES:$1" + fi +} + +# Modification helpers + +# Modify the user prompt (UserPromptSubmit hooks only). +# Usage: crush_modify_prompt "new prompt text" +crush_modify_prompt() { + export CRUSH_MODIFIED_PROMPT="$1" +} + +# Modify tool input parameters (PreToolUse hooks only). +# Values are parsed as JSON when valid, supporting strings, numbers, booleans, arrays, objects. +# Usage: crush_modify_input "param_name" "value" +# Examples: +# crush_modify_input "command" "ls -la" +# crush_modify_input "offset" "100" +# crush_modify_input "run_in_background" "true" +# crush_modify_input "ignore" '["*.log","*.tmp"]' +crush_modify_input() { + local key="$1" + local value="$2" + if [ -z "$CRUSH_MODIFIED_INPUT" ]; then + export CRUSH_MODIFIED_INPUT="$key=$value" + else + export CRUSH_MODIFIED_INPUT="$CRUSH_MODIFIED_INPUT:$key=$value" + fi +} + +# Modify tool output (PostToolUse hooks only). +# Usage: crush_modify_output "field_name" "value" +crush_modify_output() { + local key="$1" + local value="$2" + if [ -z "$CRUSH_MODIFIED_OUTPUT" ]; then + export CRUSH_MODIFIED_OUTPUT="$key=$value" + else + export CRUSH_MODIFIED_OUTPUT="$CRUSH_MODIFIED_OUTPUT:$key=$value" + fi +} + +# Stop execution. +# Usage: crush_stop ["message"] +crush_stop() { + export CRUSH_CONTINUE=false + [ -n "$1" ] && export CRUSH_MESSAGE="$1" + exit 1 +} + +# Input parsing helpers +# These read from the JSON context saved in _CRUSH_STDIN + +# Get a field from the hook context. +# Usage: VALUE=$(crush_get_input "field_name") +crush_get_input() { + echo "$_CRUSH_STDIN" | jq -r ".$1 // empty" +} + +# Get a tool input parameter. +# Usage: COMMAND=$(crush_get_tool_input "command") +crush_get_tool_input() { + echo "$_CRUSH_STDIN" | jq -r ".tool_input.$1 // empty" +} + +# Get the user prompt. +# Usage: PROMPT=$(crush_get_prompt) +crush_get_prompt() { + echo "$_CRUSH_STDIN" | jq -r '.prompt // empty' +} + +# Logging helper. +# Writes to stderr which is captured by Crush. +# Usage: crush_log "debug message" +crush_log() { + echo "[CRUSH HOOK] $*" >&2 +} diff --git a/internal/hooks/manager.go b/internal/hooks/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..b17753647f184b748ab2769c4a38cf06baf516c2 --- /dev/null +++ b/internal/hooks/manager.go @@ -0,0 +1,285 @@ +package hooks + +import ( + "context" + "fmt" + "log/slog" + "maps" + "os" + "path/filepath" + "runtime" + "slices" + "sort" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/csync" +) + +type manager struct { + workingDir string + dataDir string + config *Config + executor *Executor + hooks *csync.Map[HookType, []string] +} + +// NewManager creates a new hook manager. +func NewManager(workingDir, dataDir string, cfg *Config) Manager { + if cfg == nil { + cfg = &Config{ + Enabled: true, + TimeoutSeconds: 30, + Directories: []string{filepath.Join(dataDir, "hooks")}, + } + } + + // Ensure default directory if not specified. + if len(cfg.Directories) == 0 { + cfg.Directories = []string{filepath.Join(dataDir, "hooks")} + } + + return &manager{ + workingDir: workingDir, + dataDir: dataDir, + config: cfg, + executor: NewExecutor(workingDir), + hooks: csync.NewMap[HookType, []string](), + } +} + +// isExecutable checks if a file is executable. +// On Unix: checks execute permission bits for .sh files. +// On Windows: only recognizes .sh extension (as we use POSIX shell emulator). +func isExecutable(info os.FileInfo) bool { + name := strings.ToLower(info.Name()) + if !strings.HasSuffix(name, ".sh") { + return false + } + + if runtime.GOOS == "windows" { + return true + } + return info.Mode()&0o111 != 0 +} + +// ExecuteHooks implements Manager. +func (m *manager) ExecuteHooks(ctx context.Context, hookType HookType, hookContext HookContext) (HookResult, error) { + if !m.config.Enabled { + return HookResult{Continue: true}, nil + } + + hookContext.HookType = hookType + hookContext.Environment = m.config.Environment + + hooks := m.discoverHooks(hookType) + if len(hooks) == 0 { + return HookResult{Continue: true}, nil + } + + slog.Debug("Executing hooks", "type", hookType, "count", len(hooks)) + + accumulated := HookResult{Continue: true} + for _, hookPath := range hooks { + if m.isDisabled(hookPath) { + slog.Debug("Skipping disabled hook", "path", hookPath) + continue + } + + hookCtx, cancel := context.WithTimeout(ctx, time.Duration(m.config.TimeoutSeconds)*time.Second) + + result, err := m.executor.Execute(hookCtx, hookPath, hookContext) + cancel() + + if err != nil { + slog.Error("Hook execution failed", "path", hookPath, "error", err) + if hookType == HookPreToolUse { + accumulated.Continue = false + accumulated.Permission = "deny" + accumulated.Message = fmt.Sprintf("Hook failed: %v", err) + return accumulated, nil + } + continue + } + + if result.Message != "" { + slog.Info("Hook message", "path", hookPath, "message", result.Message) + } + + m.mergeResults(&accumulated, result) + + if !result.Continue { + slog.Info("Hook stopped execution", "path", hookPath) + break + } + } + + return accumulated, nil +} + +// discoverHooks finds all executable hooks for the given type. +func (m *manager) discoverHooks(hookType HookType) []string { + if cached, ok := m.hooks.Get(hookType); ok { + return cached + } + + var hooks []string + + for _, dir := range m.config.Directories { + if _, err := os.Stat(dir); err == nil { + entries, err := os.ReadDir(dir) + if err == nil { + for _, entry := range entries { + if entry.IsDir() { + continue + } + + hookPath := filepath.Join(dir, entry.Name()) + + info, err := entry.Info() + if err != nil { + continue + } + + if !isExecutable(info) { + continue + } + + hooks = append(hooks, hookPath) + slog.Debug("Discovered catch-all hook", "path", hookPath, "type", hookType) + } + } + } + + hookDir := filepath.Join(dir, string(hookType)) + if _, err := os.Stat(hookDir); os.IsNotExist(err) { + continue + } + + entries, err := os.ReadDir(hookDir) + if err != nil { + slog.Error("Failed to read hooks directory", "dir", hookDir, "error", err) + continue + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + hookPath := filepath.Join(hookDir, entry.Name()) + + info, err := entry.Info() + if err != nil { + continue + } + + if !isExecutable(info) { + slog.Debug("Skipping non-executable hook", "path", hookPath) + continue + } + + hooks = append(hooks, hookPath) + } + } + + if inlineHooks, ok := m.config.Inline[string(hookType)]; ok { + for _, inline := range inlineHooks { + hookPath, err := m.writeInlineHook(hookType, inline) + if err != nil { + slog.Error("Failed to write inline hook", "name", inline.Name, "error", err) + continue + } + hooks = append(hooks, hookPath) + } + } + + sort.Strings(hooks) + m.hooks.Set(hookType, hooks) + return hooks +} + +// writeInlineHook writes an inline hook script to a temp file. +func (m *manager) writeInlineHook(hookType HookType, inline InlineHook) (string, error) { + tempDir := filepath.Join(m.dataDir, "hooks", ".inline", string(hookType)) + if err := os.MkdirAll(tempDir, 0o755); err != nil { + return "", err + } + + hookPath := filepath.Join(tempDir, inline.Name) + if err := os.WriteFile(hookPath, []byte(inline.Script), 0o755); err != nil { + return "", err + } + + return hookPath, nil +} + +// isDisabled checks if a hook is in the disabled list. +func (m *manager) isDisabled(hookPath string) bool { + for _, dir := range m.config.Directories { + if rel, err := filepath.Rel(dir, hookPath); err == nil { + // Normalize to forward slashes for cross-platform comparison + rel = filepath.ToSlash(rel) + if slices.Contains(m.config.Disabled, rel) { + return true + } + } + } + return false +} + +// mergeResults merges a new result into the accumulated result. +func (m *manager) mergeResults(accumulated *HookResult, new *HookResult) { + accumulated.Continue = accumulated.Continue && new.Continue + + if new.Permission != "" { + if new.Permission == "deny" { + accumulated.Permission = "deny" + } else if new.Permission == "ask" && accumulated.Permission != "deny" { + accumulated.Permission = "ask" + } else if new.Permission == "approve" && accumulated.Permission == "" { + accumulated.Permission = "approve" + } + } + + if new.ModifiedPrompt != nil { + accumulated.ModifiedPrompt = new.ModifiedPrompt + } + + if len(new.ModifiedInput) > 0 { + if accumulated.ModifiedInput == nil { + accumulated.ModifiedInput = make(map[string]any) + } + maps.Copy(accumulated.ModifiedInput, new.ModifiedInput) + } + + if len(new.ModifiedOutput) > 0 { + if accumulated.ModifiedOutput == nil { + accumulated.ModifiedOutput = make(map[string]any) + } + maps.Copy(accumulated.ModifiedOutput, new.ModifiedOutput) + } + + if new.ContextContent != "" { + if accumulated.ContextContent == "" { + accumulated.ContextContent = new.ContextContent + } else { + accumulated.ContextContent += "\n\n" + new.ContextContent + } + } + + accumulated.ContextFiles = append(accumulated.ContextFiles, new.ContextFiles...) + + if new.Message != "" { + if accumulated.Message == "" { + accumulated.Message = new.Message + } else { + accumulated.Message += "; " + new.Message + } + } +} + +// ListHooks implements Manager. +func (m *manager) ListHooks(hookType HookType) []string { + return m.discoverHooks(hookType) +} diff --git a/internal/hooks/manager_test.go b/internal/hooks/manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e8ec12ea9d782408ff3f4e2ad25da4431bb56fef --- /dev/null +++ b/internal/hooks/manager_test.go @@ -0,0 +1,524 @@ +package hooks + +import ( + "context" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestManager(t *testing.T) { + t.Run("discovers hooks in order", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Create hooks with numeric prefixes. + hooks := []string{"02-second.sh", "01-first.sh", "03-third.sh"} + for _, name := range hooks { + path := filepath.Join(hooksDir, name) + err := os.WriteFile(path, []byte("#!/bin/bash\necho test"), 0o755) + require.NoError(t, err) + } + + mgr := NewManager(tempDir, dataDir, nil) + discovered := mgr.ListHooks(HookPreToolUse) + + assert.Len(t, discovered, 3) + // Should be sorted alphabetically. + assert.Contains(t, discovered[0], "01-first.sh") + assert.Contains(t, discovered[1], "02-second.sh") + assert.Contains(t, discovered[2], "03-third.sh") + }) + + t.Run("skips non-executable files", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Create non-executable file. + path := filepath.Join(hooksDir, "non-executable.sh") + err := os.WriteFile(path, []byte("#!/bin/bash\necho test"), 0o644) + require.NoError(t, err) + + mgr := NewManager(tempDir, dataDir, nil) + discovered := mgr.ListHooks(HookPreToolUse) + + // On Windows, .sh files are always considered executable + // On Unix, non-executable files (0o644) should be skipped + if runtime.GOOS == "windows" { + assert.Len(t, discovered, 1, "On Windows, .sh files are executable regardless of permissions") + } else { + assert.Len(t, discovered, 0, "On Unix, non-executable files should be skipped") + } + }) + + t.Run("discovers hooks by extension on all platforms", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Only .sh files are recognized as hooks. + // On Unix, they need execute permission. On Windows, extension is enough. + validHook := filepath.Join(hooksDir, "valid-hook.sh") + err := os.WriteFile(validHook, []byte("#!/bin/bash\necho test"), 0o755) + require.NoError(t, err) + + // These should NOT be discovered (wrong extensions). + invalidFiles := []string{"hook.bat", "hook.cmd", "hook.ps1", "hook.txt"} + for _, name := range invalidFiles { + path := filepath.Join(hooksDir, name) + err := os.WriteFile(path, []byte("echo test"), 0o755) + require.NoError(t, err) + } + + mgr := NewManager(tempDir, dataDir, nil) + discovered := mgr.ListHooks(HookPreToolUse) + + // Only the .sh file should be discovered. + assert.Len(t, discovered, 1) + assert.Contains(t, discovered[0], "valid-hook.sh") + }) + + t.Run("executes multiple hooks and merges results", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "user-prompt-submit") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Hook 1: Adds context. + hook1 := filepath.Join(hooksDir, "01-add-context.sh") + err := os.WriteFile(hook1, []byte(`#!/bin/bash +crush_add_context "Context from hook 1" +`), 0o755) + require.NoError(t, err) + + // Hook 2: Adds more context. + hook2 := filepath.Join(hooksDir, "02-add-more.sh") + err = os.WriteFile(hook2, []byte(`#!/bin/bash +crush_add_context "Context from hook 2" +`), 0o755) + require.NoError(t, err) + + mgr := NewManager(tempDir, dataDir, nil) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookUserPromptSubmit, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{ + "prompt": "test prompt", + }, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + // Contexts should be merged with \n\n separator. + assert.Equal(t, "Context from hook 1\n\nContext from hook 2", result.ContextContent) + }) + + t.Run("stops on first hook that sets continue=false", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Hook 1: Denies. + hook1 := filepath.Join(hooksDir, "01-deny.sh") + err := os.WriteFile(hook1, []byte(`#!/bin/bash +crush_deny "blocked" +`), 0o755) + require.NoError(t, err) + + // Hook 2: Should not execute. + hook2 := filepath.Join(hooksDir, "02-never-runs.sh") + err = os.WriteFile(hook2, []byte(`#!/bin/bash +export CRUSH_MESSAGE="should not see this" +`), 0o755) + require.NoError(t, err) + + mgr := NewManager(tempDir, dataDir, nil) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.False(t, result.Continue) + assert.Equal(t, "deny", result.Permission) + assert.Equal(t, "blocked", result.Message) + assert.NotContains(t, result.Message, "should not see this") + }) + + t.Run("merges permissions with deny winning", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Hook 1: Approves. + hook1 := filepath.Join(hooksDir, "01-approve.sh") + err := os.WriteFile(hook1, []byte(`#!/bin/bash +export CRUSH_PERMISSION=approve +`), 0o755) + require.NoError(t, err) + + // Hook 2: Denies (should win). + hook2 := filepath.Join(hooksDir, "02-deny.sh") + err = os.WriteFile(hook2, []byte(`#!/bin/bash +export CRUSH_PERMISSION=deny +`), 0o755) + require.NoError(t, err) + + mgr := NewManager(tempDir, dataDir, nil) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.Equal(t, "deny", result.Permission) + }) + + t.Run("disabled hooks are skipped", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Hook 1: Should run. + hook1 := filepath.Join(hooksDir, "01-enabled.sh") + err := os.WriteFile(hook1, []byte(`#!/bin/bash +export CRUSH_MESSAGE="enabled" +`), 0o755) + require.NoError(t, err) + + // Hook 2: Disabled. + hook2 := filepath.Join(hooksDir, "02-disabled.sh") + err = os.WriteFile(hook2, []byte(`#!/bin/bash +export CRUSH_MESSAGE="disabled" +`), 0o755) + require.NoError(t, err) + + cfg := &Config{ + Enabled: true, + TimeoutSeconds: 30, + Directories: []string{filepath.Join(dataDir, "hooks")}, + Disabled: []string{"pre-tool-use/02-disabled.sh"}, + } + + mgr := NewManager(tempDir, dataDir, cfg) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.Equal(t, "enabled", result.Message) + }) + + t.Run("inline hooks are executed", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + + cfg := &Config{ + Enabled: true, + TimeoutSeconds: 30, + Directories: []string{filepath.Join(dataDir, "hooks")}, + Inline: map[string][]InlineHook{ + "user-prompt-submit": { + { + Name: "inline-test.sh", + Script: `#!/bin/bash +export CRUSH_MESSAGE="inline hook executed" +`, + }, + }, + }, + } + + mgr := NewManager(tempDir, dataDir, cfg) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookUserPromptSubmit, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.Equal(t, "inline hook executed", result.Message) + }) + + t.Run("returns empty result when hooks disabled", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + + cfg := &Config{ + Enabled: false, + } + + mgr := NewManager(tempDir, dataDir, cfg) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + assert.Empty(t, result.Message) + }) + + t.Run("returns empty result when no hooks found", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + + mgr := NewManager(tempDir, dataDir, nil) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) + }) + + t.Run("handles hook failure on PreToolUse by denying", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Hook that fails with exit 1. + hook := filepath.Join(hooksDir, "01-fail.sh") + err := os.WriteFile(hook, []byte(`#!/bin/bash +exit 1 +`), 0o755) + require.NoError(t, err) + + mgr := NewManager(tempDir, dataDir, nil) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.False(t, result.Continue) + assert.Equal(t, "deny", result.Permission) + assert.Contains(t, result.Message, "Hook failed") + }) + + t.Run("caches discovered hooks", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + hook := filepath.Join(hooksDir, "01-test.sh") + err := os.WriteFile(hook, []byte("#!/bin/bash\necho test"), 0o755) + require.NoError(t, err) + + mgr := NewManager(tempDir, dataDir, nil) + + // First call - discovers. + hooks1 := mgr.ListHooks(HookPreToolUse) + assert.Len(t, hooks1, 1) + + // Second call - should use cache. + hooks2 := mgr.ListHooks(HookPreToolUse) + assert.Equal(t, hooks1, hooks2) + }) + + t.Run("catch-all hooks at root execute for all types", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Create catch-all hook at root level. + catchAllHook := filepath.Join(hooksDir, "00-catch-all.sh") + err := os.WriteFile(catchAllHook, []byte(`#!/bin/bash +export CRUSH_MESSAGE="catch-all: $CRUSH_HOOK_TYPE" +`), 0o755) + require.NoError(t, err) + + // Create specific hook for pre-tool-use. + specificDir := filepath.Join(hooksDir, "pre-tool-use") + require.NoError(t, os.MkdirAll(specificDir, 0o755)) + specificHook := filepath.Join(specificDir, "01-specific.sh") + err = os.WriteFile(specificHook, []byte(`#!/bin/bash +export CRUSH_MESSAGE="$CRUSH_MESSAGE; specific hook" +`), 0o755) + require.NoError(t, err) + + mgr := NewManager(tempDir, dataDir, nil) + + // Test PreToolUse - should execute both catch-all and specific. + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.Contains(t, result.Message, "catch-all: pre-tool-use") + assert.Contains(t, result.Message, "specific hook") + + // Test UserPromptSubmit - should only execute catch-all. + result2, err := mgr.ExecuteHooks(ctx, HookUserPromptSubmit, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.Equal(t, "catch-all: user-prompt-submit", result2.Message) + assert.NotContains(t, result2.Message, "specific hook") + }) + + t.Run("passes environment variables from config to hooks", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Hook that checks for custom environment variables. + hook := filepath.Join(hooksDir, "01-check-env.sh") + err := os.WriteFile(hook, []byte(`#!/bin/bash +if [ "$CUSTOM_API_KEY" = "test-key-123" ] && [ "$CUSTOM_ENV" = "production" ]; then + export CRUSH_MESSAGE="config environment variables received" +else + export CRUSH_MESSAGE="environment variables missing" +fi +`), 0o755) + require.NoError(t, err) + + cfg := &Config{ + Enabled: true, + TimeoutSeconds: 30, + Directories: []string{filepath.Join(dataDir, "hooks")}, + Environment: map[string]string{ + "CUSTOM_API_KEY": "test-key-123", + "CUSTOM_ENV": "production", + }, + } + + mgr := NewManager(tempDir, dataDir, cfg) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.Equal(t, "config environment variables received", result.Message) + }) + + t.Run("handles inline hook write failure gracefully", func(t *testing.T) { + tempDir := t.TempDir() + // Use a read-only directory as dataDir to force write failure. + readOnlyDir := filepath.Join(tempDir, "readonly") + require.NoError(t, os.MkdirAll(readOnlyDir, 0o555)) // Read-only + + cfg := &Config{ + Enabled: true, + TimeoutSeconds: 30, + Directories: []string{filepath.Join(readOnlyDir, "hooks")}, + Inline: map[string][]InlineHook{ + "pre-tool-use": { + { + Name: "inline-fail.sh", + Script: "#!/bin/bash\necho test", + }, + }, + }, + } + + mgr := NewManager(tempDir, readOnlyDir, cfg) + ctx := context.Background() + + // Should not error even though inline hook write fails. + // The hook will be skipped and logged. + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.True(t, result.Continue) // Should continue despite write failure + }) + + t.Run("handles hooks directory read failure gracefully", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Create a hook file. + hook := filepath.Join(hooksDir, "01-test.sh") + require.NoError(t, os.WriteFile(hook, []byte("#!/bin/bash\necho test"), 0o755)) + + mgr := NewManager(tempDir, dataDir, nil) + + // Make directory unreadable after discovery to test error path. + // Note: This is tricky to test reliably cross-platform. + // On some systems, we can't make a directory unreadable if we own it. + // We'll test that hooks are cached and re-discovery works. + hooks1 := mgr.ListHooks(HookPreToolUse) + assert.Len(t, hooks1, 1) + + // Add another hook. + hook2 := filepath.Join(hooksDir, "02-test.sh") + require.NoError(t, os.WriteFile(hook2, []byte("#!/bin/bash\necho test2"), 0o755)) + + // Should still return cached hooks (won't see new one). + hooks2 := mgr.ListHooks(HookPreToolUse) + assert.Len(t, hooks2, 1, "hooks are cached, new hook not seen") + }) + + t.Run("approve permission is set when accumulated is empty", func(t *testing.T) { + tempDir := t.TempDir() + dataDir := filepath.Join(tempDir, ".crush") + hooksDir := filepath.Join(dataDir, "hooks", "pre-tool-use") + require.NoError(t, os.MkdirAll(hooksDir, 0o755)) + + // Single hook that approves. + hook := filepath.Join(hooksDir, "01-approve.sh") + require.NoError(t, os.WriteFile(hook, []byte(`#!/bin/bash +export CRUSH_PERMISSION=approve +export CRUSH_MESSAGE="auto-approved" +`), 0o755)) + + mgr := NewManager(tempDir, dataDir, nil) + ctx := context.Background() + result, err := mgr.ExecuteHooks(ctx, HookPreToolUse, HookContext{ + SessionID: "test", + WorkingDir: tempDir, + Data: map[string]any{}, + }) + + require.NoError(t, err) + assert.Equal(t, "approve", result.Permission) + assert.Equal(t, "auto-approved", result.Message) + }) +} diff --git a/internal/hooks/parser.go b/internal/hooks/parser.go new file mode 100644 index 0000000000000000000000000000000000000000..3acbe1f99ad4b89a905a4cddabb0af9672af19c5 --- /dev/null +++ b/internal/hooks/parser.go @@ -0,0 +1,183 @@ +package hooks + +import ( + "encoding/base64" + "encoding/json" + "strings" +) + +// parseShellEnv parses hook results from environment variables. +func parseShellEnv(env []string) *HookResult { + result := &HookResult{Continue: true} + + for _, line := range env { + if !strings.HasPrefix(line, "CRUSH_") { + continue + } + + key, value, ok := strings.Cut(line, "=") + if !ok { + continue + } + + switch key { + case "CRUSH_CONTINUE": + result.Continue = value == "true" + + case "CRUSH_PERMISSION": + result.Permission = value + + case "CRUSH_MESSAGE": + result.Message = value + + case "CRUSH_MODIFIED_PROMPT": + result.ModifiedPrompt = &value + + case "CRUSH_CONTEXT_CONTENT": + if decoded, err := base64.StdEncoding.DecodeString(value); err == nil { + result.ContextContent = string(decoded) + } else { + result.ContextContent = value + } + + case "CRUSH_CONTEXT_FILES": + if value != "" { + result.ContextFiles = strings.Split(value, ":") + } + + case "CRUSH_MODIFIED_INPUT": + if value != "" { + result.ModifiedInput = parseKeyValuePairs(value) + } + + case "CRUSH_MODIFIED_OUTPUT": + if value != "" { + result.ModifiedOutput = parseKeyValuePairs(value) + } + } + } + + return result +} + +// parseJSONResult parses hook results from JSON output. +func parseJSONResult(data []byte) (*HookResult, error) { + result := &HookResult{Continue: true} + + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + if v, ok := raw["continue"].(bool); ok { + result.Continue = v + } + + if v, ok := raw["permission"].(string); ok { + result.Permission = v + } + + if v, ok := raw["message"].(string); ok { + result.Message = v + } + + if v, ok := raw["modified_prompt"].(string); ok { + result.ModifiedPrompt = &v + } + + if v, ok := raw["modified_input"].(map[string]any); ok { + result.ModifiedInput = v + } + + if v, ok := raw["modified_output"].(map[string]any); ok { + result.ModifiedOutput = v + } + + if v, ok := raw["context_content"].(string); ok { + result.ContextContent = v + } + + if v, ok := raw["context_files"].([]any); ok { + for _, file := range v { + if s, ok := file.(string); ok { + result.ContextFiles = append(result.ContextFiles, s) + } + } + } + + return result, nil +} + +// mergeJSONResult merges JSON-parsed result into env-parsed result. +func mergeJSONResult(base *HookResult, jsonResult *HookResult) { + if !jsonResult.Continue { + base.Continue = false + } + + if jsonResult.Permission != "" { + base.Permission = jsonResult.Permission + } + + if jsonResult.Message != "" { + if base.Message == "" { + base.Message = jsonResult.Message + } else { + base.Message += "; " + jsonResult.Message + } + } + + if jsonResult.ModifiedPrompt != nil { + base.ModifiedPrompt = jsonResult.ModifiedPrompt + } + + if len(jsonResult.ModifiedInput) > 0 { + if base.ModifiedInput == nil { + base.ModifiedInput = make(map[string]any) + } + for k, v := range jsonResult.ModifiedInput { + base.ModifiedInput[k] = v + } + } + + if len(jsonResult.ModifiedOutput) > 0 { + if base.ModifiedOutput == nil { + base.ModifiedOutput = make(map[string]any) + } + for k, v := range jsonResult.ModifiedOutput { + base.ModifiedOutput[k] = v + } + } + + if jsonResult.ContextContent != "" { + if base.ContextContent == "" { + base.ContextContent = jsonResult.ContextContent + } else { + base.ContextContent += "\n\n" + jsonResult.ContextContent + } + } + + base.ContextFiles = append(base.ContextFiles, jsonResult.ContextFiles...) +} + +// parseKeyValuePairs parses "key=value:key2=value2" format into a map. +// Values are parsed as JSON when possible, otherwise treated as strings. +func parseKeyValuePairs(encoded string) map[string]any { + result := make(map[string]any) + pairs := strings.Split(encoded, ":") + for _, pair := range pairs { + key, value, ok := strings.Cut(pair, "=") + if !ok { + continue + } + + // Try to parse value as JSON to support numbers, booleans, arrays, objects + var jsonValue any + if err := json.Unmarshal([]byte(value), &jsonValue); err == nil { + result[key] = jsonValue + } else { + // Fall back to string if not valid JSON + result[key] = value + } + } + return result +} diff --git a/internal/hooks/parser_test.go b/internal/hooks/parser_test.go new file mode 100644 index 0000000000000000000000000000000000000000..11fb73e43d8d9fb29840acd6b05e5794dced58fa --- /dev/null +++ b/internal/hooks/parser_test.go @@ -0,0 +1,416 @@ +package hooks + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseShellEnv(t *testing.T) { + t.Run("parses basic fields", func(t *testing.T) { + env := []string{ + "PATH=/usr/bin", + "CRUSH_CONTINUE=false", + "CRUSH_PERMISSION=approve", + "CRUSH_MESSAGE=test message", + "HOME=/home/user", + } + + result := parseShellEnv(env) + + assert.False(t, result.Continue) + assert.Equal(t, "approve", result.Permission) + assert.Equal(t, "test message", result.Message) + }) + + t.Run("parses modified prompt", func(t *testing.T) { + env := []string{ + "CRUSH_MODIFIED_PROMPT=new prompt text", + } + + result := parseShellEnv(env) + + require.NotNil(t, result.ModifiedPrompt) + assert.Equal(t, "new prompt text", *result.ModifiedPrompt) + }) + + t.Run("parses context content", func(t *testing.T) { + env := []string{ + "CRUSH_CONTEXT_CONTENT=some context", + } + + result := parseShellEnv(env) + + assert.Equal(t, "some context", result.ContextContent) + }) + + t.Run("parses base64 context content", func(t *testing.T) { + text := "multiline\ncontext\nhere" + encoded := base64.StdEncoding.EncodeToString([]byte(text)) + + env := []string{ + "CRUSH_CONTEXT_CONTENT=" + encoded, + } + + result := parseShellEnv(env) + + assert.Equal(t, text, result.ContextContent) + }) + + t.Run("parses context files", func(t *testing.T) { + env := []string{ + "CRUSH_CONTEXT_FILES=file1.md:file2.txt:file3.go", + } + + result := parseShellEnv(env) + + assert.Equal(t, []string{"file1.md", "file2.txt", "file3.go"}, result.ContextFiles) + }) + + t.Run("defaults to continue=true", func(t *testing.T) { + env := []string{} + + result := parseShellEnv(env) + + assert.True(t, result.Continue) + }) + + t.Run("ignores non-CRUSH env vars", func(t *testing.T) { + env := []string{ + "PATH=/usr/bin", + "HOME=/home/user", + "CRUSH_MESSAGE=test", + } + + result := parseShellEnv(env) + + assert.Equal(t, "test", result.Message) + }) + + t.Run("falls back to raw value for invalid base64", func(t *testing.T) { + // Invalid base64 string should be used as-is. + env := []string{ + "CRUSH_CONTEXT_CONTENT=this is not base64!@#$", + } + + result := parseShellEnv(env) + + assert.Equal(t, "this is not base64!@#$", result.ContextContent) + }) + + t.Run("parses modified input", func(t *testing.T) { + env := []string{ + "CRUSH_MODIFIED_INPUT=command=ls -la:working_dir=/tmp", + } + + result := parseShellEnv(env) + + require.NotNil(t, result.ModifiedInput) + assert.Equal(t, "ls -la", result.ModifiedInput["command"]) + assert.Equal(t, "/tmp", result.ModifiedInput["working_dir"]) + }) + + t.Run("parses modified output", func(t *testing.T) { + env := []string{ + "CRUSH_MODIFIED_OUTPUT=status=redacted:data=[REDACTED]", + } + + result := parseShellEnv(env) + + require.NotNil(t, result.ModifiedOutput) + assert.Equal(t, "redacted", result.ModifiedOutput["status"]) + assert.Equal(t, "[REDACTED]", result.ModifiedOutput["data"]) + }) + + t.Run("parses modified input with JSON types", func(t *testing.T) { + env := []string{ + `CRUSH_MODIFIED_INPUT=offset=100:limit=50:run_in_background=true:ignore=["*.log","*.tmp"]`, + } + + result := parseShellEnv(env) + + require.NotNil(t, result.ModifiedInput) + assert.Equal(t, float64(100), result.ModifiedInput["offset"]) // JSON numbers are float64 + assert.Equal(t, float64(50), result.ModifiedInput["limit"]) + assert.Equal(t, true, result.ModifiedInput["run_in_background"]) + assert.Equal(t, []any{"*.log", "*.tmp"}, result.ModifiedInput["ignore"]) + }) + + t.Run("parses modified input with strings containing colons", func(t *testing.T) { + // Colons in file paths should work if the value doesn't contain '=' + env := []string{ + `CRUSH_MODIFIED_INPUT=path=/usr/local/bin:name=test`, + } + + result := parseShellEnv(env) + + require.NotNil(t, result.ModifiedInput) + // First pair: path=/usr/local/bin + // Second pair: name=test + // Note: This splits on first '=' in each pair + assert.Equal(t, "/usr/local/bin", result.ModifiedInput["path"]) + assert.Equal(t, "test", result.ModifiedInput["name"]) + }) +} + +func TestParseJSONResult(t *testing.T) { + t.Run("parses basic fields", func(t *testing.T) { + json := []byte(`{ + "continue": false, + "permission": "deny", + "message": "blocked" + }`) + + result, err := parseJSONResult(json) + + require.NoError(t, err) + assert.False(t, result.Continue) + assert.Equal(t, "deny", result.Permission) + assert.Equal(t, "blocked", result.Message) + }) + + t.Run("parses modified_input", func(t *testing.T) { + json := []byte(`{ + "modified_input": { + "command": "ls -la", + "working_dir": "/tmp" + } + }`) + + result, err := parseJSONResult(json) + + require.NoError(t, err) + assert.Equal(t, map[string]any{ + "command": "ls -la", + "working_dir": "/tmp", + }, result.ModifiedInput) + }) + + t.Run("parses modified_output", func(t *testing.T) { + json := []byte(`{ + "modified_output": { + "content": "filtered output" + } + }`) + + result, err := parseJSONResult(json) + + require.NoError(t, err) + assert.Equal(t, map[string]any{ + "content": "filtered output", + }, result.ModifiedOutput) + }) + + t.Run("parses context_files array", func(t *testing.T) { + json := []byte(`{ + "context_files": ["file1.md", "file2.txt"] + }`) + + result, err := parseJSONResult(json) + + require.NoError(t, err) + assert.Equal(t, []string{"file1.md", "file2.txt"}, result.ContextFiles) + }) + + t.Run("returns error on invalid JSON", func(t *testing.T) { + json := []byte(`{invalid}`) + + _, err := parseJSONResult(json) + + assert.Error(t, err) + }) + + t.Run("defaults to continue=true", func(t *testing.T) { + json := []byte(`{"message": "test"}`) + + result, err := parseJSONResult(json) + + require.NoError(t, err) + assert.True(t, result.Continue) + }) + + t.Run("handles wrong type for modified_input", func(t *testing.T) { + // modified_input should be a map, but here it's a string. + json := []byte(`{ + "modified_input": "not a map" + }`) + + result, err := parseJSONResult(json) + + require.NoError(t, err) + // Should be nil/empty since type assertion failed. + assert.Nil(t, result.ModifiedInput) + }) + + t.Run("handles wrong type for modified_output", func(t *testing.T) { + // modified_output should be a map, but here it's an array. + json := []byte(`{ + "modified_output": ["not", "a", "map"] + }`) + + result, err := parseJSONResult(json) + + require.NoError(t, err) + assert.Nil(t, result.ModifiedOutput) + }) + + t.Run("handles non-string elements in context_files", func(t *testing.T) { + // context_files should be array of strings, but has numbers. + json := []byte(`{ + "context_files": ["file1.md", 123, "file2.md", null] + }`) + + result, err := parseJSONResult(json) + + require.NoError(t, err) + // Should only include valid strings. + assert.Equal(t, []string{"file1.md", "file2.md"}, result.ContextFiles) + }) + + t.Run("handles wrong type for context_files", func(t *testing.T) { + // context_files should be an array, but here it's a string. + json := []byte(`{ + "context_files": "not an array" + }`) + + result, err := parseJSONResult(json) + + require.NoError(t, err) + // Should be empty since type assertion failed. + assert.Empty(t, result.ContextFiles) + }) +} + +func TestMergeJSONResult(t *testing.T) { + t.Run("merges continue flag", func(t *testing.T) { + base := &HookResult{Continue: true} + json := &HookResult{Continue: false} + + mergeJSONResult(base, json) + + assert.False(t, base.Continue) + }) + + t.Run("merges permission", func(t *testing.T) { + base := &HookResult{} + json := &HookResult{Permission: "approve"} + + mergeJSONResult(base, json) + + assert.Equal(t, "approve", base.Permission) + }) + + t.Run("appends messages", func(t *testing.T) { + base := &HookResult{Message: "first"} + json := &HookResult{Message: "second"} + + mergeJSONResult(base, json) + + assert.Equal(t, "first; second", base.Message) + }) + + t.Run("merges modified_input maps", func(t *testing.T) { + base := &HookResult{ + ModifiedInput: map[string]any{ + "field1": "value1", + }, + } + json := &HookResult{ + ModifiedInput: map[string]any{ + "field2": "value2", + }, + } + + mergeJSONResult(base, json) + + assert.Equal(t, map[string]any{ + "field1": "value1", + "field2": "value2", + }, base.ModifiedInput) + }) + + t.Run("overwrites conflicting modified_input fields", func(t *testing.T) { + base := &HookResult{ + ModifiedInput: map[string]any{ + "field": "old", + }, + } + json := &HookResult{ + ModifiedInput: map[string]any{ + "field": "new", + }, + } + + mergeJSONResult(base, json) + + assert.Equal(t, "new", base.ModifiedInput["field"]) + }) + + t.Run("appends context content", func(t *testing.T) { + base := &HookResult{ContextContent: "first"} + json := &HookResult{ContextContent: "second"} + + mergeJSONResult(base, json) + + assert.Equal(t, "first\n\nsecond", base.ContextContent) + }) + + t.Run("appends context files", func(t *testing.T) { + base := &HookResult{ContextFiles: []string{"file1.md"}} + json := &HookResult{ContextFiles: []string{"file2.md", "file3.md"}} + + mergeJSONResult(base, json) + + assert.Equal(t, []string{"file1.md", "file2.md", "file3.md"}, base.ContextFiles) + }) + + t.Run("initializes ModifiedInput when nil", func(t *testing.T) { + // Base has nil ModifiedInput. + base := &HookResult{} + json := &HookResult{ + ModifiedInput: map[string]any{ + "field": "value", + }, + } + + mergeJSONResult(base, json) + + require.NotNil(t, base.ModifiedInput) + assert.Equal(t, "value", base.ModifiedInput["field"]) + }) + + t.Run("initializes ModifiedOutput when nil", func(t *testing.T) { + // Base has nil ModifiedOutput. + base := &HookResult{} + json := &HookResult{ + ModifiedOutput: map[string]any{ + "filtered": true, + }, + } + + mergeJSONResult(base, json) + + require.NotNil(t, base.ModifiedOutput) + assert.Equal(t, true, base.ModifiedOutput["filtered"]) + }) + + t.Run("sets context content when base is empty", func(t *testing.T) { + base := &HookResult{} + json := &HookResult{ContextContent: "new content"} + + mergeJSONResult(base, json) + + assert.Equal(t, "new content", base.ContextContent) + }) + + t.Run("sets message when base is empty", func(t *testing.T) { + base := &HookResult{} + json := &HookResult{Message: "new message"} + + mergeJSONResult(base, json) + + assert.Equal(t, "new message", base.Message) + }) +} diff --git a/internal/hooks/types.go b/internal/hooks/types.go new file mode 100644 index 0000000000000000000000000000000000000000..5fc2b6196efd2fd1adcc73809def33a944d33a17 --- /dev/null +++ b/internal/hooks/types.go @@ -0,0 +1,93 @@ +// Package hooks provides a Git-like hooks system for Crush. +// +// Hooks are executable scripts that run at specific points in the application +// lifecycle. They can modify behavior, add context, control permissions, and +// audit activity. +package hooks + +import "context" + +// HookType represents the type of hook. +type HookType string + +const ( + // HookUserPromptSubmit executes after user submits prompt, before sending to LLM. + HookUserPromptSubmit HookType = "user-prompt-submit" + + // HookPreToolUse executes after LLM requests tool use, before permission check & execution. + HookPreToolUse HookType = "pre-tool-use" + + // HookPostToolUse executes after tool executes, before result sent to LLM. + HookPostToolUse HookType = "post-tool-use" + + // HookStop executes when agent conversation loop stops or is cancelled. + HookStop HookType = "stop" +) + +// HookContext contains the data passed to hooks. +type HookContext struct { + // HookType is the type of hook being executed. + HookType HookType + + // SessionID is the current session ID. + SessionID string + + // WorkingDir is the working directory. + WorkingDir string + + // Data is hook-specific data marshaled to JSON and passed via stdin. + // For UserPromptSubmit: prompt, attachments, model, is_first_message + // For PreToolUse: tool_name, tool_call_id, tool_input + // For PostToolUse: tool_name, tool_call_id, tool_input, tool_output, execution_time_ms + // For Stop: reason + Data map[string]any + + // ToolName is the tool name (for tool hooks only). + ToolName string + + // ToolCallID is the tool call ID (for tool hooks only). + ToolCallID string + + // Environment contains additional environment variables to pass to the hook. + Environment map[string]string +} + +// HookResult contains the result of hook execution. +type HookResult struct { + // Continue indicates whether to continue execution. + // If false, execution stops. + Continue bool + + // Permission decision (for PreToolUse hooks only). + // Values: "ask" (default), "approve", "deny" + Permission string + + // ModifiedPrompt is the modified user prompt (for UserPromptSubmit). + ModifiedPrompt *string + + // ModifiedInput is the modified tool input parameters (for PreToolUse). + // This is a map that can be merged with the original tool input. + ModifiedInput map[string]any + + // ModifiedOutput is the modified tool output (for PostToolUse). + ModifiedOutput map[string]any + + // ContextContent is raw text content to add to LLM context. + ContextContent string + + // ContextFiles is a list of file paths to load and add to LLM context. + ContextFiles []string + + // Message is a user-facing message (logged and potentially displayed). + Message string +} + +// Manager coordinates hook discovery and execution. +type Manager interface { + // ExecuteHooks executes all hooks for the given type in order. + // Returns accumulated results from all hooks. + ExecuteHooks(ctx context.Context, hookType HookType, context HookContext) (HookResult, error) + + // ListHooks returns all discovered hooks for a given type. + ListHooks(hookType HookType) []string +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index f9f4656b82bbb6ee14b38469a20d493d98354b4a..39ee77226177f7b0cfed56757654b32e484f4bfa 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -100,7 +100,15 @@ func (s *Shell) Exec(ctx context.Context, command string) (string, string, error s.mu.Lock() defer s.mu.Unlock() - return s.exec(ctx, command) + return s.exec(ctx, command, nil) +} + +// ExecWithStdin executes a command in the shell with provided stdin +func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin io.Reader) (string, string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.exec(ctx, command, stdin) } // ExecStream executes a command in the shell with streaming output to provided writers @@ -237,9 +245,9 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand } // newInterp creates a new interpreter with the current shell state -func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) { +func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) { return interp.New( - interp.StdIO(nil, stdout, stderr), + interp.StdIO(stdin, stdout, stderr), interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), @@ -257,13 +265,13 @@ func (s *Shell) updateShellFromRunner(runner *interp.Runner) { } // execCommon is the shared implementation for executing commands -func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) error { +func (s *Shell) execCommon(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error { line, err := syntax.NewParser().Parse(strings.NewReader(command), "") if err != nil { return fmt.Errorf("could not parse command: %w", err) } - runner, err := s.newInterp(stdout, stderr) + runner, err := s.newInterp(stdin, stdout, stderr) if err != nil { return fmt.Errorf("could not run command: %w", err) } @@ -275,15 +283,15 @@ func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr i } // exec executes commands using a cross-platform shell interpreter. -func (s *Shell) exec(ctx context.Context, command string) (string, string, error) { +func (s *Shell) exec(ctx context.Context, command string, stdin io.Reader) (string, string, error) { var stdout, stderr bytes.Buffer - err := s.execCommon(ctx, command, &stdout, &stderr) + err := s.execCommon(ctx, command, stdin, &stdout, &stderr) return stdout.String(), stderr.String(), err } // execStream executes commands using POSIX shell emulation with streaming output func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error { - return s.execCommon(ctx, command, stdout, stderr) + return s.execCommon(ctx, command, nil, stdout, stderr) } func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {