1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "os"
7 "strings"
8 "testing"
9 "time"
10
11 "github.com/kujtimiihoxha/termai/internal/permission"
12 "github.com/kujtimiihoxha/termai/internal/pubsub"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17func TestBashTool_Info(t *testing.T) {
18 tool := NewBashTool()
19 info := tool.Info()
20
21 assert.Equal(t, BashToolName, info.Name)
22 assert.NotEmpty(t, info.Description)
23 assert.Contains(t, info.Parameters, "command")
24 assert.Contains(t, info.Parameters, "timeout")
25 assert.Contains(t, info.Required, "command")
26}
27
28func TestBashTool_Run(t *testing.T) {
29 // Setup a mock permission handler that always allows
30 origPermission := permission.Default
31 defer func() {
32 permission.Default = origPermission
33 }()
34 permission.Default = newMockPermissionService(true)
35
36 // Save original working directory
37 origWd, err := os.Getwd()
38 require.NoError(t, err)
39 defer func() {
40 os.Chdir(origWd)
41 }()
42
43 t.Run("executes command successfully", func(t *testing.T) {
44 permission.Default = newMockPermissionService(true)
45 tool := NewBashTool()
46 params := BashParams{
47 Command: "echo 'Hello World'",
48 }
49
50 paramsJSON, err := json.Marshal(params)
51 require.NoError(t, err)
52
53 call := ToolCall{
54 Name: BashToolName,
55 Input: string(paramsJSON),
56 }
57
58 response, err := tool.Run(context.Background(), call)
59 require.NoError(t, err)
60 assert.Equal(t, "Hello World\n", response.Content)
61 })
62
63 t.Run("handles invalid parameters", func(t *testing.T) {
64 permission.Default = newMockPermissionService(true)
65
66 tool := NewBashTool()
67 call := ToolCall{
68 Name: BashToolName,
69 Input: "invalid json",
70 }
71
72 response, err := tool.Run(context.Background(), call)
73 require.NoError(t, err)
74 assert.Contains(t, response.Content, "invalid parameters")
75 })
76
77 t.Run("handles missing command", func(t *testing.T) {
78 permission.Default = newMockPermissionService(true)
79
80 tool := NewBashTool()
81 params := BashParams{
82 Command: "",
83 }
84
85 paramsJSON, err := json.Marshal(params)
86 require.NoError(t, err)
87
88 call := ToolCall{
89 Name: BashToolName,
90 Input: string(paramsJSON),
91 }
92
93 response, err := tool.Run(context.Background(), call)
94 require.NoError(t, err)
95 assert.Contains(t, response.Content, "missing command")
96 })
97
98 t.Run("handles banned commands", func(t *testing.T) {
99 permission.Default = newMockPermissionService(true)
100
101 tool := NewBashTool()
102
103 for _, bannedCmd := range BannedCommands {
104 params := BashParams{
105 Command: bannedCmd + " arg1 arg2",
106 }
107
108 paramsJSON, err := json.Marshal(params)
109 require.NoError(t, err)
110
111 call := ToolCall{
112 Name: BashToolName,
113 Input: string(paramsJSON),
114 }
115
116 response, err := tool.Run(context.Background(), call)
117 require.NoError(t, err)
118 assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
119 }
120 })
121
122 t.Run("handles multi-word safe commands without permission check", func(t *testing.T) {
123 permission.Default = newMockPermissionService(false)
124
125 tool := NewBashTool()
126
127 // Test with multi-word safe commands
128 multiWordCommands := []string{
129 "git status",
130 "git log -n 5",
131 "docker ps",
132 "go test ./...",
133 "kubectl get pods",
134 }
135
136 for _, cmd := range multiWordCommands {
137 params := BashParams{
138 Command: cmd,
139 }
140
141 paramsJSON, err := json.Marshal(params)
142 require.NoError(t, err)
143
144 call := ToolCall{
145 Name: BashToolName,
146 Input: string(paramsJSON),
147 }
148
149 response, err := tool.Run(context.Background(), call)
150 require.NoError(t, err)
151 assert.NotContains(t, response.Content, "permission denied",
152 "Command %s should be allowed without permission", cmd)
153 }
154 })
155
156 t.Run("handles permission denied", func(t *testing.T) {
157 permission.Default = newMockPermissionService(false)
158
159 tool := NewBashTool()
160
161 // Test with a command that requires permission
162 params := BashParams{
163 Command: "mkdir test_dir",
164 }
165
166 paramsJSON, err := json.Marshal(params)
167 require.NoError(t, err)
168
169 call := ToolCall{
170 Name: BashToolName,
171 Input: string(paramsJSON),
172 }
173
174 response, err := tool.Run(context.Background(), call)
175 require.NoError(t, err)
176 assert.Contains(t, response.Content, "permission denied")
177 })
178
179 t.Run("handles command timeout", func(t *testing.T) {
180 permission.Default = newMockPermissionService(true)
181 tool := NewBashTool()
182 params := BashParams{
183 Command: "sleep 2",
184 Timeout: 100, // 100ms timeout
185 }
186
187 paramsJSON, err := json.Marshal(params)
188 require.NoError(t, err)
189
190 call := ToolCall{
191 Name: BashToolName,
192 Input: string(paramsJSON),
193 }
194
195 response, err := tool.Run(context.Background(), call)
196 require.NoError(t, err)
197 assert.Contains(t, response.Content, "aborted")
198 })
199
200 t.Run("handles command with stderr output", func(t *testing.T) {
201 permission.Default = newMockPermissionService(true)
202 tool := NewBashTool()
203 params := BashParams{
204 Command: "echo 'error message' >&2",
205 }
206
207 paramsJSON, err := json.Marshal(params)
208 require.NoError(t, err)
209
210 call := ToolCall{
211 Name: BashToolName,
212 Input: string(paramsJSON),
213 }
214
215 response, err := tool.Run(context.Background(), call)
216 require.NoError(t, err)
217 assert.Contains(t, response.Content, "error message")
218 })
219
220 t.Run("handles command with both stdout and stderr", func(t *testing.T) {
221 permission.Default = newMockPermissionService(true)
222 tool := NewBashTool()
223 params := BashParams{
224 Command: "echo 'stdout message' && echo 'stderr message' >&2",
225 }
226
227 paramsJSON, err := json.Marshal(params)
228 require.NoError(t, err)
229
230 call := ToolCall{
231 Name: BashToolName,
232 Input: string(paramsJSON),
233 }
234
235 response, err := tool.Run(context.Background(), call)
236 require.NoError(t, err)
237 assert.Contains(t, response.Content, "stdout message")
238 assert.Contains(t, response.Content, "stderr message")
239 })
240
241 t.Run("handles context cancellation", func(t *testing.T) {
242 permission.Default = newMockPermissionService(true)
243 tool := NewBashTool()
244 params := BashParams{
245 Command: "sleep 5",
246 }
247
248 paramsJSON, err := json.Marshal(params)
249 require.NoError(t, err)
250
251 call := ToolCall{
252 Name: BashToolName,
253 Input: string(paramsJSON),
254 }
255
256 ctx, cancel := context.WithCancel(context.Background())
257
258 // Cancel the context after a short delay
259 go func() {
260 time.Sleep(100 * time.Millisecond)
261 cancel()
262 }()
263
264 response, err := tool.Run(ctx, call)
265 require.NoError(t, err)
266 assert.Contains(t, response.Content, "aborted")
267 })
268
269 t.Run("respects max timeout", func(t *testing.T) {
270 permission.Default = newMockPermissionService(true)
271 tool := NewBashTool()
272 params := BashParams{
273 Command: "echo 'test'",
274 Timeout: MaxTimeout + 1000, // Exceeds max timeout
275 }
276
277 paramsJSON, err := json.Marshal(params)
278 require.NoError(t, err)
279
280 call := ToolCall{
281 Name: BashToolName,
282 Input: string(paramsJSON),
283 }
284
285 response, err := tool.Run(context.Background(), call)
286 require.NoError(t, err)
287 assert.Equal(t, "test\n", response.Content)
288 })
289
290 t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
291 permission.Default = newMockPermissionService(true)
292 tool := NewBashTool()
293 params := BashParams{
294 Command: "echo 'test'",
295 Timeout: -100, // Negative timeout
296 }
297
298 paramsJSON, err := json.Marshal(params)
299 require.NoError(t, err)
300
301 call := ToolCall{
302 Name: BashToolName,
303 Input: string(paramsJSON),
304 }
305
306 response, err := tool.Run(context.Background(), call)
307 require.NoError(t, err)
308 assert.Equal(t, "test\n", response.Content)
309 })
310}
311
312func TestTruncateOutput(t *testing.T) {
313 t.Run("does not truncate short output", func(t *testing.T) {
314 output := "short output"
315 result := truncateOutput(output)
316 assert.Equal(t, output, result)
317 })
318
319 t.Run("truncates long output", func(t *testing.T) {
320 // Create a string longer than MaxOutputLength
321 longOutput := strings.Repeat("a\n", MaxOutputLength)
322 result := truncateOutput(longOutput)
323
324 // Check that the result is shorter than the original
325 assert.Less(t, len(result), len(longOutput))
326
327 // Check that the truncation message is included
328 assert.Contains(t, result, "lines truncated")
329
330 // Check that we have the beginning and end of the original string
331 assert.True(t, strings.HasPrefix(result, "a\n"))
332 assert.True(t, strings.HasSuffix(result, "a\n"))
333 })
334}
335
336func TestCountLines(t *testing.T) {
337 testCases := []struct {
338 name string
339 input string
340 expected int
341 }{
342 {
343 name: "empty string",
344 input: "",
345 expected: 0,
346 },
347 {
348 name: "single line",
349 input: "line1",
350 expected: 1,
351 },
352 {
353 name: "multiple lines",
354 input: "line1\nline2\nline3",
355 expected: 3,
356 },
357 {
358 name: "trailing newline",
359 input: "line1\nline2\n",
360 expected: 3, // Empty string after last newline counts as a line
361 },
362 }
363
364 for _, tc := range testCases {
365 t.Run(tc.name, func(t *testing.T) {
366 result := countLines(tc.input)
367 assert.Equal(t, tc.expected, result)
368 })
369 }
370}
371
372// Mock permission service for testing
373type mockPermissionService struct {
374 *pubsub.Broker[permission.PermissionRequest]
375 allow bool
376}
377
378func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
379 // Not needed for tests
380}
381
382func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
383 // Not needed for tests
384}
385
386func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
387 // Not needed for tests
388}
389
390func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
391 return m.allow
392}
393
394func newMockPermissionService(allow bool) permission.Service {
395 return &mockPermissionService{
396 Broker: pubsub.NewBroker[permission.PermissionRequest](),
397 allow: allow,
398 }
399}
400