1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "os"
7 "strings"
8 "testing"
9 "time"
10
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13)
14
15func TestBashTool_Info(t *testing.T) {
16 tool := NewBashTool(newMockPermissionService(true))
17 info := tool.Info()
18
19 assert.Equal(t, BashToolName, info.Name)
20 assert.NotEmpty(t, info.Description)
21 assert.Contains(t, info.Parameters, "command")
22 assert.Contains(t, info.Parameters, "timeout")
23 assert.Contains(t, info.Required, "command")
24}
25
26func TestBashTool_Run(t *testing.T) {
27 // Save original working directory
28 origWd, err := os.Getwd()
29 require.NoError(t, err)
30 defer func() {
31 os.Chdir(origWd)
32 }()
33
34 t.Run("executes command successfully", func(t *testing.T) {
35 tool := NewBashTool(newMockPermissionService(true))
36 params := BashParams{
37 Command: "echo 'Hello World'",
38 }
39
40 paramsJSON, err := json.Marshal(params)
41 require.NoError(t, err)
42
43 call := ToolCall{
44 Name: BashToolName,
45 Input: string(paramsJSON),
46 }
47
48 response, err := tool.Run(context.Background(), call)
49 require.NoError(t, err)
50 assert.Equal(t, "Hello World\n", response.Content)
51 })
52
53 t.Run("handles invalid parameters", func(t *testing.T) {
54 tool := NewBashTool(newMockPermissionService(true))
55 call := ToolCall{
56 Name: BashToolName,
57 Input: "invalid json",
58 }
59
60 response, err := tool.Run(context.Background(), call)
61 require.NoError(t, err)
62 assert.Contains(t, response.Content, "invalid parameters")
63 })
64
65 t.Run("handles missing command", func(t *testing.T) {
66 tool := NewBashTool(newMockPermissionService(true))
67 params := BashParams{
68 Command: "",
69 }
70
71 paramsJSON, err := json.Marshal(params)
72 require.NoError(t, err)
73
74 call := ToolCall{
75 Name: BashToolName,
76 Input: string(paramsJSON),
77 }
78
79 response, err := tool.Run(context.Background(), call)
80 require.NoError(t, err)
81 assert.Contains(t, response.Content, "missing command")
82 })
83
84 t.Run("handles banned commands", func(t *testing.T) {
85 tool := NewBashTool(newMockPermissionService(true))
86
87 for _, bannedCmd := range bannedCommands {
88 params := BashParams{
89 Command: bannedCmd + " arg1 arg2",
90 }
91
92 paramsJSON, err := json.Marshal(params)
93 require.NoError(t, err)
94
95 call := ToolCall{
96 Name: BashToolName,
97 Input: string(paramsJSON),
98 }
99
100 response, err := tool.Run(context.Background(), call)
101 require.NoError(t, err)
102 assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
103 }
104 })
105
106 t.Run("handles multi-word safe commands without permission check", func(t *testing.T) {
107 tool := NewBashTool(newMockPermissionService(false))
108
109 // Test with multi-word safe commands
110 multiWordCommands := []string{
111 "go env",
112 }
113
114 for _, cmd := range multiWordCommands {
115 params := BashParams{
116 Command: cmd,
117 }
118
119 paramsJSON, err := json.Marshal(params)
120 require.NoError(t, err)
121
122 call := ToolCall{
123 Name: BashToolName,
124 Input: string(paramsJSON),
125 }
126
127 response, err := tool.Run(context.Background(), call)
128 require.NoError(t, err)
129 assert.NotContains(t, response.Content, "permission denied",
130 "Command %s should be allowed without permission", cmd)
131 }
132 })
133
134 t.Run("handles permission denied", func(t *testing.T) {
135 tool := NewBashTool(newMockPermissionService(false))
136
137 // Test with a command that requires permission
138 params := BashParams{
139 Command: "mkdir test_dir",
140 }
141
142 paramsJSON, err := json.Marshal(params)
143 require.NoError(t, err)
144
145 call := ToolCall{
146 Name: BashToolName,
147 Input: string(paramsJSON),
148 }
149
150 response, err := tool.Run(context.Background(), call)
151 require.NoError(t, err)
152 assert.Contains(t, response.Content, "permission denied")
153 })
154
155 t.Run("handles command timeout", func(t *testing.T) {
156 tool := NewBashTool(newMockPermissionService(true))
157 params := BashParams{
158 Command: "sleep 2",
159 Timeout: 100, // 100ms timeout
160 }
161
162 paramsJSON, err := json.Marshal(params)
163 require.NoError(t, err)
164
165 call := ToolCall{
166 Name: BashToolName,
167 Input: string(paramsJSON),
168 }
169
170 response, err := tool.Run(context.Background(), call)
171 require.NoError(t, err)
172 assert.Contains(t, response.Content, "aborted")
173 })
174
175 t.Run("handles command with stderr output", func(t *testing.T) {
176 tool := NewBashTool(newMockPermissionService(true))
177 params := BashParams{
178 Command: "echo 'error message' >&2",
179 }
180
181 paramsJSON, err := json.Marshal(params)
182 require.NoError(t, err)
183
184 call := ToolCall{
185 Name: BashToolName,
186 Input: string(paramsJSON),
187 }
188
189 response, err := tool.Run(context.Background(), call)
190 require.NoError(t, err)
191 assert.Contains(t, response.Content, "error message")
192 })
193
194 t.Run("handles command with both stdout and stderr", func(t *testing.T) {
195 tool := NewBashTool(newMockPermissionService(true))
196 params := BashParams{
197 Command: "echo 'stdout message' && echo 'stderr message' >&2",
198 }
199
200 paramsJSON, err := json.Marshal(params)
201 require.NoError(t, err)
202
203 call := ToolCall{
204 Name: BashToolName,
205 Input: string(paramsJSON),
206 }
207
208 response, err := tool.Run(context.Background(), call)
209 require.NoError(t, err)
210 assert.Contains(t, response.Content, "stdout message")
211 assert.Contains(t, response.Content, "stderr message")
212 })
213
214 t.Run("handles context cancellation", func(t *testing.T) {
215 tool := NewBashTool(newMockPermissionService(true))
216 params := BashParams{
217 Command: "sleep 5",
218 }
219
220 paramsJSON, err := json.Marshal(params)
221 require.NoError(t, err)
222
223 call := ToolCall{
224 Name: BashToolName,
225 Input: string(paramsJSON),
226 }
227
228 ctx, cancel := context.WithCancel(context.Background())
229
230 // Cancel the context after a short delay
231 go func() {
232 time.Sleep(100 * time.Millisecond)
233 cancel()
234 }()
235
236 response, err := tool.Run(ctx, call)
237 require.NoError(t, err)
238 assert.Contains(t, response.Content, "aborted")
239 })
240
241 t.Run("respects max timeout", func(t *testing.T) {
242 tool := NewBashTool(newMockPermissionService(true))
243 params := BashParams{
244 Command: "echo 'test'",
245 Timeout: MaxTimeout + 1000, // Exceeds max timeout
246 }
247
248 paramsJSON, err := json.Marshal(params)
249 require.NoError(t, err)
250
251 call := ToolCall{
252 Name: BashToolName,
253 Input: string(paramsJSON),
254 }
255
256 response, err := tool.Run(context.Background(), call)
257 require.NoError(t, err)
258 assert.Equal(t, "test\n", response.Content)
259 })
260
261 t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
262 tool := NewBashTool(newMockPermissionService(true))
263 params := BashParams{
264 Command: "echo 'test'",
265 Timeout: -100, // Negative timeout
266 }
267
268 paramsJSON, err := json.Marshal(params)
269 require.NoError(t, err)
270
271 call := ToolCall{
272 Name: BashToolName,
273 Input: string(paramsJSON),
274 }
275
276 response, err := tool.Run(context.Background(), call)
277 require.NoError(t, err)
278 assert.Equal(t, "test\n", response.Content)
279 })
280}
281
282func TestTruncateOutput(t *testing.T) {
283 t.Run("does not truncate short output", func(t *testing.T) {
284 output := "short output"
285 result := truncateOutput(output)
286 assert.Equal(t, output, result)
287 })
288
289 t.Run("truncates long output", func(t *testing.T) {
290 // Create a string longer than MaxOutputLength
291 longOutput := strings.Repeat("a\n", MaxOutputLength)
292 result := truncateOutput(longOutput)
293
294 // Check that the result is shorter than the original
295 assert.Less(t, len(result), len(longOutput))
296
297 // Check that the truncation message is included
298 assert.Contains(t, result, "lines truncated")
299
300 // Check that we have the beginning and end of the original string
301 assert.True(t, strings.HasPrefix(result, "a\n"))
302 assert.True(t, strings.HasSuffix(result, "a\n"))
303 })
304}
305
306func TestCountLines(t *testing.T) {
307 testCases := []struct {
308 name string
309 input string
310 expected int
311 }{
312 {
313 name: "empty string",
314 input: "",
315 expected: 0,
316 },
317 {
318 name: "single line",
319 input: "line1",
320 expected: 1,
321 },
322 {
323 name: "multiple lines",
324 input: "line1\nline2\nline3",
325 expected: 3,
326 },
327 {
328 name: "trailing newline",
329 input: "line1\nline2\n",
330 expected: 3, // Empty string after last newline counts as a line
331 },
332 }
333
334 for _, tc := range testCases {
335 t.Run(tc.name, func(t *testing.T) {
336 result := countLines(tc.input)
337 assert.Equal(t, tc.expected, result)
338 })
339 }
340}