1package hooks
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8
9 "mvdan.cc/sh/v3/interp"
10)
11
12// crushGetInput reads a field from the hook context JSON.
13// Usage: VALUE=$(crush_get_input "field_name")
14func crushGetInput(ctx context.Context, args []string) error {
15 hc := interp.HandlerCtx(ctx)
16
17 if len(args) != 2 {
18 fmt.Fprintln(hc.Stderr, "Usage: crush_get_input <field_name>")
19 return interp.ExitStatus(1)
20 }
21
22 fieldName := args[1]
23 stdin := hc.Env.Get("_CRUSH_STDIN").Str
24
25 var data map[string]any
26 if err := json.Unmarshal([]byte(stdin), &data); err != nil {
27 fmt.Fprintf(hc.Stderr, "crush_get_input: failed to parse JSON: %v\n", err)
28 return interp.ExitStatus(1)
29 }
30
31 if value, ok := data[fieldName]; ok && value != nil {
32 fmt.Fprint(hc.Stdout, formatJSONValue(value))
33 }
34
35 return nil
36}
37
38// crushGetToolInput reads a tool input parameter from the hook context JSON.
39// Usage: COMMAND=$(crush_get_tool_input "command")
40func crushGetToolInput(ctx context.Context, args []string) error {
41 hc := interp.HandlerCtx(ctx)
42
43 if len(args) != 2 {
44 fmt.Fprintln(hc.Stderr, "Usage: crush_get_tool_input <param_name>")
45 return interp.ExitStatus(1)
46 }
47
48 paramName := args[1]
49 stdin := hc.Env.Get("_CRUSH_STDIN").Str
50
51 var data map[string]any
52 if err := json.Unmarshal([]byte(stdin), &data); err != nil {
53 fmt.Fprintf(hc.Stderr, "crush_get_tool_input: failed to parse JSON: %v\n", err)
54 return interp.ExitStatus(1)
55 }
56
57 toolInput, ok := data["tool_input"].(map[string]any)
58 if !ok {
59 return nil
60 }
61
62 if value, ok := toolInput[paramName]; ok && value != nil {
63 fmt.Fprint(hc.Stdout, formatJSONValue(value))
64 }
65
66 return nil
67}
68
69// crushGetPrompt reads the user prompt from the hook context JSON.
70// Usage: PROMPT=$(crush_get_prompt)
71func crushGetPrompt(ctx context.Context, args []string) error {
72 hc := interp.HandlerCtx(ctx)
73
74 stdin := hc.Env.Get("_CRUSH_STDIN").Str
75
76 var data map[string]any
77 if err := json.Unmarshal([]byte(stdin), &data); err != nil {
78 fmt.Fprintf(hc.Stderr, "crush_get_prompt: failed to parse JSON: %v\n", err)
79 return interp.ExitStatus(1)
80 }
81
82 if prompt, ok := data["prompt"]; ok && prompt != nil {
83 fmt.Fprint(hc.Stdout, formatJSONValue(prompt))
84 }
85
86 return nil
87}
88
89// crushLog writes a log message using slog.Debug.
90// Usage: crush_log "debug message"
91func crushLog(ctx context.Context, args []string) error {
92 if len(args) < 2 {
93 return nil
94 }
95
96 slog.Debug(joinArgs(args[1:]))
97 return nil
98}
99
100// formatJSONValue converts a JSON value to a string suitable for shell output.
101func formatJSONValue(value any) string {
102 switch v := value.(type) {
103 case string:
104 return v
105 case float64:
106 // JSON numbers are float64 by default
107 if v == float64(int64(v)) {
108 return fmt.Sprintf("%d", int64(v))
109 }
110 return fmt.Sprintf("%v", v)
111 case bool:
112 return fmt.Sprintf("%t", v)
113 case nil:
114 return ""
115 default:
116 // For complex types (arrays, objects), return JSON representation
117 b, err := json.Marshal(v)
118 if err != nil {
119 return fmt.Sprintf("%v", v)
120 }
121 return string(b)
122 }
123}
124
125// joinArgs joins arguments with spaces.
126func joinArgs(args []string) string {
127 if len(args) == 0 {
128 return ""
129 }
130 result := args[0]
131 for _, arg := range args[1:] {
132 result += " " + arg
133 }
134 return result
135}
136
137// RegisterBuiltins returns an ExecHandlerFunc that registers all Crush hook builtins.
138func RegisterBuiltins(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
139 builtins := map[string]func(context.Context, []string) error{
140 "crush_get_input": crushGetInput,
141 "crush_get_tool_input": crushGetToolInput,
142 "crush_get_prompt": crushGetPrompt,
143 "crush_log": crushLog,
144 }
145
146 return func(ctx context.Context, args []string) error {
147 if len(args) == 0 {
148 return next(ctx, args)
149 }
150
151 if fn, ok := builtins[args[0]]; ok {
152 return fn(ctx, args)
153 }
154
155 return next(ctx, args)
156 }
157}