1package tools
2
3import (
4 "bytes"
5 "context"
6 _ "embed"
7 "fmt"
8 "html/template"
9 "os"
10 "path/filepath"
11 "runtime"
12 "strings"
13 "time"
14
15 "charm.land/fantasy"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/permission"
18 "github.com/charmbracelet/crush/internal/shell"
19)
20
21type BashParams struct {
22 Command string `json:"command" description:"The command to execute"`
23 Description string `json:"description,omitempty" description:"A brief description of what the command does"`
24 Timeout int `json:"timeout,omitempty" description:"Optional timeout in milliseconds (max 600000)"`
25 Background bool `json:"background,omitempty" description:"Run the command in a background shell. Returns a shell ID for managing the process."`
26}
27
28type BashPermissionsParams struct {
29 Command string `json:"command"`
30 Description string `json:"description"`
31 Timeout int `json:"timeout"`
32 Background bool `json:"background"`
33}
34
35type BashResponseMetadata struct {
36 StartTime int64 `json:"start_time"`
37 EndTime int64 `json:"end_time"`
38 Output string `json:"output"`
39 Description string `json:"description"`
40 WorkingDirectory string `json:"working_directory"`
41 Background bool `json:"background,omitempty"`
42 ShellID string `json:"shell_id,omitempty"`
43}
44
45const (
46 BashToolName = "bash"
47
48 DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
49 MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
50 MaxOutputLength = 30000
51 BashNoOutput = "no output"
52)
53
54//go:embed bash.tpl
55var bashDescriptionTmpl []byte
56
57var bashDescriptionTpl = template.Must(
58 template.New("bashDescription").
59 Parse(string(bashDescriptionTmpl)),
60)
61
62type bashDescriptionData struct {
63 BannedCommands string
64 MaxOutputLength int
65 Attribution config.Attribution
66}
67
68var bannedCommands = []string{
69 // Network/Download tools
70 "alias",
71 "aria2c",
72 "axel",
73 "chrome",
74 "curl",
75 "curlie",
76 "firefox",
77 "http-prompt",
78 "httpie",
79 "links",
80 "lynx",
81 "nc",
82 "safari",
83 "scp",
84 "ssh",
85 "telnet",
86 "w3m",
87 "wget",
88 "xh",
89
90 // System administration
91 "doas",
92 "su",
93 "sudo",
94
95 // Package managers
96 "apk",
97 "apt",
98 "apt-cache",
99 "apt-get",
100 "dnf",
101 "dpkg",
102 "emerge",
103 "home-manager",
104 "makepkg",
105 "opkg",
106 "pacman",
107 "paru",
108 "pkg",
109 "pkg_add",
110 "pkg_delete",
111 "portage",
112 "rpm",
113 "yay",
114 "yum",
115 "zypper",
116
117 // System modification
118 "at",
119 "batch",
120 "chkconfig",
121 "crontab",
122 "fdisk",
123 "mkfs",
124 "mount",
125 "parted",
126 "service",
127 "systemctl",
128 "umount",
129
130 // Network configuration
131 "firewall-cmd",
132 "ifconfig",
133 "ip",
134 "iptables",
135 "netstat",
136 "pfctl",
137 "route",
138 "ufw",
139}
140
141func bashDescription(attribution *config.Attribution) string {
142 bannedCommandsStr := strings.Join(bannedCommands, ", ")
143 var out bytes.Buffer
144 if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
145 BannedCommands: bannedCommandsStr,
146 MaxOutputLength: MaxOutputLength,
147 Attribution: *attribution,
148 }); err != nil {
149 // this should never happen.
150 panic("failed to execute bash description template: " + err.Error())
151 }
152 return out.String()
153}
154
155func blockFuncs() []shell.BlockFunc {
156 return []shell.BlockFunc{
157 shell.CommandsBlocker(bannedCommands),
158
159 // System package managers
160 shell.ArgumentsBlocker("apk", []string{"add"}, nil),
161 shell.ArgumentsBlocker("apt", []string{"install"}, nil),
162 shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
163 shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
164 shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
165 shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
166 shell.ArgumentsBlocker("yum", []string{"install"}, nil),
167 shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
168
169 // Language-specific package managers
170 shell.ArgumentsBlocker("brew", []string{"install"}, nil),
171 shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
172 shell.ArgumentsBlocker("gem", []string{"install"}, nil),
173 shell.ArgumentsBlocker("go", []string{"install"}, nil),
174 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
175 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
176 shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
177 shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
178 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
179 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
180 shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
181
182 // `go test -exec` can run arbitrary commands
183 shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
184 }
185}
186
187func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) fantasy.AgentTool {
188 // Set up command blocking on the persistent shell
189 persistentShell := shell.GetPersistentShell(workingDir)
190 persistentShell.SetBlockFuncs(blockFuncs())
191 return fantasy.NewAgentTool(
192 BashToolName,
193 string(bashDescription(attribution)),
194 func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
195 if params.Timeout > MaxTimeout {
196 params.Timeout = MaxTimeout
197 } else if params.Timeout <= 0 {
198 params.Timeout = DefaultTimeout
199 }
200
201 if params.Command == "" {
202 return fantasy.NewTextErrorResponse("missing command"), nil
203 }
204
205 isSafeReadOnly := false
206 cmdLower := strings.ToLower(params.Command)
207
208 for _, safe := range safeCommands {
209 if strings.HasPrefix(cmdLower, safe) {
210 if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
211 isSafeReadOnly = true
212 break
213 }
214 }
215 }
216
217 sessionID := GetSessionFromContext(ctx)
218 if sessionID == "" {
219 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
220 }
221 if !isSafeReadOnly {
222 var shellDir string
223 if params.Background {
224 shellDir = workingDir
225 } else {
226 shellDir = shell.GetPersistentShell(workingDir).GetWorkingDir()
227 }
228 p := permissions.Request(
229 permission.CreatePermissionRequest{
230 SessionID: sessionID,
231 Path: shellDir,
232 ToolCallID: call.ID,
233 ToolName: BashToolName,
234 Action: "execute",
235 Description: fmt.Sprintf("Execute command: %s", params.Command),
236 Params: BashPermissionsParams{
237 Command: params.Command,
238 Description: params.Description,
239 Background: params.Background,
240 },
241 },
242 )
243 if !p {
244 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
245 }
246 }
247
248 if params.Background {
249 startTime := time.Now()
250 bgManager := shell.GetBackgroundShellManager()
251 bgShell, err := bgManager.Start(ctx, workingDir, blockFuncs(), params.Command)
252 if err != nil {
253 return fantasy.ToolResponse{}, fmt.Errorf("error starting background shell: %w", err)
254 }
255
256 metadata := BashResponseMetadata{
257 StartTime: startTime.UnixMilli(),
258 EndTime: time.Now().UnixMilli(),
259 Description: params.Description,
260 WorkingDirectory: bgShell.GetWorkingDir(),
261 Background: true,
262 ShellID: bgShell.ID,
263 }
264 response := fmt.Sprintf("Background shell started with ID: %s\n\nUse bash_output tool to view output or bash_kill to terminate.", bgShell.ID)
265 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
266 }
267
268 startTime := time.Now()
269 if params.Timeout > 0 {
270 var cancel context.CancelFunc
271 ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
272 defer cancel()
273 }
274
275 persistentShell := shell.GetPersistentShell(workingDir)
276 stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
277
278 currentWorkingDir := persistentShell.GetWorkingDir()
279 interrupted := shell.IsInterrupt(err)
280 exitCode := shell.ExitCode(err)
281 if exitCode == 0 && !interrupted && err != nil {
282 return fantasy.ToolResponse{}, fmt.Errorf("error executing command: %w", err)
283 }
284
285 stdout = truncateOutput(stdout)
286 stderr = truncateOutput(stderr)
287
288 errorMessage := stderr
289 if errorMessage == "" && err != nil {
290 errorMessage = err.Error()
291 }
292
293 if interrupted {
294 if errorMessage != "" {
295 errorMessage += "\n"
296 }
297 errorMessage += "Command was aborted before completion"
298 } else if exitCode != 0 {
299 if errorMessage != "" {
300 errorMessage += "\n"
301 }
302 errorMessage += fmt.Sprintf("Exit code %d", exitCode)
303 }
304
305 hasBothOutputs := stdout != "" && stderr != ""
306
307 if hasBothOutputs {
308 stdout += "\n"
309 }
310
311 if errorMessage != "" {
312 stdout += "\n" + errorMessage
313 }
314
315 metadata := BashResponseMetadata{
316 StartTime: startTime.UnixMilli(),
317 EndTime: time.Now().UnixMilli(),
318 Output: stdout,
319 Description: params.Description,
320 WorkingDirectory: currentWorkingDir,
321 }
322 if stdout == "" {
323 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
324 }
325 stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(currentWorkingDir))
326 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
327 })
328}
329
330func truncateOutput(content string) string {
331 if len(content) <= MaxOutputLength {
332 return content
333 }
334
335 halfLength := MaxOutputLength / 2
336 start := content[:halfLength]
337 end := content[len(content)-halfLength:]
338
339 truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
340 return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
341}
342
343func countLines(s string) int {
344 if s == "" {
345 return 0
346 }
347 return len(strings.Split(s, "\n"))
348}
349
350func normalizeWorkingDir(path string) string {
351 if runtime.GOOS == "windows" {
352 cwd, err := os.Getwd()
353 if err != nil {
354 cwd = "C:"
355 }
356 path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
357 }
358
359 return filepath.ToSlash(path)
360}