1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "os"
7 "strings"
8 "testing"
9 "time"
10
11 "github.com/kujtimiihoxha/termai/internal/permission"
12 "github.com/kujtimiihoxha/termai/internal/pubsub"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17func TestBashTool_Info(t *testing.T) {
18 tool := NewBashTool(newMockPermissionService(true))
19 info := tool.Info()
20
21 assert.Equal(t, BashToolName, info.Name)
22 assert.NotEmpty(t, info.Description)
23 assert.Contains(t, info.Parameters, "command")
24 assert.Contains(t, info.Parameters, "timeout")
25 assert.Contains(t, info.Required, "command")
26}
27
28func TestBashTool_Run(t *testing.T) {
29 // Save original working directory
30 origWd, err := os.Getwd()
31 require.NoError(t, err)
32 defer func() {
33 os.Chdir(origWd)
34 }()
35
36 t.Run("executes command successfully", func(t *testing.T) {
37 tool := NewBashTool(newMockPermissionService(true))
38 params := BashParams{
39 Command: "echo 'Hello World'",
40 }
41
42 paramsJSON, err := json.Marshal(params)
43 require.NoError(t, err)
44
45 call := ToolCall{
46 Name: BashToolName,
47 Input: string(paramsJSON),
48 }
49
50 response, err := tool.Run(context.Background(), call)
51 require.NoError(t, err)
52 assert.Equal(t, "Hello World\n", response.Content)
53 })
54
55 t.Run("handles invalid parameters", func(t *testing.T) {
56 tool := NewBashTool(newMockPermissionService(true))
57 call := ToolCall{
58 Name: BashToolName,
59 Input: "invalid json",
60 }
61
62 response, err := tool.Run(context.Background(), call)
63 require.NoError(t, err)
64 assert.Contains(t, response.Content, "invalid parameters")
65 })
66
67 t.Run("handles missing command", func(t *testing.T) {
68 tool := NewBashTool(newMockPermissionService(true))
69 params := BashParams{
70 Command: "",
71 }
72
73 paramsJSON, err := json.Marshal(params)
74 require.NoError(t, err)
75
76 call := ToolCall{
77 Name: BashToolName,
78 Input: string(paramsJSON),
79 }
80
81 response, err := tool.Run(context.Background(), call)
82 require.NoError(t, err)
83 assert.Contains(t, response.Content, "missing command")
84 })
85
86 t.Run("handles banned commands", func(t *testing.T) {
87 tool := NewBashTool(newMockPermissionService(true))
88
89 for _, bannedCmd := range bannedCommands {
90 params := BashParams{
91 Command: bannedCmd + " arg1 arg2",
92 }
93
94 paramsJSON, err := json.Marshal(params)
95 require.NoError(t, err)
96
97 call := ToolCall{
98 Name: BashToolName,
99 Input: string(paramsJSON),
100 }
101
102 response, err := tool.Run(context.Background(), call)
103 require.NoError(t, err)
104 assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
105 }
106 })
107
108 t.Run("handles multi-word safe commands without permission check", func(t *testing.T) {
109 tool := NewBashTool(newMockPermissionService(false))
110
111 // Test with multi-word safe commands
112 multiWordCommands := []string{
113 "go env",
114 }
115
116 for _, cmd := range multiWordCommands {
117 params := BashParams{
118 Command: cmd,
119 }
120
121 paramsJSON, err := json.Marshal(params)
122 require.NoError(t, err)
123
124 call := ToolCall{
125 Name: BashToolName,
126 Input: string(paramsJSON),
127 }
128
129 response, err := tool.Run(context.Background(), call)
130 require.NoError(t, err)
131 assert.NotContains(t, response.Content, "permission denied",
132 "Command %s should be allowed without permission", cmd)
133 }
134 })
135
136 t.Run("handles permission denied", func(t *testing.T) {
137 tool := NewBashTool(newMockPermissionService(false))
138
139 // Test with a command that requires permission
140 params := BashParams{
141 Command: "mkdir test_dir",
142 }
143
144 paramsJSON, err := json.Marshal(params)
145 require.NoError(t, err)
146
147 call := ToolCall{
148 Name: BashToolName,
149 Input: string(paramsJSON),
150 }
151
152 response, err := tool.Run(context.Background(), call)
153 require.NoError(t, err)
154 assert.Contains(t, response.Content, "permission denied")
155 })
156
157 t.Run("handles command timeout", func(t *testing.T) {
158 tool := NewBashTool(newMockPermissionService(true))
159 params := BashParams{
160 Command: "sleep 2",
161 Timeout: 100, // 100ms timeout
162 }
163
164 paramsJSON, err := json.Marshal(params)
165 require.NoError(t, err)
166
167 call := ToolCall{
168 Name: BashToolName,
169 Input: string(paramsJSON),
170 }
171
172 response, err := tool.Run(context.Background(), call)
173 require.NoError(t, err)
174 assert.Contains(t, response.Content, "aborted")
175 })
176
177 t.Run("handles command with stderr output", func(t *testing.T) {
178 tool := NewBashTool(newMockPermissionService(true))
179 params := BashParams{
180 Command: "echo 'error message' >&2",
181 }
182
183 paramsJSON, err := json.Marshal(params)
184 require.NoError(t, err)
185
186 call := ToolCall{
187 Name: BashToolName,
188 Input: string(paramsJSON),
189 }
190
191 response, err := tool.Run(context.Background(), call)
192 require.NoError(t, err)
193 assert.Contains(t, response.Content, "error message")
194 })
195
196 t.Run("handles command with both stdout and stderr", func(t *testing.T) {
197 tool := NewBashTool(newMockPermissionService(true))
198 params := BashParams{
199 Command: "echo 'stdout message' && echo 'stderr message' >&2",
200 }
201
202 paramsJSON, err := json.Marshal(params)
203 require.NoError(t, err)
204
205 call := ToolCall{
206 Name: BashToolName,
207 Input: string(paramsJSON),
208 }
209
210 response, err := tool.Run(context.Background(), call)
211 require.NoError(t, err)
212 assert.Contains(t, response.Content, "stdout message")
213 assert.Contains(t, response.Content, "stderr message")
214 })
215
216 t.Run("handles context cancellation", func(t *testing.T) {
217 tool := NewBashTool(newMockPermissionService(true))
218 params := BashParams{
219 Command: "sleep 5",
220 }
221
222 paramsJSON, err := json.Marshal(params)
223 require.NoError(t, err)
224
225 call := ToolCall{
226 Name: BashToolName,
227 Input: string(paramsJSON),
228 }
229
230 ctx, cancel := context.WithCancel(context.Background())
231
232 // Cancel the context after a short delay
233 go func() {
234 time.Sleep(100 * time.Millisecond)
235 cancel()
236 }()
237
238 response, err := tool.Run(ctx, call)
239 require.NoError(t, err)
240 assert.Contains(t, response.Content, "aborted")
241 })
242
243 t.Run("respects max timeout", func(t *testing.T) {
244 tool := NewBashTool(newMockPermissionService(true))
245 params := BashParams{
246 Command: "echo 'test'",
247 Timeout: MaxTimeout + 1000, // Exceeds max timeout
248 }
249
250 paramsJSON, err := json.Marshal(params)
251 require.NoError(t, err)
252
253 call := ToolCall{
254 Name: BashToolName,
255 Input: string(paramsJSON),
256 }
257
258 response, err := tool.Run(context.Background(), call)
259 require.NoError(t, err)
260 assert.Equal(t, "test\n", response.Content)
261 })
262
263 t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
264 tool := NewBashTool(newMockPermissionService(true))
265 params := BashParams{
266 Command: "echo 'test'",
267 Timeout: -100, // Negative timeout
268 }
269
270 paramsJSON, err := json.Marshal(params)
271 require.NoError(t, err)
272
273 call := ToolCall{
274 Name: BashToolName,
275 Input: string(paramsJSON),
276 }
277
278 response, err := tool.Run(context.Background(), call)
279 require.NoError(t, err)
280 assert.Equal(t, "test\n", response.Content)
281 })
282}
283
284func TestTruncateOutput(t *testing.T) {
285 t.Run("does not truncate short output", func(t *testing.T) {
286 output := "short output"
287 result := truncateOutput(output)
288 assert.Equal(t, output, result)
289 })
290
291 t.Run("truncates long output", func(t *testing.T) {
292 // Create a string longer than MaxOutputLength
293 longOutput := strings.Repeat("a\n", MaxOutputLength)
294 result := truncateOutput(longOutput)
295
296 // Check that the result is shorter than the original
297 assert.Less(t, len(result), len(longOutput))
298
299 // Check that the truncation message is included
300 assert.Contains(t, result, "lines truncated")
301
302 // Check that we have the beginning and end of the original string
303 assert.True(t, strings.HasPrefix(result, "a\n"))
304 assert.True(t, strings.HasSuffix(result, "a\n"))
305 })
306}
307
308func TestCountLines(t *testing.T) {
309 testCases := []struct {
310 name string
311 input string
312 expected int
313 }{
314 {
315 name: "empty string",
316 input: "",
317 expected: 0,
318 },
319 {
320 name: "single line",
321 input: "line1",
322 expected: 1,
323 },
324 {
325 name: "multiple lines",
326 input: "line1\nline2\nline3",
327 expected: 3,
328 },
329 {
330 name: "trailing newline",
331 input: "line1\nline2\n",
332 expected: 3, // Empty string after last newline counts as a line
333 },
334 }
335
336 for _, tc := range testCases {
337 t.Run(tc.name, func(t *testing.T) {
338 result := countLines(tc.input)
339 assert.Equal(t, tc.expected, result)
340 })
341 }
342}
343
344// Mock permission service for testing
345type mockPermissionService struct {
346 *pubsub.Broker[permission.PermissionRequest]
347 allow bool
348}
349
350func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
351 // Not needed for tests
352}
353
354func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
355 // Not needed for tests
356}
357
358func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
359 // Not needed for tests
360}
361
362func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
363 return m.allow
364}
365
366func newMockPermissionService(allow bool) permission.Service {
367 return &mockPermissionService{
368 Broker: pubsub.NewBroker[permission.PermissionRequest](),
369 allow: allow,
370 }
371}