1package hooks
2
3import (
4 "context"
5 "strings"
6 "testing"
7
8 "github.com/charmbracelet/crush/internal/shell"
9 "github.com/stretchr/testify/require"
10 "mvdan.cc/sh/v3/interp"
11)
12
13func TestBuiltinsIntegration(t *testing.T) {
14 t.Parallel()
15
16 jsonInput := `{
17 "prompt": "test prompt",
18 "tool_input": {
19 "command": "ls -la",
20 "offset": 100
21 },
22 "custom_field": "custom_value"
23 }`
24
25 script := `
26PROMPT=$(crush_get_prompt)
27COMMAND=$(crush_get_tool_input "command")
28OFFSET=$(crush_get_tool_input "offset")
29CUSTOM=$(crush_get_input "custom_field")
30
31echo "prompt=$PROMPT"
32echo "command=$COMMAND"
33echo "offset=$OFFSET"
34echo "custom=$CUSTOM"
35
36crush_log "Processing complete"
37`
38
39 hookShell := shell.NewShell(&shell.Options{
40 WorkingDir: t.TempDir(),
41 ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
42 })
43
44 // Need to set _CRUSH_STDIN before running the script
45 stdin := strings.NewReader(jsonInput)
46 setupScript := `
47_CRUSH_STDIN=$(cat)
48export _CRUSH_STDIN
49` + script
50
51 stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
52
53 require.NoError(t, err)
54 require.Contains(t, stdout, "prompt=test prompt")
55 require.Contains(t, stdout, "command=ls -la")
56 require.Contains(t, stdout, "offset=100")
57 require.Contains(t, stdout, "custom=custom_value")
58}
59
60func TestBuiltinErrors(t *testing.T) {
61 t.Parallel()
62
63 tests := []struct {
64 name string
65 script string
66 stdin string
67 wantErr bool
68 }{
69 {
70 name: "invalid json",
71 script: `crush_get_input "field"`,
72 stdin: `{invalid}`,
73 wantErr: true,
74 },
75 {
76 name: "wrong number of args",
77 script: `crush_get_input`,
78 stdin: `{"field":"value"}`,
79 wantErr: true,
80 },
81 }
82
83 for _, tt := range tests {
84 t.Run(tt.name, func(t *testing.T) {
85 t.Parallel()
86
87 hookShell := shell.NewShell(&shell.Options{
88 WorkingDir: t.TempDir(),
89 ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
90 })
91
92 setupScript := `
93_CRUSH_STDIN=$(cat)
94export _CRUSH_STDIN
95` + tt.script
96
97 stdin := strings.NewReader(tt.stdin)
98 _, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
99
100 if tt.wantErr {
101 require.Error(t, err)
102 } else {
103 require.NoError(t, err)
104 }
105 })
106 }
107}
108
109func TestBuiltinWithMissingFields(t *testing.T) {
110 t.Parallel()
111
112 jsonInput := `{"prompt": "test"}`
113
114 script := `
115MISSING=$(crush_get_input "missing_field")
116TOOL_MISSING=$(crush_get_tool_input "missing_param")
117
118if [ -z "$MISSING" ]; then
119 echo "missing is empty"
120fi
121
122if [ -z "$TOOL_MISSING" ]; then
123 echo "tool_missing is empty"
124fi
125`
126
127 hookShell := shell.NewShell(&shell.Options{
128 WorkingDir: t.TempDir(),
129 ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
130 })
131
132 stdin := strings.NewReader(jsonInput)
133 setupScript := `
134_CRUSH_STDIN=$(cat)
135export _CRUSH_STDIN
136` + script
137
138 stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
139
140 require.NoError(t, err)
141 require.Contains(t, stdout, "missing is empty")
142 require.Contains(t, stdout, "tool_missing is empty")
143}
144
145func TestBuiltinWithComplexTypes(t *testing.T) {
146 t.Parallel()
147
148 jsonInput := `{
149 "array_field": [1, 2, 3],
150 "object_field": {"key": "value"},
151 "bool_field": true,
152 "null_field": null
153 }`
154
155 script := `
156ARRAY=$(crush_get_input "array_field")
157OBJECT=$(crush_get_input "object_field")
158BOOL=$(crush_get_input "bool_field")
159NULL=$(crush_get_input "null_field")
160
161echo "array=$ARRAY"
162echo "object=$OBJECT"
163echo "bool=$BOOL"
164echo "null=$NULL"
165`
166
167 hookShell := shell.NewShell(&shell.Options{
168 WorkingDir: t.TempDir(),
169 ExecHandlers: []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc{RegisterBuiltins},
170 })
171
172 stdin := strings.NewReader(jsonInput)
173 setupScript := `
174_CRUSH_STDIN=$(cat)
175export _CRUSH_STDIN
176` + script
177
178 stdout, _, err := hookShell.ExecWithStdin(context.Background(), setupScript, stdin)
179
180 require.NoError(t, err)
181 require.Contains(t, stdout, "array=[1,2,3]")
182 require.Contains(t, stdout, `object={"key":"value"}`)
183 require.Contains(t, stdout, "bool=true")
184 require.Contains(t, stdout, "null=")
185}