1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "os"
7 "strings"
8 "testing"
9 "time"
10
11 "github.com/kujtimiihoxha/termai/internal/permission"
12 "github.com/kujtimiihoxha/termai/internal/pubsub"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17func TestBashTool_Info(t *testing.T) {
18 tool := NewBashTool()
19 info := tool.Info()
20
21 assert.Equal(t, BashToolName, info.Name)
22 assert.NotEmpty(t, info.Description)
23 assert.Contains(t, info.Parameters, "command")
24 assert.Contains(t, info.Parameters, "timeout")
25 assert.Contains(t, info.Required, "command")
26}
27
28func TestBashTool_Run(t *testing.T) {
29 // Setup a mock permission handler that always allows
30 origPermission := permission.Default
31 defer func() {
32 permission.Default = origPermission
33 }()
34 permission.Default = newMockPermissionService(true)
35
36 // Save original working directory
37 origWd, err := os.Getwd()
38 require.NoError(t, err)
39 defer func() {
40 os.Chdir(origWd)
41 }()
42
43 t.Run("executes command successfully", func(t *testing.T) {
44 permission.Default = newMockPermissionService(true)
45 tool := NewBashTool()
46 params := BashParams{
47 Command: "echo 'Hello World'",
48 }
49
50 paramsJSON, err := json.Marshal(params)
51 require.NoError(t, err)
52
53 call := ToolCall{
54 Name: BashToolName,
55 Input: string(paramsJSON),
56 }
57
58 response, err := tool.Run(context.Background(), call)
59 require.NoError(t, err)
60 assert.Equal(t, "Hello World\n", response.Content)
61 })
62
63 t.Run("handles invalid parameters", func(t *testing.T) {
64 permission.Default = newMockPermissionService(true)
65
66 tool := NewBashTool()
67 call := ToolCall{
68 Name: BashToolName,
69 Input: "invalid json",
70 }
71
72 response, err := tool.Run(context.Background(), call)
73 require.NoError(t, err)
74 assert.Contains(t, response.Content, "invalid parameters")
75 })
76
77 t.Run("handles missing command", func(t *testing.T) {
78 permission.Default = newMockPermissionService(true)
79
80 tool := NewBashTool()
81 params := BashParams{
82 Command: "",
83 }
84
85 paramsJSON, err := json.Marshal(params)
86 require.NoError(t, err)
87
88 call := ToolCall{
89 Name: BashToolName,
90 Input: string(paramsJSON),
91 }
92
93 response, err := tool.Run(context.Background(), call)
94 require.NoError(t, err)
95 assert.Contains(t, response.Content, "missing command")
96 })
97
98 t.Run("handles banned commands", func(t *testing.T) {
99 permission.Default = newMockPermissionService(true)
100
101 tool := NewBashTool()
102
103 for _, bannedCmd := range BannedCommands {
104 params := BashParams{
105 Command: bannedCmd + " arg1 arg2",
106 }
107
108 paramsJSON, err := json.Marshal(params)
109 require.NoError(t, err)
110
111 call := ToolCall{
112 Name: BashToolName,
113 Input: string(paramsJSON),
114 }
115
116 response, err := tool.Run(context.Background(), call)
117 require.NoError(t, err)
118 assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
119 }
120 })
121
122 t.Run("handles safe read-only commands without permission check", func(t *testing.T) {
123 permission.Default = newMockPermissionService(false)
124
125 tool := NewBashTool()
126
127 // Test with a safe read-only command
128 params := BashParams{
129 Command: "echo 'test'",
130 }
131
132 paramsJSON, err := json.Marshal(params)
133 require.NoError(t, err)
134
135 call := ToolCall{
136 Name: BashToolName,
137 Input: string(paramsJSON),
138 }
139
140 response, err := tool.Run(context.Background(), call)
141 require.NoError(t, err)
142 assert.Equal(t, "test\n", response.Content)
143 })
144
145 t.Run("handles permission denied", func(t *testing.T) {
146 permission.Default = newMockPermissionService(false)
147
148 tool := NewBashTool()
149
150 // Test with a command that requires permission
151 params := BashParams{
152 Command: "mkdir test_dir",
153 }
154
155 paramsJSON, err := json.Marshal(params)
156 require.NoError(t, err)
157
158 call := ToolCall{
159 Name: BashToolName,
160 Input: string(paramsJSON),
161 }
162
163 response, err := tool.Run(context.Background(), call)
164 require.NoError(t, err)
165 assert.Contains(t, response.Content, "permission denied")
166 })
167
168 t.Run("handles command timeout", func(t *testing.T) {
169 permission.Default = newMockPermissionService(true)
170 tool := NewBashTool()
171 params := BashParams{
172 Command: "sleep 2",
173 Timeout: 100, // 100ms timeout
174 }
175
176 paramsJSON, err := json.Marshal(params)
177 require.NoError(t, err)
178
179 call := ToolCall{
180 Name: BashToolName,
181 Input: string(paramsJSON),
182 }
183
184 response, err := tool.Run(context.Background(), call)
185 require.NoError(t, err)
186 assert.Contains(t, response.Content, "aborted")
187 })
188
189 t.Run("handles command with stderr output", func(t *testing.T) {
190 permission.Default = newMockPermissionService(true)
191 tool := NewBashTool()
192 params := BashParams{
193 Command: "echo 'error message' >&2",
194 }
195
196 paramsJSON, err := json.Marshal(params)
197 require.NoError(t, err)
198
199 call := ToolCall{
200 Name: BashToolName,
201 Input: string(paramsJSON),
202 }
203
204 response, err := tool.Run(context.Background(), call)
205 require.NoError(t, err)
206 assert.Contains(t, response.Content, "error message")
207 })
208
209 t.Run("handles command with both stdout and stderr", func(t *testing.T) {
210 permission.Default = newMockPermissionService(true)
211 tool := NewBashTool()
212 params := BashParams{
213 Command: "echo 'stdout message' && echo 'stderr message' >&2",
214 }
215
216 paramsJSON, err := json.Marshal(params)
217 require.NoError(t, err)
218
219 call := ToolCall{
220 Name: BashToolName,
221 Input: string(paramsJSON),
222 }
223
224 response, err := tool.Run(context.Background(), call)
225 require.NoError(t, err)
226 assert.Contains(t, response.Content, "stdout message")
227 assert.Contains(t, response.Content, "stderr message")
228 })
229
230 t.Run("handles context cancellation", func(t *testing.T) {
231 permission.Default = newMockPermissionService(true)
232 tool := NewBashTool()
233 params := BashParams{
234 Command: "sleep 5",
235 }
236
237 paramsJSON, err := json.Marshal(params)
238 require.NoError(t, err)
239
240 call := ToolCall{
241 Name: BashToolName,
242 Input: string(paramsJSON),
243 }
244
245 ctx, cancel := context.WithCancel(context.Background())
246
247 // Cancel the context after a short delay
248 go func() {
249 time.Sleep(100 * time.Millisecond)
250 cancel()
251 }()
252
253 response, err := tool.Run(ctx, call)
254 require.NoError(t, err)
255 assert.Contains(t, response.Content, "aborted")
256 })
257
258 t.Run("respects max timeout", func(t *testing.T) {
259 permission.Default = newMockPermissionService(true)
260 tool := NewBashTool()
261 params := BashParams{
262 Command: "echo 'test'",
263 Timeout: MaxTimeout + 1000, // Exceeds max timeout
264 }
265
266 paramsJSON, err := json.Marshal(params)
267 require.NoError(t, err)
268
269 call := ToolCall{
270 Name: BashToolName,
271 Input: string(paramsJSON),
272 }
273
274 response, err := tool.Run(context.Background(), call)
275 require.NoError(t, err)
276 assert.Equal(t, "test\n", response.Content)
277 })
278
279 t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
280 permission.Default = newMockPermissionService(true)
281 tool := NewBashTool()
282 params := BashParams{
283 Command: "echo 'test'",
284 Timeout: -100, // Negative timeout
285 }
286
287 paramsJSON, err := json.Marshal(params)
288 require.NoError(t, err)
289
290 call := ToolCall{
291 Name: BashToolName,
292 Input: string(paramsJSON),
293 }
294
295 response, err := tool.Run(context.Background(), call)
296 require.NoError(t, err)
297 assert.Equal(t, "test\n", response.Content)
298 })
299}
300
301func TestTruncateOutput(t *testing.T) {
302 t.Run("does not truncate short output", func(t *testing.T) {
303 output := "short output"
304 result := truncateOutput(output)
305 assert.Equal(t, output, result)
306 })
307
308 t.Run("truncates long output", func(t *testing.T) {
309 // Create a string longer than MaxOutputLength
310 longOutput := strings.Repeat("a\n", MaxOutputLength)
311 result := truncateOutput(longOutput)
312
313 // Check that the result is shorter than the original
314 assert.Less(t, len(result), len(longOutput))
315
316 // Check that the truncation message is included
317 assert.Contains(t, result, "lines truncated")
318
319 // Check that we have the beginning and end of the original string
320 assert.True(t, strings.HasPrefix(result, "a\n"))
321 assert.True(t, strings.HasSuffix(result, "a\n"))
322 })
323}
324
325func TestCountLines(t *testing.T) {
326 testCases := []struct {
327 name string
328 input string
329 expected int
330 }{
331 {
332 name: "empty string",
333 input: "",
334 expected: 0,
335 },
336 {
337 name: "single line",
338 input: "line1",
339 expected: 1,
340 },
341 {
342 name: "multiple lines",
343 input: "line1\nline2\nline3",
344 expected: 3,
345 },
346 {
347 name: "trailing newline",
348 input: "line1\nline2\n",
349 expected: 3, // Empty string after last newline counts as a line
350 },
351 }
352
353 for _, tc := range testCases {
354 t.Run(tc.name, func(t *testing.T) {
355 result := countLines(tc.input)
356 assert.Equal(t, tc.expected, result)
357 })
358 }
359}
360
361// Mock permission service for testing
362type mockPermissionService struct {
363 *pubsub.Broker[permission.PermissionRequest]
364 allow bool
365}
366
367func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
368 // Not needed for tests
369}
370
371func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
372 // Not needed for tests
373}
374
375func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
376 // Not needed for tests
377}
378
379func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
380 return m.allow
381}
382
383func newMockPermissionService(allow bool) permission.Service {
384 return &mockPermissionService{
385 Broker: pubsub.NewBroker[permission.PermissionRequest](),
386 allow: allow,
387 }
388}
389