1package tools
2
3import (
4 "bytes"
5 "context"
6 _ "embed"
7 "fmt"
8 "html/template"
9 "os"
10 "path/filepath"
11 "runtime"
12 "strings"
13 "time"
14
15 "charm.land/fantasy"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/permission"
18 "github.com/charmbracelet/crush/internal/shell"
19)
20
21type BashParams struct {
22 Command string `json:"command" description:"The command to execute"`
23 Description string `json:"description,omitempty" description:"A brief description of what the command does"`
24 Timeout int `json:"timeout,omitempty" description:"Optional timeout in milliseconds (max 600000)"`
25}
26
27type BashPermissionsParams struct {
28 Command string `json:"command"`
29 Description string `json:"description"`
30 Timeout int `json:"timeout"`
31}
32
33type BashResponseMetadata struct {
34 StartTime int64 `json:"start_time"`
35 EndTime int64 `json:"end_time"`
36 Output string `json:"output"`
37 Description string `json:"description"`
38 WorkingDirectory string `json:"working_directory"`
39}
40
41const (
42 BashToolName = "bash"
43
44 DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
45 MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
46 MaxOutputLength = 30000
47 BashNoOutput = "no output"
48)
49
50//go:embed bash.tpl
51var bashDescriptionTmpl []byte
52
53var bashDescriptionTpl = template.Must(
54 template.New("bashDescription").
55 Parse(string(bashDescriptionTmpl)),
56)
57
58type bashDescriptionData struct {
59 BannedCommands string
60 MaxOutputLength int
61 Attribution config.Attribution
62}
63
64var bannedCommands = []string{
65 // Network/Download tools
66 "alias",
67 "aria2c",
68 "axel",
69 "chrome",
70 "curl",
71 "curlie",
72 "firefox",
73 "http-prompt",
74 "httpie",
75 "links",
76 "lynx",
77 "nc",
78 "safari",
79 "scp",
80 "ssh",
81 "telnet",
82 "w3m",
83 "wget",
84 "xh",
85
86 // System administration
87 "doas",
88 "su",
89 "sudo",
90
91 // Package managers
92 "apk",
93 "apt",
94 "apt-cache",
95 "apt-get",
96 "dnf",
97 "dpkg",
98 "emerge",
99 "home-manager",
100 "makepkg",
101 "opkg",
102 "pacman",
103 "paru",
104 "pkg",
105 "pkg_add",
106 "pkg_delete",
107 "portage",
108 "rpm",
109 "yay",
110 "yum",
111 "zypper",
112
113 // System modification
114 "at",
115 "batch",
116 "chkconfig",
117 "crontab",
118 "fdisk",
119 "mkfs",
120 "mount",
121 "parted",
122 "service",
123 "systemctl",
124 "umount",
125
126 // Network configuration
127 "firewall-cmd",
128 "ifconfig",
129 "ip",
130 "iptables",
131 "netstat",
132 "pfctl",
133 "route",
134 "ufw",
135}
136
137func bashDescription(attribution *config.Attribution) string {
138 bannedCommandsStr := strings.Join(bannedCommands, ", ")
139 var out bytes.Buffer
140 if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
141 BannedCommands: bannedCommandsStr,
142 MaxOutputLength: MaxOutputLength,
143 Attribution: *attribution,
144 }); err != nil {
145 // this should never happen.
146 panic("failed to execute bash description template: " + err.Error())
147 }
148 return out.String()
149}
150
151func blockFuncs() []shell.BlockFunc {
152 return []shell.BlockFunc{
153 shell.CommandsBlocker(bannedCommands),
154
155 // System package managers
156 shell.ArgumentsBlocker("apk", []string{"add"}, nil),
157 shell.ArgumentsBlocker("apt", []string{"install"}, nil),
158 shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
159 shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
160 shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
161 shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
162 shell.ArgumentsBlocker("yum", []string{"install"}, nil),
163 shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
164
165 // Language-specific package managers
166 shell.ArgumentsBlocker("brew", []string{"install"}, nil),
167 shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
168 shell.ArgumentsBlocker("gem", []string{"install"}, nil),
169 shell.ArgumentsBlocker("go", []string{"install"}, nil),
170 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
171 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
172 shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
173 shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
174 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
175 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
176 shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
177
178 // `go test -exec` can run arbitrary commands
179 shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
180 }
181}
182
183func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) fantasy.AgentTool {
184 // Set up command blocking on the persistent shell
185 persistentShell := shell.GetPersistentShell(workingDir)
186 persistentShell.SetBlockFuncs(blockFuncs())
187 return fantasy.NewAgentTool(
188 BashToolName,
189 string(bashDescription(attribution)),
190 func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
191 if params.Timeout > MaxTimeout {
192 params.Timeout = MaxTimeout
193 } else if params.Timeout <= 0 {
194 params.Timeout = DefaultTimeout
195 }
196
197 if params.Command == "" {
198 return fantasy.NewTextErrorResponse("missing command"), nil
199 }
200
201 isSafeReadOnly := false
202 cmdLower := strings.ToLower(params.Command)
203
204 for _, safe := range safeCommands {
205 if strings.HasPrefix(cmdLower, safe) {
206 if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
207 isSafeReadOnly = true
208 break
209 }
210 }
211 }
212
213 sessionID := GetSessionFromContext(ctx)
214 if sessionID == "" {
215 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
216 }
217 if !isSafeReadOnly {
218 shell := shell.GetPersistentShell(workingDir)
219 p := permissions.Request(
220 permission.CreatePermissionRequest{
221 SessionID: sessionID,
222 Path: shell.GetWorkingDir(),
223 ToolCallID: call.ID,
224 ToolName: BashToolName,
225 Action: "execute",
226 Description: fmt.Sprintf("Execute command: %s", params.Command),
227 Params: BashPermissionsParams{
228 Command: params.Command,
229 Description: params.Description,
230 },
231 },
232 )
233 if !p {
234 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
235 }
236 }
237 startTime := time.Now()
238 if params.Timeout > 0 {
239 var cancel context.CancelFunc
240 ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
241 defer cancel()
242 }
243
244 persistentShell := shell.GetPersistentShell(workingDir)
245 stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
246
247 // Get the current working directory after command execution
248 currentWorkingDir := persistentShell.GetWorkingDir()
249 interrupted := shell.IsInterrupt(err)
250 exitCode := shell.ExitCode(err)
251 if exitCode == 0 && !interrupted && err != nil {
252 return fantasy.ToolResponse{}, fmt.Errorf("error executing command: %w", err)
253 }
254
255 stdout = truncateOutput(stdout)
256 stderr = truncateOutput(stderr)
257
258 errorMessage := stderr
259 if errorMessage == "" && err != nil {
260 errorMessage = err.Error()
261 }
262
263 if interrupted {
264 if errorMessage != "" {
265 errorMessage += "\n"
266 }
267 errorMessage += "Command was aborted before completion"
268 } else if exitCode != 0 {
269 if errorMessage != "" {
270 errorMessage += "\n"
271 }
272 errorMessage += fmt.Sprintf("Exit code %d", exitCode)
273 }
274
275 hasBothOutputs := stdout != "" && stderr != ""
276
277 if hasBothOutputs {
278 stdout += "\n"
279 }
280
281 if errorMessage != "" {
282 stdout += "\n" + errorMessage
283 }
284
285 metadata := BashResponseMetadata{
286 StartTime: startTime.UnixMilli(),
287 EndTime: time.Now().UnixMilli(),
288 Output: stdout,
289 Description: params.Description,
290 WorkingDirectory: currentWorkingDir,
291 }
292 if stdout == "" {
293 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
294 }
295 stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(currentWorkingDir))
296 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
297 })
298}
299
300func truncateOutput(content string) string {
301 if len(content) <= MaxOutputLength {
302 return content
303 }
304
305 halfLength := MaxOutputLength / 2
306 start := content[:halfLength]
307 end := content[len(content)-halfLength:]
308
309 truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
310 return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
311}
312
313func countLines(s string) int {
314 if s == "" {
315 return 0
316 }
317 return len(strings.Split(s, "\n"))
318}
319
320func normalizeWorkingDir(path string) string {
321 if runtime.GOOS == "windows" {
322 cwd, err := os.Getwd()
323 if err != nil {
324 cwd = "C:"
325 }
326 path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
327 }
328
329 return filepath.ToSlash(path)
330}