Detailed changes
@@ -0,0 +1,157 @@
+package hooks
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+
+ "mvdan.cc/sh/v3/interp"
+)
+
+// crushGetInput reads a field from the hook context JSON.
+// Usage: VALUE=$(crush_get_input "field_name")
+func crushGetInput(ctx context.Context, args []string) error {
+ hc := interp.HandlerCtx(ctx)
+
+ if len(args) != 2 {
+ fmt.Fprintln(hc.Stderr, "Usage: crush_get_input <field_name>")
+ return interp.ExitStatus(1)
+ }
+
+ fieldName := args[1]
+ stdin := hc.Env.Get("_CRUSH_STDIN").Str
+
+ var data map[string]any
+ if err := json.Unmarshal([]byte(stdin), &data); err != nil {
+ fmt.Fprintf(hc.Stderr, "crush_get_input: failed to parse JSON: %v\n", err)
+ return interp.ExitStatus(1)
+ }
+
+ if value, ok := data[fieldName]; ok && value != nil {
+ fmt.Fprint(hc.Stdout, formatJSONValue(value))
+ }
+
+ return nil
+}
+
+// crushGetToolInput reads a tool input parameter from the hook context JSON.
+// Usage: COMMAND=$(crush_get_tool_input "command")
+func crushGetToolInput(ctx context.Context, args []string) error {
+ hc := interp.HandlerCtx(ctx)
+
+ if len(args) != 2 {
+ fmt.Fprintln(hc.Stderr, "Usage: crush_get_tool_input <param_name>")
+ return interp.ExitStatus(1)
+ }
+
+ paramName := args[1]
+ stdin := hc.Env.Get("_CRUSH_STDIN").Str
+
+ var data map[string]any
+ if err := json.Unmarshal([]byte(stdin), &data); err != nil {
+ fmt.Fprintf(hc.Stderr, "crush_get_tool_input: failed to parse JSON: %v\n", err)
+ return interp.ExitStatus(1)
+ }
+
+ toolInput, ok := data["tool_input"].(map[string]any)
+ if !ok {
+ return nil
+ }
+
+ if value, ok := toolInput[paramName]; ok && value != nil {
+ fmt.Fprint(hc.Stdout, formatJSONValue(value))
+ }
+
+ return nil
+}
+
+// crushGetPrompt reads the user prompt from the hook context JSON.
+// Usage: PROMPT=$(crush_get_prompt)
+func crushGetPrompt(ctx context.Context, args []string) error {
+ hc := interp.HandlerCtx(ctx)
+
+ stdin := hc.Env.Get("_CRUSH_STDIN").Str
+
+ var data map[string]any
+ if err := json.Unmarshal([]byte(stdin), &data); err != nil {
+ fmt.Fprintf(hc.Stderr, "crush_get_prompt: failed to parse JSON: %v\n", err)
+ return interp.ExitStatus(1)
+ }
+
+ if prompt, ok := data["prompt"]; ok && prompt != nil {
+ fmt.Fprint(hc.Stdout, formatJSONValue(prompt))
+ }
+
+ return nil
+}
+
+// crushLog writes a log message using slog.Debug.
+// Usage: crush_log "debug message"
+func crushLog(ctx context.Context, args []string) error {
+ if len(args) < 2 {
+ return nil
+ }
+
+ slog.Debug(joinArgs(args[1:]))
+ return nil
+}
+
+// formatJSONValue converts a JSON value to a string suitable for shell output.
+func formatJSONValue(value any) string {
+ switch v := value.(type) {
+ case string:
+ return v
+ case float64:
+ // JSON numbers are float64 by default
+ if v == float64(int64(v)) {
+ return fmt.Sprintf("%d", int64(v))
+ }
+ return fmt.Sprintf("%v", v)
+ case bool:
+ return fmt.Sprintf("%t", v)
+ case nil:
+ return ""
+ default:
+ // For complex types (arrays, objects), return JSON representation
+ b, err := json.Marshal(v)
+ if err != nil {
+ return fmt.Sprintf("%v", v)
+ }
+ return string(b)
+ }
+}
+
+// joinArgs joins arguments with spaces.
+func joinArgs(args []string) string {
+ if len(args) == 0 {
+ return ""
+ }
+ result := args[0]
+ for _, arg := range args[1:] {
+ result += " " + arg
+ }
+ return result
+}
+
+// RegisterBuiltins returns an ExecHandlerFunc that registers all Crush hook builtins.
+func RegisterBuiltins(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+ builtins := map[string]func(context.Context, []string) error{
+ "crush_get_input": crushGetInput,
+ "crush_get_tool_input": crushGetToolInput,
+ "crush_get_prompt": crushGetPrompt,
+ "crush_log": crushLog,
+ }
+
+ return func(ctx context.Context, args []string) error {
+ if len(args) == 0 {
+ return next(ctx, args)
+ }
+
+ if fn, ok := builtins[args[0]]; ok {
+ return fn(ctx, args)
+ }
+
+ return next(ctx, args)
+ }
+}
@@ -0,0 +1,185 @@
+package hooks
+
+import (
+ "context"
+ "strings"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/shell"
+ "github.com/stretchr/testify/require"
+ "mvdan.cc/sh/v3/interp"
+)
+
+func TestBuiltinsIntegration(t *testing.T) {
+ t.Parallel()
+
+ jsonInput := `{
+ "prompt": "test prompt",
+ "tool_input": {
+ "command": "ls -la",
+ "offset": 100
+ },
+ "custom_field": "custom_value"
+ }`
+
+ script := `
+PROMPT=$(crush_get_prompt)
+COMMAND=$(crush_get_tool_input "command")
+OFFSET=$(crush_get_tool_input "offset")
+CUSTOM=$(crush_get_input "custom_field")
+
+echo "prompt=$PROMPT"
+echo "command=$COMMAND"
+echo "offset=$OFFSET"
+echo "custom=$CUSTOM"
+
+crush_log "Processing complete"
+`
+
+ hookShell := shell.NewShell(&shell.Options{
+ WorkingDir: t.TempDir(),
+ ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
+ })
+
+ // Need to set _CRUSH_STDIN before running the script
+ stdin := strings.NewReader(jsonInput)
+ setupScript := `
+_CRUSH_STDIN=$(cat)
+export _CRUSH_STDIN
+` + script
+
+ stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
+
+ require.NoError(t, err)
+ require.Contains(t, stdout, "prompt=test prompt")
+ require.Contains(t, stdout, "command=ls -la")
+ require.Contains(t, stdout, "offset=100")
+ require.Contains(t, stdout, "custom=custom_value")
+}
+
+func TestBuiltinErrors(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ script string
+ stdin string
+ wantErr bool
+ }{
+ {
+ name: "invalid json",
+ script: `crush_get_input "field"`,
+ stdin: `{invalid}`,
+ wantErr: true,
+ },
+ {
+ name: "wrong number of args",
+ script: `crush_get_input`,
+ stdin: `{"field":"value"}`,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ hookShell := shell.NewShell(&shell.Options{
+ WorkingDir: t.TempDir(),
+ ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
+ })
+
+ setupScript := `
+_CRUSH_STDIN=$(cat)
+export _CRUSH_STDIN
+` + tt.script
+
+ stdin := strings.NewReader(tt.stdin)
+ _, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
+
+ if tt.wantErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestBuiltinWithMissingFields(t *testing.T) {
+ t.Parallel()
+
+ jsonInput := `{"prompt": "test"}`
+
+ script := `
+MISSING=$(crush_get_input "missing_field")
+TOOL_MISSING=$(crush_get_tool_input "missing_param")
+
+if [ -z "$MISSING" ]; then
+ echo "missing is empty"
+fi
+
+if [ -z "$TOOL_MISSING" ]; then
+ echo "tool_missing is empty"
+fi
+`
+
+ hookShell := shell.NewShell(&shell.Options{
+ WorkingDir: t.TempDir(),
+ ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
+ })
+
+ stdin := strings.NewReader(jsonInput)
+ setupScript := `
+_CRUSH_STDIN=$(cat)
+export _CRUSH_STDIN
+` + script
+
+ stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
+
+ require.NoError(t, err)
+ require.Contains(t, stdout, "missing is empty")
+ require.Contains(t, stdout, "tool_missing is empty")
+}
+
+func TestBuiltinWithComplexTypes(t *testing.T) {
+ t.Parallel()
+
+ jsonInput := `{
+ "array_field": [1, 2, 3],
+ "object_field": {"key": "value"},
+ "bool_field": true,
+ "null_field": null
+ }`
+
+ script := `
+ARRAY=$(crush_get_input "array_field")
+OBJECT=$(crush_get_input "object_field")
+BOOL=$(crush_get_input "bool_field")
+NULL=$(crush_get_input "null_field")
+
+echo "array=$ARRAY"
+echo "object=$OBJECT"
+echo "bool=$BOOL"
+echo "null=$NULL"
+`
+
+ hookShell := shell.NewShell(&shell.Options{
+ WorkingDir: t.TempDir(),
+ ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
+ })
+
+ stdin := strings.NewReader(jsonInput)
+ setupScript := `
+_CRUSH_STDIN=$(cat)
+export _CRUSH_STDIN
+` + script
+
+ stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
+
+ require.NoError(t, err)
+ require.Contains(t, stdout, "array=[1,2,3]")
+ require.Contains(t, stdout, `object={"key":"value"}`)
+ require.Contains(t, stdout, "bool=true")
+ require.Contains(t, stdout, "null=")
+}
@@ -9,6 +9,7 @@ import (
"strings"
"github.com/charmbracelet/crush/internal/shell"
+ "mvdan.cc/sh/v3/interp"
)
//go:embed helpers.sh
@@ -36,7 +37,7 @@ func (e *Executor) Execute(ctx context.Context, hookPath string, context HookCon
return nil, fmt.Errorf("failed to marshal context: %w", err)
}
- // Wrap user hook in a function and prepend helper functions
+ // Wrap user hook in a function and prepend helper functions
// Read stdin before calling the function, then export it
fullScript := fmt.Sprintf(`%s
@@ -69,8 +70,9 @@ _crush_hook_main
}
hookShell := shell.NewShell(&shell.Options{
- WorkingDir: context.WorkingDir,
- Env: env,
+ WorkingDir: context.WorkingDir,
+ Env: env,
+ ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
})
// Pass JSON context via stdin instead of heredoc
@@ -92,30 +92,3 @@ crush_stop() {
exit 1
}
-# Input parsing helpers
-# These read from the JSON context saved in _CRUSH_STDIN
-
-# Get a field from the hook context.
-# Usage: VALUE=$(crush_get_input "field_name")
-crush_get_input() {
- echo "$_CRUSH_STDIN" | jq -r ".$1 // empty"
-}
-
-# Get a tool input parameter.
-# Usage: COMMAND=$(crush_get_tool_input "command")
-crush_get_tool_input() {
- echo "$_CRUSH_STDIN" | jq -r ".tool_input.$1 // empty"
-}
-
-# Get the user prompt.
-# Usage: PROMPT=$(crush_get_prompt)
-crush_get_prompt() {
- echo "$_CRUSH_STDIN" | jq -r '.prompt // empty'
-}
-
-# Logging helper.
-# Writes to stderr which is captured by Crush.
-# Usage: crush_log "debug message"
-crush_log() {
- echo "[CRUSH HOOK] $*" >&2
-}
@@ -51,19 +51,21 @@ type BlockFunc func(args []string) bool
// Shell provides cross-platform shell execution with optional state persistence
type Shell struct {
- env []string
- cwd string
- mu sync.Mutex
- logger Logger
- blockFuncs []BlockFunc
+ env []string
+ cwd string
+ mu sync.Mutex
+ logger Logger
+ blockFuncs []BlockFunc
+ customExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc
}
// Options for creating a new shell
type Options struct {
- WorkingDir string
- Env []string
- Logger Logger
- BlockFuncs []BlockFunc
+ WorkingDir string
+ Env []string
+ Logger Logger
+ BlockFuncs []BlockFunc
+ ExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc
}
// NewShell creates a new shell instance with the given options
@@ -88,10 +90,11 @@ func NewShell(opts *Options) *Shell {
}
return &Shell{
- cwd: cwd,
- env: env,
- logger: logger,
- blockFuncs: opts.BlockFuncs,
+ cwd: cwd,
+ env: env,
+ logger: logger,
+ blockFuncs: opts.BlockFuncs,
+ customExecHandlers: opts.ExecHandlers,
}
}
@@ -246,13 +249,15 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand
// newInterp creates a new interpreter with the current shell state
func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
- return interp.New(
+ opts := []interp.RunnerOption{
interp.StdIO(stdin, stdout, stderr),
interp.Interactive(false),
interp.Env(expand.ListEnviron(s.env...)),
interp.Dir(s.cwd),
interp.ExecHandlers(s.execHandlers()...),
- )
+ }
+
+ return interp.New(opts...)
}
// updateShellFromRunner updates the shell from the interpreter after execution
@@ -298,6 +303,8 @@ func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHa
handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
s.blockHandler(),
}
+ // Add custom exec handlers first (they get priority)
+ handlers = append(handlers, s.customExecHandlers...)
if useGoCoreUtils {
handlers = append(handlers, coreutils.ExecHandler)
}