1package tools
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 _ "embed"
8 "fmt"
9 "html/template"
10 "os"
11 "path/filepath"
12 "runtime"
13 "strings"
14 "time"
15
16 "charm.land/fantasy"
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/shell"
20)
21
22type BashParams struct {
23 Description string `json:"description" description:"A brief description of what the command does, try to keep it under 30 characters or so"`
24 Command string `json:"command" description:"The command to execute"`
25 WorkingDir string `json:"working_dir,omitempty" description:"The working directory to execute the command in (defaults to current directory)"`
26 RunInBackground bool `json:"run_in_background,omitempty" description:"Set to true (boolean) to run this command in the background. Use job_output to read the output later."`
27}
28
29type BashPermissionsParams struct {
30 Description string `json:"description"`
31 Command string `json:"command"`
32 WorkingDir string `json:"working_dir"`
33 RunInBackground bool `json:"run_in_background"`
34}
35
36type BashResponseMetadata struct {
37 StartTime int64 `json:"start_time"`
38 EndTime int64 `json:"end_time"`
39 Output string `json:"output"`
40 Description string `json:"description"`
41 WorkingDirectory string `json:"working_directory"`
42 Background bool `json:"background,omitempty"`
43 ShellID string `json:"shell_id,omitempty"`
44}
45
46const (
47 BashToolName = "bash"
48
49 AutoBackgroundThreshold = 1 * time.Minute // Commands taking longer automatically become background jobs
50 MaxOutputLength = 30000
51 BashNoOutput = "no output"
52)
53
54//go:embed bash.tpl
55var bashDescriptionTmpl []byte
56
57var bashDescriptionTpl = template.Must(
58 template.New("bashDescription").
59 Parse(string(bashDescriptionTmpl)),
60)
61
62type bashDescriptionData struct {
63 BannedCommands string
64 MaxOutputLength int
65 Attribution config.Attribution
66 ModelName string
67}
68
69var defaultBannedCommands = []string{
70 // Network/Download tools
71 "alias",
72 "aria2c",
73 "axel",
74 "chrome",
75 "curl",
76 "curlie",
77 "firefox",
78 "http-prompt",
79 "httpie",
80 "links",
81 "lynx",
82 "nc",
83 "safari",
84 "scp",
85 "ssh",
86 "telnet",
87 "w3m",
88 "wget",
89 "xh",
90
91 // System administration
92 "doas",
93 "su",
94 "sudo",
95
96 // Package managers
97 "apk",
98 "apt",
99 "apt-cache",
100 "apt-get",
101 "dnf",
102 "dpkg",
103 "emerge",
104 "home-manager",
105 "makepkg",
106 "opkg",
107 "pacman",
108 "paru",
109 "pkg",
110 "pkg_add",
111 "pkg_delete",
112 "portage",
113 "rpm",
114 "yay",
115 "yum",
116 "zypper",
117
118 // System modification
119 "at",
120 "batch",
121 "chkconfig",
122 "crontab",
123 "fdisk",
124 "mkfs",
125 "mount",
126 "parted",
127 "service",
128 "systemctl",
129 "umount",
130
131 // Network configuration
132 "firewall-cmd",
133 "ifconfig",
134 "ip",
135 "iptables",
136 "netstat",
137 "pfctl",
138 "route",
139 "ufw",
140}
141
142func bashDescription(attribution *config.Attribution, modelName string, bashConfig config.ToolBash) string {
143 bannedCommandsList := resolveBannedCommandsList(bashConfig)
144 bannedCommandsStr := strings.Join(bannedCommandsList, ", ")
145 var out bytes.Buffer
146 if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
147 BannedCommands: bannedCommandsStr,
148 MaxOutputLength: MaxOutputLength,
149 Attribution: *attribution,
150 ModelName: modelName,
151 }); err != nil {
152 // this should never happen.
153 panic("failed to execute bash description template: " + err.Error())
154 }
155 return out.String()
156}
157
158var defaultBannedSubCommands = []shell.BlockFunc{
159 // System package managers
160 shell.ArgumentsBlocker("apk", []string{"add"}, nil),
161 shell.ArgumentsBlocker("apt", []string{"install"}, nil),
162 shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
163 shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
164 shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
165 shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
166 shell.ArgumentsBlocker("yum", []string{"install"}, nil),
167 shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
168
169 // Language-specific package managers
170 shell.ArgumentsBlocker("brew", []string{"install"}, nil),
171 shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
172 shell.ArgumentsBlocker("gem", []string{"install"}, nil),
173 shell.ArgumentsBlocker("go", []string{"install"}, nil),
174 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
175 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
176 shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
177 shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
178 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
179 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
180 shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
181
182 // `go test -exec` can run arbitrary commands
183 shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
184}
185
186func blockFuncs(bannedCommands []string, bannedSubCommands []config.BannedToolArgsAndOrParams, includeSubCommandDefaults bool) []shell.BlockFunc {
187 blockFuncs := []shell.BlockFunc{}
188 blockFuncs = append(blockFuncs, shell.CommandsBlocker(bannedCommands))
189
190 for _, bannedSubCmd := range bannedSubCommands {
191 blockFuncs = append(blockFuncs, shell.ArgumentsBlocker(bannedSubCmd.Command, bannedSubCmd.Args, bannedSubCmd.Flags))
192 }
193
194 if includeSubCommandDefaults {
195 blockFuncs = append(blockFuncs, defaultBannedSubCommands...)
196 }
197 return blockFuncs
198}
199
200func resolveBannedCommandsList(cfg config.ToolBash) []string {
201 bannedCommands := cfg.BannedCommands
202 if !cfg.DisableDefaultCommands {
203 if len(bannedCommands) == 0 {
204 return defaultBannedCommands
205 }
206 bannedCommands = append(bannedCommands, defaultBannedCommands...)
207 }
208 return bannedCommands
209}
210
211func resolveBlockFuncs(cfg config.ToolBash) []shell.BlockFunc {
212 return blockFuncs(resolveBannedCommandsList(cfg), cfg.BannedSubCommands, cfg.DisableDefaultSubCommands)
213}
214
215func NewBashTool(
216 permissions permission.Service,
217 workingDir string, attribution *config.Attribution,
218 modelName string,
219 bashConfig config.ToolBash,
220) fantasy.AgentTool {
221 return fantasy.NewAgentTool(
222 BashToolName,
223 string(bashDescription(attribution, modelName, bashConfig)),
224 func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
225 if params.Command == "" {
226 return fantasy.NewTextErrorResponse("missing command"), nil
227 }
228
229 // Determine working directory
230 execWorkingDir := cmp.Or(params.WorkingDir, workingDir)
231
232 isSafeReadOnly := false
233 cmdLower := strings.ToLower(params.Command)
234
235 for _, safe := range safeCommands {
236 if strings.HasPrefix(cmdLower, safe) {
237 if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
238 isSafeReadOnly = true
239 break
240 }
241 }
242 }
243
244 sessionID := GetSessionFromContext(ctx)
245 if sessionID == "" {
246 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
247 }
248 if !isSafeReadOnly {
249 p := permissions.Request(
250 permission.CreatePermissionRequest{
251 SessionID: sessionID,
252 Path: execWorkingDir,
253 ToolCallID: call.ID,
254 ToolName: BashToolName,
255 Action: "execute",
256 Description: fmt.Sprintf("Execute command: %s", params.Command),
257 Params: BashPermissionsParams(params),
258 },
259 )
260 if !p {
261 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
262 }
263 }
264
265 // If explicitly requested as background, start immediately with detached context
266 if params.RunInBackground {
267 startTime := time.Now()
268 bgManager := shell.GetBackgroundShellManager()
269 bgManager.Cleanup()
270 // Use background context so it continues after tool returns
271 bgShell, err := bgManager.Start(context.Background(), execWorkingDir, resolveBlockFuncs(bashConfig), params.Command, params.Description)
272 if err != nil {
273 return fantasy.ToolResponse{}, fmt.Errorf("error starting background shell: %w", err)
274 }
275
276 // Wait a short time to detect fast failures (blocked commands, syntax errors, etc.)
277 time.Sleep(1 * time.Second)
278 stdout, stderr, done, execErr := bgShell.GetOutput()
279
280 if done {
281 // Command failed or completed very quickly
282 bgManager.Remove(bgShell.ID)
283
284 interrupted := shell.IsInterrupt(execErr)
285 exitCode := shell.ExitCode(execErr)
286 if exitCode == 0 && !interrupted && execErr != nil {
287 return fantasy.ToolResponse{}, fmt.Errorf("[Job %s] error executing command: %w", bgShell.ID, execErr)
288 }
289
290 stdout = formatOutput(stdout, stderr, execErr)
291
292 metadata := BashResponseMetadata{
293 StartTime: startTime.UnixMilli(),
294 EndTime: time.Now().UnixMilli(),
295 Output: stdout,
296 Description: params.Description,
297 Background: params.RunInBackground,
298 WorkingDirectory: bgShell.WorkingDir,
299 }
300 if stdout == "" {
301 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
302 }
303 stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(bgShell.WorkingDir))
304 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
305 }
306
307 // Still running after fast-failure check - return as background job
308 metadata := BashResponseMetadata{
309 StartTime: startTime.UnixMilli(),
310 EndTime: time.Now().UnixMilli(),
311 Description: params.Description,
312 WorkingDirectory: bgShell.WorkingDir,
313 Background: true,
314 ShellID: bgShell.ID,
315 }
316 response := fmt.Sprintf("Background shell started with ID: %s\n\nUse job_output tool to view output or job_kill to terminate.", bgShell.ID)
317 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
318 }
319
320 // Start synchronous execution with auto-background support
321 startTime := time.Now()
322
323 // Start with detached context so it can survive if moved to background
324 bgManager := shell.GetBackgroundShellManager()
325 bgManager.Cleanup()
326 bgShell, err := bgManager.Start(context.Background(), execWorkingDir, resolveBlockFuncs(bashConfig), params.Command, params.Description)
327 if err != nil {
328 return fantasy.ToolResponse{}, fmt.Errorf("error starting shell: %w", err)
329 }
330
331 // Wait for either completion, auto-background threshold, or context cancellation
332 ticker := time.NewTicker(100 * time.Millisecond)
333 defer ticker.Stop()
334 timeout := time.After(AutoBackgroundThreshold)
335
336 var stdout, stderr string
337 var done bool
338 var execErr error
339
340 waitLoop:
341 for {
342 select {
343 case <-ticker.C:
344 stdout, stderr, done, execErr = bgShell.GetOutput()
345 if done {
346 break waitLoop
347 }
348 case <-timeout:
349 stdout, stderr, done, execErr = bgShell.GetOutput()
350 break waitLoop
351 case <-ctx.Done():
352 // Incoming context was cancelled before we moved to background
353 // Kill the shell and return error
354 bgManager.Kill(bgShell.ID)
355 return fantasy.ToolResponse{}, ctx.Err()
356 }
357 }
358
359 if done {
360 // Command completed within threshold - return synchronously
361 // Remove from background manager since we're returning directly
362 // Don't call Kill() as it cancels the context and corrupts the exit code
363 bgManager.Remove(bgShell.ID)
364
365 interrupted := shell.IsInterrupt(execErr)
366 exitCode := shell.ExitCode(execErr)
367 if exitCode == 0 && !interrupted && execErr != nil {
368 return fantasy.ToolResponse{}, fmt.Errorf("[Job %s] error executing command: %w", bgShell.ID, execErr)
369 }
370
371 stdout = formatOutput(stdout, stderr, execErr)
372
373 metadata := BashResponseMetadata{
374 StartTime: startTime.UnixMilli(),
375 EndTime: time.Now().UnixMilli(),
376 Output: stdout,
377 Description: params.Description,
378 Background: params.RunInBackground,
379 WorkingDirectory: bgShell.WorkingDir,
380 }
381 if stdout == "" {
382 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
383 }
384 stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(bgShell.WorkingDir))
385 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
386 }
387
388 // Still running - keep as background job
389 metadata := BashResponseMetadata{
390 StartTime: startTime.UnixMilli(),
391 EndTime: time.Now().UnixMilli(),
392 Description: params.Description,
393 WorkingDirectory: bgShell.WorkingDir,
394 Background: true,
395 ShellID: bgShell.ID,
396 }
397 response := fmt.Sprintf("Command is taking longer than expected and has been moved to background.\n\nBackground shell ID: %s\n\nUse job_output tool to view output or job_kill to terminate.", bgShell.ID)
398 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
399 })
400}
401
402// formatOutput formats the output of a completed command with error handling
403func formatOutput(stdout, stderr string, execErr error) string {
404 interrupted := shell.IsInterrupt(execErr)
405 exitCode := shell.ExitCode(execErr)
406
407 stdout = truncateOutput(stdout)
408 stderr = truncateOutput(stderr)
409
410 errorMessage := stderr
411 if errorMessage == "" && execErr != nil {
412 errorMessage = execErr.Error()
413 }
414
415 if interrupted {
416 if errorMessage != "" {
417 errorMessage += "\n"
418 }
419 errorMessage += "Command was aborted before completion"
420 } else if exitCode != 0 {
421 if errorMessage != "" {
422 errorMessage += "\n"
423 }
424 errorMessage += fmt.Sprintf("Exit code %d", exitCode)
425 }
426
427 hasBothOutputs := stdout != "" && stderr != ""
428
429 if hasBothOutputs {
430 stdout += "\n"
431 }
432
433 if errorMessage != "" {
434 stdout += "\n" + errorMessage
435 }
436
437 return stdout
438}
439
440func truncateOutput(content string) string {
441 if len(content) <= MaxOutputLength {
442 return content
443 }
444
445 halfLength := MaxOutputLength / 2
446 start := content[:halfLength]
447 end := content[len(content)-halfLength:]
448
449 truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
450 return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
451}
452
453func countLines(s string) int {
454 if s == "" {
455 return 0
456 }
457 return len(strings.Split(s, "\n"))
458}
459
460func normalizeWorkingDir(path string) string {
461 if runtime.GOOS == "windows" {
462 cwd, err := os.Getwd()
463 if err != nil {
464 cwd = "C:"
465 }
466 path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
467 }
468
469 return filepath.ToSlash(path)
470}