diff --git a/internal/hooks/builtins.go b/internal/hooks/builtins.go new file mode 100644 index 0000000000000000000000000000000000000000..7221d63aef87b8cc553321c95015346cd11f4f60 --- /dev/null +++ b/internal/hooks/builtins.go @@ -0,0 +1,157 @@ +package hooks + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + + "mvdan.cc/sh/v3/interp" +) + +// crushGetInput reads a field from the hook context JSON. +// Usage: VALUE=$(crush_get_input "field_name") +func crushGetInput(ctx context.Context, args []string) error { + hc := interp.HandlerCtx(ctx) + + if len(args) != 2 { + fmt.Fprintln(hc.Stderr, "Usage: crush_get_input ") + return interp.ExitStatus(1) + } + + fieldName := args[1] + stdin := hc.Env.Get("_CRUSH_STDIN").Str + + var data map[string]any + if err := json.Unmarshal([]byte(stdin), &data); err != nil { + fmt.Fprintf(hc.Stderr, "crush_get_input: failed to parse JSON: %v\n", err) + return interp.ExitStatus(1) + } + + if value, ok := data[fieldName]; ok && value != nil { + fmt.Fprint(hc.Stdout, formatJSONValue(value)) + } + + return nil +} + +// crushGetToolInput reads a tool input parameter from the hook context JSON. +// Usage: COMMAND=$(crush_get_tool_input "command") +func crushGetToolInput(ctx context.Context, args []string) error { + hc := interp.HandlerCtx(ctx) + + if len(args) != 2 { + fmt.Fprintln(hc.Stderr, "Usage: crush_get_tool_input ") + return interp.ExitStatus(1) + } + + paramName := args[1] + stdin := hc.Env.Get("_CRUSH_STDIN").Str + + var data map[string]any + if err := json.Unmarshal([]byte(stdin), &data); err != nil { + fmt.Fprintf(hc.Stderr, "crush_get_tool_input: failed to parse JSON: %v\n", err) + return interp.ExitStatus(1) + } + + toolInput, ok := data["tool_input"].(map[string]any) + if !ok { + return nil + } + + if value, ok := toolInput[paramName]; ok && value != nil { + fmt.Fprint(hc.Stdout, formatJSONValue(value)) + } + + return nil +} + +// crushGetPrompt reads the user prompt from the hook context JSON. +// Usage: PROMPT=$(crush_get_prompt) +func crushGetPrompt(ctx context.Context, args []string) error { + hc := interp.HandlerCtx(ctx) + + stdin := hc.Env.Get("_CRUSH_STDIN").Str + + var data map[string]any + if err := json.Unmarshal([]byte(stdin), &data); err != nil { + fmt.Fprintf(hc.Stderr, "crush_get_prompt: failed to parse JSON: %v\n", err) + return interp.ExitStatus(1) + } + + if prompt, ok := data["prompt"]; ok && prompt != nil { + fmt.Fprint(hc.Stdout, formatJSONValue(prompt)) + } + + return nil +} + +// crushLog writes a log message using slog.Debug. +// Usage: crush_log "debug message" +func crushLog(ctx context.Context, args []string) error { + if len(args) < 2 { + return nil + } + + slog.Debug(joinArgs(args[1:])) + return nil +} + +// formatJSONValue converts a JSON value to a string suitable for shell output. +func formatJSONValue(value any) string { + switch v := value.(type) { + case string: + return v + case float64: + // JSON numbers are float64 by default + if v == float64(int64(v)) { + return fmt.Sprintf("%d", int64(v)) + } + return fmt.Sprintf("%v", v) + case bool: + return fmt.Sprintf("%t", v) + case nil: + return "" + default: + // For complex types (arrays, objects), return JSON representation + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(b) + } +} + +// joinArgs joins arguments with spaces. +func joinArgs(args []string) string { + if len(args) == 0 { + return "" + } + result := args[0] + for _, arg := range args[1:] { + result += " " + arg + } + return result +} + +// RegisterBuiltins returns an ExecHandlerFunc that registers all Crush hook builtins. +func RegisterBuiltins(next interp.ExecHandlerFunc) interp.ExecHandlerFunc { + builtins := map[string]func(context.Context, []string) error{ + "crush_get_input": crushGetInput, + "crush_get_tool_input": crushGetToolInput, + "crush_get_prompt": crushGetPrompt, + "crush_log": crushLog, + } + + return func(ctx context.Context, args []string) error { + if len(args) == 0 { + return next(ctx, args) + } + + if fn, ok := builtins[args[0]]; ok { + return fn(ctx, args) + } + + return next(ctx, args) + } +} diff --git a/internal/hooks/builtins_test.go b/internal/hooks/builtins_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2a1bff29c093b9638ac7ac83537492c913376cad --- /dev/null +++ b/internal/hooks/builtins_test.go @@ -0,0 +1,185 @@ +package hooks + +import ( + "context" + "strings" + "testing" + + "github.com/charmbracelet/crush/internal/shell" + "github.com/stretchr/testify/require" + "mvdan.cc/sh/v3/interp" +) + +func TestBuiltinsIntegration(t *testing.T) { + t.Parallel() + + jsonInput := `{ + "prompt": "test prompt", + "tool_input": { + "command": "ls -la", + "offset": 100 + }, + "custom_field": "custom_value" + }` + + script := ` +PROMPT=$(crush_get_prompt) +COMMAND=$(crush_get_tool_input "command") +OFFSET=$(crush_get_tool_input "offset") +CUSTOM=$(crush_get_input "custom_field") + +echo "prompt=$PROMPT" +echo "command=$COMMAND" +echo "offset=$OFFSET" +echo "custom=$CUSTOM" + +crush_log "Processing complete" +` + + hookShell := shell.NewShell(&shell.Options{ + WorkingDir: t.TempDir(), + ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins}, + }) + + // Need to set _CRUSH_STDIN before running the script + stdin := strings.NewReader(jsonInput) + setupScript := ` +_CRUSH_STDIN=$(cat) +export _CRUSH_STDIN +` + script + + stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin) + + require.NoError(t, err) + require.Contains(t, stdout, "prompt=test prompt") + require.Contains(t, stdout, "command=ls -la") + require.Contains(t, stdout, "offset=100") + require.Contains(t, stdout, "custom=custom_value") +} + +func TestBuiltinErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + script string + stdin string + wantErr bool + }{ + { + name: "invalid json", + script: `crush_get_input "field"`, + stdin: `{invalid}`, + wantErr: true, + }, + { + name: "wrong number of args", + script: `crush_get_input`, + stdin: `{"field":"value"}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + hookShell := shell.NewShell(&shell.Options{ + WorkingDir: t.TempDir(), + ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins}, + }) + + setupScript := ` +_CRUSH_STDIN=$(cat) +export _CRUSH_STDIN +` + tt.script + + stdin := strings.NewReader(tt.stdin) + _, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestBuiltinWithMissingFields(t *testing.T) { + t.Parallel() + + jsonInput := `{"prompt": "test"}` + + script := ` +MISSING=$(crush_get_input "missing_field") +TOOL_MISSING=$(crush_get_tool_input "missing_param") + +if [ -z "$MISSING" ]; then + echo "missing is empty" +fi + +if [ -z "$TOOL_MISSING" ]; then + echo "tool_missing is empty" +fi +` + + hookShell := shell.NewShell(&shell.Options{ + WorkingDir: t.TempDir(), + ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins}, + }) + + stdin := strings.NewReader(jsonInput) + setupScript := ` +_CRUSH_STDIN=$(cat) +export _CRUSH_STDIN +` + script + + stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin) + + require.NoError(t, err) + require.Contains(t, stdout, "missing is empty") + require.Contains(t, stdout, "tool_missing is empty") +} + +func TestBuiltinWithComplexTypes(t *testing.T) { + t.Parallel() + + jsonInput := `{ + "array_field": [1, 2, 3], + "object_field": {"key": "value"}, + "bool_field": true, + "null_field": null + }` + + script := ` +ARRAY=$(crush_get_input "array_field") +OBJECT=$(crush_get_input "object_field") +BOOL=$(crush_get_input "bool_field") +NULL=$(crush_get_input "null_field") + +echo "array=$ARRAY" +echo "object=$OBJECT" +echo "bool=$BOOL" +echo "null=$NULL" +` + + hookShell := shell.NewShell(&shell.Options{ + WorkingDir: t.TempDir(), + ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins}, + }) + + stdin := strings.NewReader(jsonInput) + setupScript := ` +_CRUSH_STDIN=$(cat) +export _CRUSH_STDIN +` + script + + stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin) + + require.NoError(t, err) + require.Contains(t, stdout, "array=[1,2,3]") + require.Contains(t, stdout, `object={"key":"value"}`) + require.Contains(t, stdout, "bool=true") + require.Contains(t, stdout, "null=") +} diff --git a/internal/hooks/executor.go b/internal/hooks/executor.go index c037370138e4e3b439e2f607ed0777748247dc9c..140f683df30500d13ff016c7468da623bf26f3ba 100644 --- a/internal/hooks/executor.go +++ b/internal/hooks/executor.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/charmbracelet/crush/internal/shell" + "mvdan.cc/sh/v3/interp" ) //go:embed helpers.sh @@ -36,7 +37,7 @@ func (e *Executor) Execute(ctx context.Context, hookPath string, context HookCon return nil, fmt.Errorf("failed to marshal context: %w", err) } - // Wrap user hook in a function and prepend helper functions + // Wrap user hook in a function and prepend helper functions // Read stdin before calling the function, then export it fullScript := fmt.Sprintf(`%s @@ -69,8 +70,9 @@ _crush_hook_main } hookShell := shell.NewShell(&shell.Options{ - WorkingDir: context.WorkingDir, - Env: env, + WorkingDir: context.WorkingDir, + Env: env, + ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins}, }) // Pass JSON context via stdin instead of heredoc diff --git a/internal/hooks/helpers.sh b/internal/hooks/helpers.sh index b19749628e42e1310a8dfe5cfcd8843c82c257e0..21ef57e4aad4b3a6944d3b3a885ac3ccbb43a32c 100644 --- a/internal/hooks/helpers.sh +++ b/internal/hooks/helpers.sh @@ -92,30 +92,3 @@ crush_stop() { exit 1 } -# Input parsing helpers -# These read from the JSON context saved in _CRUSH_STDIN - -# Get a field from the hook context. -# Usage: VALUE=$(crush_get_input "field_name") -crush_get_input() { - echo "$_CRUSH_STDIN" | jq -r ".$1 // empty" -} - -# Get a tool input parameter. -# Usage: COMMAND=$(crush_get_tool_input "command") -crush_get_tool_input() { - echo "$_CRUSH_STDIN" | jq -r ".tool_input.$1 // empty" -} - -# Get the user prompt. -# Usage: PROMPT=$(crush_get_prompt) -crush_get_prompt() { - echo "$_CRUSH_STDIN" | jq -r '.prompt // empty' -} - -# Logging helper. -# Writes to stderr which is captured by Crush. -# Usage: crush_log "debug message" -crush_log() { - echo "[CRUSH HOOK] $*" >&2 -} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 39ee77226177f7b0cfed56757654b32e484f4bfa..9bb78620b8da49c0076be10e669d02f6f22f94dd 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -51,19 +51,21 @@ type BlockFunc func(args []string) bool // Shell provides cross-platform shell execution with optional state persistence type Shell struct { - env []string - cwd string - mu sync.Mutex - logger Logger - blockFuncs []BlockFunc + env []string + cwd string + mu sync.Mutex + logger Logger + blockFuncs []BlockFunc + customExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc } // Options for creating a new shell type Options struct { - WorkingDir string - Env []string - Logger Logger - BlockFuncs []BlockFunc + WorkingDir string + Env []string + Logger Logger + BlockFuncs []BlockFunc + ExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc } // NewShell creates a new shell instance with the given options @@ -88,10 +90,11 @@ func NewShell(opts *Options) *Shell { } return &Shell{ - cwd: cwd, - env: env, - logger: logger, - blockFuncs: opts.BlockFuncs, + cwd: cwd, + env: env, + logger: logger, + blockFuncs: opts.BlockFuncs, + customExecHandlers: opts.ExecHandlers, } } @@ -246,13 +249,15 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand // newInterp creates a new interpreter with the current shell state func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) { - return interp.New( + opts := []interp.RunnerOption{ interp.StdIO(stdin, stdout, stderr), interp.Interactive(false), interp.Env(expand.ListEnviron(s.env...)), interp.Dir(s.cwd), interp.ExecHandlers(s.execHandlers()...), - ) + } + + return interp.New(opts...) } // updateShellFromRunner updates the shell from the interpreter after execution @@ -298,6 +303,8 @@ func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHa handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{ s.blockHandler(), } + // Add custom exec handlers first (they get priority) + handlers = append(handlers, s.customExecHandlers...) if useGoCoreUtils { handlers = append(handlers, coreutils.ExecHandler) }