1package hooks
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8
9 "mvdan.cc/sh/v3/interp"
10)
11
12// crushGetInput reads a field from the hook context JSON.
13// Usage: VALUE=$(crush_get_input "field_name")
14func crushGetInput(ctx context.Context, args []string) error {
15 hc := interp.HandlerCtx(ctx)
16
17 if len(args) != 2 {
18 fmt.Fprintf(hc.Stderr, "Wrong number of arguments to `crush_get_input`. Expected 2, got %d.\n", len(args))
19 return interp.ExitStatus(1)
20 }
21
22 fieldName := args[1]
23 stdin := hc.Env.Get("_CRUSH_STDIN").Str
24
25 var data map[string]any
26 if err := json.Unmarshal([]byte(stdin), &data); err != nil {
27 fmt.Fprintf(hc.Stderr, "crush_get_input: failed to parse JSON: %v\n", err)
28 return interp.ExitStatus(1)
29 }
30
31 if value, ok := data[fieldName]; ok && value != nil {
32 fmt.Fprint(hc.Stdout, formatJSONValue(value))
33 }
34
35 return nil
36}
37
38// crushGetToolInput reads a tool input parameter from the hook context JSON.
39// Usage: COMMAND=$(crush_get_tool_input "command")
40func crushGetToolInput(ctx context.Context, args []string) error {
41 hc := interp.HandlerCtx(ctx)
42 if len(args) != 2 {
43 fmt.Fprintf(hc.Stderr, "Wrong number of arguments to `crush_get_tool_input`. Expected 2, got %d.\n", len(args))
44 return interp.ExitStatus(1)
45 }
46
47 paramName := args[1]
48 stdin := hc.Env.Get("_CRUSH_STDIN").Str
49
50 var data map[string]any
51 if err := json.Unmarshal([]byte(stdin), &data); err != nil {
52 fmt.Fprintf(hc.Stderr, "crush_get_tool_input: failed to parse JSON: %v\n", err)
53 return interp.ExitStatus(1)
54 }
55
56 toolInput, ok := data["tool_input"].(map[string]any)
57 if !ok {
58 return nil
59 }
60
61 if value, ok := toolInput[paramName]; ok && value != nil {
62 fmt.Fprint(hc.Stdout, formatJSONValue(value))
63 }
64
65 return nil
66}
67
68// crushGetPrompt reads the user prompt from the hook context JSON.
69// Usage: PROMPT=$(crush_get_prompt)
70func crushGetPrompt(ctx context.Context, args []string) error {
71 hc := interp.HandlerCtx(ctx)
72
73 if len(args) != 1 {
74 fmt.Fprintf(hc.Stderr, "Wrong number of arguments to `crush_get_prompt`. Expected 1, got %d.\n", len(args))
75 return interp.ExitStatus(1)
76 }
77
78 stdin := hc.Env.Get("_CRUSH_STDIN").Str
79
80 var data map[string]any
81 if err := json.Unmarshal([]byte(stdin), &data); err != nil {
82 fmt.Fprintf(hc.Stderr, "crush_get_prompt: failed to parse JSON: %v\n", err)
83 return interp.ExitStatus(1)
84 }
85
86 if prompt, ok := data["prompt"]; ok && prompt != nil {
87 fmt.Fprint(hc.Stdout, formatJSONValue(prompt))
88 }
89
90 return nil
91}
92
93// crushLog writes a log message using slog.Debug.
94// Usage: crush_log "debug message"
95func crushLog(ctx context.Context, args []string) error {
96 switch len(args) {
97 case 0, 1:
98 return nil
99 default:
100 slog.Debug(joinArgs(args[1:]))
101 return nil
102 }
103}
104
105// formatJSONValue converts a JSON value to a string suitable for shell output.
106func formatJSONValue(value any) string {
107 switch v := value.(type) {
108 case string:
109 return v
110 case float64:
111 // JSON numbers are float64 by default
112 if v == float64(int64(v)) {
113 return fmt.Sprintf("%d", int64(v))
114 }
115 return fmt.Sprintf("%v", v)
116 case bool:
117 return fmt.Sprintf("%t", v)
118 case nil:
119 return ""
120 default:
121 // For complex types (arrays, objects), return JSON representation
122 b, err := json.Marshal(v)
123 if err != nil {
124 return fmt.Sprintf("%v", v)
125 }
126 return string(b)
127 }
128}
129
130// joinArgs joins arguments with spaces.
131func joinArgs(args []string) string {
132 if len(args) == 0 {
133 return ""
134 }
135 result := args[0]
136 for _, arg := range args[1:] {
137 result += " " + arg
138 }
139 return result
140}
141
142// RegisterBuiltins returns an ExecHandlerFunc that registers all Crush hook builtins.
143func RegisterBuiltins(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
144 builtins := map[string]func(context.Context, []string) error{
145 "crush_get_input": crushGetInput,
146 "crush_get_tool_input": crushGetToolInput,
147 "crush_get_prompt": crushGetPrompt,
148 "crush_log": crushLog,
149 }
150
151 return func(ctx context.Context, args []string) error {
152 if len(args) == 0 {
153 return next(ctx, args)
154 }
155
156 if fn, ok := builtins[args[0]]; ok {
157 return fn(ctx, args)
158 }
159
160 return next(ctx, args)
161 }
162}