chore: migrate to using some buildin tools

Kujtim Hoxha created

Change summary

internal/hooks/builtins.go      | 157 +++++++++++++++++++++++++++++
internal/hooks/builtins_test.go | 185 +++++++++++++++++++++++++++++++++++
internal/hooks/executor.go      |   8 
internal/hooks/helpers.sh       |  27 -----
internal/shell/shell.go         |  37 ++++--
5 files changed, 369 insertions(+), 45 deletions(-)

Detailed changes

internal/hooks/builtins.go 🔗

@@ -0,0 +1,157 @@
+package hooks
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"log/slog"
+
+	"mvdan.cc/sh/v3/interp"
+)
+
+// crushGetInput reads a field from the hook context JSON.
+// Usage: VALUE=$(crush_get_input "field_name")
+func crushGetInput(ctx context.Context, args []string) error {
+	hc := interp.HandlerCtx(ctx)
+
+	if len(args) != 2 {
+		fmt.Fprintln(hc.Stderr, "Usage: crush_get_input <field_name>")
+		return interp.ExitStatus(1)
+	}
+
+	fieldName := args[1]
+	stdin := hc.Env.Get("_CRUSH_STDIN").Str
+
+	var data map[string]any
+	if err := json.Unmarshal([]byte(stdin), &data); err != nil {
+		fmt.Fprintf(hc.Stderr, "crush_get_input: failed to parse JSON: %v\n", err)
+		return interp.ExitStatus(1)
+	}
+
+	if value, ok := data[fieldName]; ok && value != nil {
+		fmt.Fprint(hc.Stdout, formatJSONValue(value))
+	}
+
+	return nil
+}
+
+// crushGetToolInput reads a tool input parameter from the hook context JSON.
+// Usage: COMMAND=$(crush_get_tool_input "command")
+func crushGetToolInput(ctx context.Context, args []string) error {
+	hc := interp.HandlerCtx(ctx)
+
+	if len(args) != 2 {
+		fmt.Fprintln(hc.Stderr, "Usage: crush_get_tool_input <param_name>")
+		return interp.ExitStatus(1)
+	}
+
+	paramName := args[1]
+	stdin := hc.Env.Get("_CRUSH_STDIN").Str
+
+	var data map[string]any
+	if err := json.Unmarshal([]byte(stdin), &data); err != nil {
+		fmt.Fprintf(hc.Stderr, "crush_get_tool_input: failed to parse JSON: %v\n", err)
+		return interp.ExitStatus(1)
+	}
+
+	toolInput, ok := data["tool_input"].(map[string]any)
+	if !ok {
+		return nil
+	}
+
+	if value, ok := toolInput[paramName]; ok && value != nil {
+		fmt.Fprint(hc.Stdout, formatJSONValue(value))
+	}
+
+	return nil
+}
+
+// crushGetPrompt reads the user prompt from the hook context JSON.
+// Usage: PROMPT=$(crush_get_prompt)
+func crushGetPrompt(ctx context.Context, args []string) error {
+	hc := interp.HandlerCtx(ctx)
+
+	stdin := hc.Env.Get("_CRUSH_STDIN").Str
+
+	var data map[string]any
+	if err := json.Unmarshal([]byte(stdin), &data); err != nil {
+		fmt.Fprintf(hc.Stderr, "crush_get_prompt: failed to parse JSON: %v\n", err)
+		return interp.ExitStatus(1)
+	}
+
+	if prompt, ok := data["prompt"]; ok && prompt != nil {
+		fmt.Fprint(hc.Stdout, formatJSONValue(prompt))
+	}
+
+	return nil
+}
+
+// crushLog writes a log message using slog.Debug.
+// Usage: crush_log "debug message"
+func crushLog(ctx context.Context, args []string) error {
+	if len(args) < 2 {
+		return nil
+	}
+
+	slog.Debug(joinArgs(args[1:]))
+	return nil
+}
+
+// formatJSONValue converts a JSON value to a string suitable for shell output.
+func formatJSONValue(value any) string {
+	switch v := value.(type) {
+	case string:
+		return v
+	case float64:
+		// JSON numbers are float64 by default
+		if v == float64(int64(v)) {
+			return fmt.Sprintf("%d", int64(v))
+		}
+		return fmt.Sprintf("%v", v)
+	case bool:
+		return fmt.Sprintf("%t", v)
+	case nil:
+		return ""
+	default:
+		// For complex types (arrays, objects), return JSON representation
+		b, err := json.Marshal(v)
+		if err != nil {
+			return fmt.Sprintf("%v", v)
+		}
+		return string(b)
+	}
+}
+
+// joinArgs joins arguments with spaces.
+func joinArgs(args []string) string {
+	if len(args) == 0 {
+		return ""
+	}
+	result := args[0]
+	for _, arg := range args[1:] {
+		result += " " + arg
+	}
+	return result
+}
+
+// RegisterBuiltins returns an ExecHandlerFunc that registers all Crush hook builtins.
+func RegisterBuiltins(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
+	builtins := map[string]func(context.Context, []string) error{
+		"crush_get_input":      crushGetInput,
+		"crush_get_tool_input": crushGetToolInput,
+		"crush_get_prompt":     crushGetPrompt,
+		"crush_log":            crushLog,
+	}
+
+	return func(ctx context.Context, args []string) error {
+		if len(args) == 0 {
+			return next(ctx, args)
+		}
+
+		if fn, ok := builtins[args[0]]; ok {
+			return fn(ctx, args)
+		}
+
+		return next(ctx, args)
+	}
+}

internal/hooks/builtins_test.go 🔗

@@ -0,0 +1,185 @@
+package hooks
+
+import (
+	"context"
+	"strings"
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/shell"
+	"github.com/stretchr/testify/require"
+	"mvdan.cc/sh/v3/interp"
+)
+
+func TestBuiltinsIntegration(t *testing.T) {
+	t.Parallel()
+
+	jsonInput := `{
+		"prompt": "test prompt",
+		"tool_input": {
+			"command": "ls -la",
+			"offset": 100
+		},
+		"custom_field": "custom_value"
+	}`
+
+	script := `
+PROMPT=$(crush_get_prompt)
+COMMAND=$(crush_get_tool_input "command")
+OFFSET=$(crush_get_tool_input "offset")
+CUSTOM=$(crush_get_input "custom_field")
+
+echo "prompt=$PROMPT"
+echo "command=$COMMAND"
+echo "offset=$OFFSET"
+echo "custom=$CUSTOM"
+
+crush_log "Processing complete"
+`
+
+	hookShell := shell.NewShell(&shell.Options{
+		WorkingDir:   t.TempDir(),
+		ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
+	})
+
+	// Need to set _CRUSH_STDIN before running the script
+	stdin := strings.NewReader(jsonInput)
+	setupScript := `
+_CRUSH_STDIN=$(cat)
+export _CRUSH_STDIN
+` + script
+
+	stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
+
+	require.NoError(t, err)
+	require.Contains(t, stdout, "prompt=test prompt")
+	require.Contains(t, stdout, "command=ls -la")
+	require.Contains(t, stdout, "offset=100")
+	require.Contains(t, stdout, "custom=custom_value")
+}
+
+func TestBuiltinErrors(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name    string
+		script  string
+		stdin   string
+		wantErr bool
+	}{
+		{
+			name:    "invalid json",
+			script:  `crush_get_input "field"`,
+			stdin:   `{invalid}`,
+			wantErr: true,
+		},
+		{
+			name:    "wrong number of args",
+			script:  `crush_get_input`,
+			stdin:   `{"field":"value"}`,
+			wantErr: true,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+
+			hookShell := shell.NewShell(&shell.Options{
+				WorkingDir:   t.TempDir(),
+				ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
+			})
+
+			setupScript := `
+_CRUSH_STDIN=$(cat)
+export _CRUSH_STDIN
+` + tt.script
+
+			stdin := strings.NewReader(tt.stdin)
+			_, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
+
+			if tt.wantErr {
+				require.Error(t, err)
+			} else {
+				require.NoError(t, err)
+			}
+		})
+	}
+}
+
+func TestBuiltinWithMissingFields(t *testing.T) {
+	t.Parallel()
+
+	jsonInput := `{"prompt": "test"}`
+
+	script := `
+MISSING=$(crush_get_input "missing_field")
+TOOL_MISSING=$(crush_get_tool_input "missing_param")
+
+if [ -z "$MISSING" ]; then
+  echo "missing is empty"
+fi
+
+if [ -z "$TOOL_MISSING" ]; then
+  echo "tool_missing is empty"
+fi
+`
+
+	hookShell := shell.NewShell(&shell.Options{
+		WorkingDir:   t.TempDir(),
+		ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
+	})
+
+	stdin := strings.NewReader(jsonInput)
+	setupScript := `
+_CRUSH_STDIN=$(cat)
+export _CRUSH_STDIN
+` + script
+
+	stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
+
+	require.NoError(t, err)
+	require.Contains(t, stdout, "missing is empty")
+	require.Contains(t, stdout, "tool_missing is empty")
+}
+
+func TestBuiltinWithComplexTypes(t *testing.T) {
+	t.Parallel()
+
+	jsonInput := `{
+		"array_field": [1, 2, 3],
+		"object_field": {"key": "value"},
+		"bool_field": true,
+		"null_field": null
+	}`
+
+	script := `
+ARRAY=$(crush_get_input "array_field")
+OBJECT=$(crush_get_input "object_field")
+BOOL=$(crush_get_input "bool_field")
+NULL=$(crush_get_input "null_field")
+
+echo "array=$ARRAY"
+echo "object=$OBJECT"
+echo "bool=$BOOL"
+echo "null=$NULL"
+`
+
+	hookShell := shell.NewShell(&shell.Options{
+		WorkingDir:   t.TempDir(),
+		ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
+	})
+
+	stdin := strings.NewReader(jsonInput)
+	setupScript := `
+_CRUSH_STDIN=$(cat)
+export _CRUSH_STDIN
+` + script
+
+	stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
+
+	require.NoError(t, err)
+	require.Contains(t, stdout, "array=[1,2,3]")
+	require.Contains(t, stdout, `object={"key":"value"}`)
+	require.Contains(t, stdout, "bool=true")
+	require.Contains(t, stdout, "null=")
+}

internal/hooks/executor.go 🔗

@@ -9,6 +9,7 @@ import (
 	"strings"
 
 	"github.com/charmbracelet/crush/internal/shell"
+	"mvdan.cc/sh/v3/interp"
 )
 
 //go:embed helpers.sh
@@ -36,7 +37,7 @@ func (e *Executor) Execute(ctx context.Context, hookPath string, context HookCon
 		return nil, fmt.Errorf("failed to marshal context: %w", err)
 	}
 
-	// Wrap user hook in a function and prepend helper functions  
+	// Wrap user hook in a function and prepend helper functions
 	// Read stdin before calling the function, then export it
 	fullScript := fmt.Sprintf(`%s
 
@@ -69,8 +70,9 @@ _crush_hook_main
 	}
 
 	hookShell := shell.NewShell(&shell.Options{
-		WorkingDir: context.WorkingDir,
-		Env:        env,
+		WorkingDir:   context.WorkingDir,
+		Env:          env,
+		ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
 	})
 
 	// Pass JSON context via stdin instead of heredoc

internal/hooks/helpers.sh 🔗

@@ -92,30 +92,3 @@ crush_stop() {
   exit 1
 }
 
-# Input parsing helpers
-# These read from the JSON context saved in _CRUSH_STDIN
-
-# Get a field from the hook context.
-# Usage: VALUE=$(crush_get_input "field_name")
-crush_get_input() {
-  echo "$_CRUSH_STDIN" | jq -r ".$1 // empty"
-}
-
-# Get a tool input parameter.
-# Usage: COMMAND=$(crush_get_tool_input "command")
-crush_get_tool_input() {
-  echo "$_CRUSH_STDIN" | jq -r ".tool_input.$1 // empty"
-}
-
-# Get the user prompt.
-# Usage: PROMPT=$(crush_get_prompt)
-crush_get_prompt() {
-  echo "$_CRUSH_STDIN" | jq -r '.prompt // empty'
-}
-
-# Logging helper.
-# Writes to stderr which is captured by Crush.
-# Usage: crush_log "debug message"
-crush_log() {
-  echo "[CRUSH HOOK] $*" >&2
-}

internal/shell/shell.go 🔗

@@ -51,19 +51,21 @@ type BlockFunc func(args []string) bool
 
 // Shell provides cross-platform shell execution with optional state persistence
 type Shell struct {
-	env        []string
-	cwd        string
-	mu         sync.Mutex
-	logger     Logger
-	blockFuncs []BlockFunc
+	env                []string
+	cwd                string
+	mu                 sync.Mutex
+	logger             Logger
+	blockFuncs         []BlockFunc
+	customExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc
 }
 
 // Options for creating a new shell
 type Options struct {
-	WorkingDir string
-	Env        []string
-	Logger     Logger
-	BlockFuncs []BlockFunc
+	WorkingDir   string
+	Env          []string
+	Logger       Logger
+	BlockFuncs   []BlockFunc
+	ExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc
 }
 
 // NewShell creates a new shell instance with the given options
@@ -88,10 +90,11 @@ func NewShell(opts *Options) *Shell {
 	}
 
 	return &Shell{
-		cwd:        cwd,
-		env:        env,
-		logger:     logger,
-		blockFuncs: opts.BlockFuncs,
+		cwd:                cwd,
+		env:                env,
+		logger:             logger,
+		blockFuncs:         opts.BlockFuncs,
+		customExecHandlers: opts.ExecHandlers,
 	}
 }
 
@@ -246,13 +249,15 @@ func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHand
 
 // newInterp creates a new interpreter with the current shell state
 func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
-	return interp.New(
+	opts := []interp.RunnerOption{
 		interp.StdIO(stdin, stdout, stderr),
 		interp.Interactive(false),
 		interp.Env(expand.ListEnviron(s.env...)),
 		interp.Dir(s.cwd),
 		interp.ExecHandlers(s.execHandlers()...),
-	)
+	}
+
+	return interp.New(opts...)
 }
 
 // updateShellFromRunner updates the shell from the interpreter after execution
@@ -298,6 +303,8 @@ func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHa
 	handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
 		s.blockHandler(),
 	}
+	// Add custom exec handlers first (they get priority)
+	handlers = append(handlers, s.customExecHandlers...)
 	if useGoCoreUtils {
 		handlers = append(handlers, coreutils.ExecHandler)
 	}