1package tools
2
3import (
4 "bytes"
5 "context"
6 _ "embed"
7 "fmt"
8 "html/template"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/permission"
14 "github.com/charmbracelet/crush/internal/shell"
15 "github.com/charmbracelet/fantasy/ai"
16)
17
18type BashParams struct {
19 Command string `json:"command" description:"The command to execute"`
20 Description string `json:"description,omitempty" description:"A brief description of what the command does"`
21 Timeout int `json:"timeout,omitempty" description:"Optional timeout in milliseconds (max 600000)"`
22}
23
24type BashPermissionsParams struct {
25 Command string `json:"command"`
26 Description string `json:"description"`
27 Timeout int `json:"timeout"`
28}
29
30type BashResponseMetadata struct {
31 StartTime int64 `json:"start_time"`
32 EndTime int64 `json:"end_time"`
33 Output string `json:"output"`
34 Description string `json:"description"`
35 WorkingDirectory string `json:"working_directory"`
36}
37
38const (
39 BashToolName = "bash"
40
41 DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
42 MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
43 MaxOutputLength = 30000
44 BashNoOutput = "no output"
45)
46
47//go:embed bash.tpl
48var bashDescriptionTmpl []byte
49
50var bashDescriptionTpl = template.Must(
51 template.New("bashDescription").
52 Parse(string(bashDescriptionTmpl)),
53)
54
55type bashDescriptionData struct {
56 BannedCommands string
57 MaxOutputLength int
58 Attribution config.Attribution
59}
60
61var bannedCommands = []string{
62 // Network/Download tools
63 "alias",
64 "aria2c",
65 "axel",
66 "chrome",
67 "curl",
68 "curlie",
69 "firefox",
70 "http-prompt",
71 "httpie",
72 "links",
73 "lynx",
74 "nc",
75 "safari",
76 "scp",
77 "ssh",
78 "telnet",
79 "w3m",
80 "wget",
81 "xh",
82
83 // System administration
84 "doas",
85 "su",
86 "sudo",
87
88 // Package managers
89 "apk",
90 "apt",
91 "apt-cache",
92 "apt-get",
93 "dnf",
94 "dpkg",
95 "emerge",
96 "home-manager",
97 "makepkg",
98 "opkg",
99 "pacman",
100 "paru",
101 "pkg",
102 "pkg_add",
103 "pkg_delete",
104 "portage",
105 "rpm",
106 "yay",
107 "yum",
108 "zypper",
109
110 // System modification
111 "at",
112 "batch",
113 "chkconfig",
114 "crontab",
115 "fdisk",
116 "mkfs",
117 "mount",
118 "parted",
119 "service",
120 "systemctl",
121 "umount",
122
123 // Network configuration
124 "firewall-cmd",
125 "ifconfig",
126 "ip",
127 "iptables",
128 "netstat",
129 "pfctl",
130 "route",
131 "ufw",
132}
133
134func bashDescription(attribution *config.Attribution) string {
135 bannedCommandsStr := strings.Join(bannedCommands, ", ")
136 var out bytes.Buffer
137 if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
138 BannedCommands: bannedCommandsStr,
139 MaxOutputLength: MaxOutputLength,
140 Attribution: *attribution,
141 }); err != nil {
142 // this should never happen.
143 panic("failed to execute bash description template: " + err.Error())
144 }
145 return out.String()
146}
147
148func blockFuncs() []shell.BlockFunc {
149 return []shell.BlockFunc{
150 shell.CommandsBlocker(bannedCommands),
151
152 // System package managers
153 shell.ArgumentsBlocker("apk", []string{"add"}, nil),
154 shell.ArgumentsBlocker("apt", []string{"install"}, nil),
155 shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
156 shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
157 shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
158 shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
159 shell.ArgumentsBlocker("yum", []string{"install"}, nil),
160 shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
161
162 // Language-specific package managers
163 shell.ArgumentsBlocker("brew", []string{"install"}, nil),
164 shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
165 shell.ArgumentsBlocker("gem", []string{"install"}, nil),
166 shell.ArgumentsBlocker("go", []string{"install"}, nil),
167 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
168 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
169 shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
170 shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
171 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
172 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
173 shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
174
175 // `go test -exec` can run arbitrary commands
176 shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
177 }
178}
179
180func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) ai.AgentTool {
181 // Set up command blocking on the persistent shell
182 persistentShell := shell.GetPersistentShell(workingDir)
183 persistentShell.SetBlockFuncs(blockFuncs())
184 return ai.NewAgentTool(
185 BashToolName,
186 string(bashDescription(attribution)),
187 func(ctx context.Context, params BashParams, call ai.ToolCall) (ai.ToolResponse, error) {
188 if params.Timeout > MaxTimeout {
189 params.Timeout = MaxTimeout
190 } else if params.Timeout <= 0 {
191 params.Timeout = DefaultTimeout
192 }
193
194 if params.Command == "" {
195 return ai.NewTextErrorResponse("missing command"), nil
196 }
197
198 isSafeReadOnly := false
199 cmdLower := strings.ToLower(params.Command)
200
201 for _, safe := range safeCommands {
202 if strings.HasPrefix(cmdLower, safe) {
203 if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
204 isSafeReadOnly = true
205 break
206 }
207 }
208 }
209
210 sessionID := GetSessionFromContext(ctx)
211 if sessionID == "" {
212 return ai.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
213 }
214 if !isSafeReadOnly {
215 shell := shell.GetPersistentShell(workingDir)
216 p := permissions.Request(
217 permission.CreatePermissionRequest{
218 SessionID: sessionID,
219 Path: shell.GetWorkingDir(),
220 ToolCallID: call.ID,
221 ToolName: BashToolName,
222 Action: "execute",
223 Description: fmt.Sprintf("Execute command: %s", params.Command),
224 Params: BashPermissionsParams{
225 Command: params.Command,
226 Description: params.Description,
227 },
228 },
229 )
230 if !p {
231 return ai.ToolResponse{}, permission.ErrorPermissionDenied
232 }
233 }
234 startTime := time.Now()
235 if params.Timeout > 0 {
236 var cancel context.CancelFunc
237 ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
238 defer cancel()
239 }
240
241 persistentShell := shell.GetPersistentShell(workingDir)
242 stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
243
244 // Get the current working directory after command execution
245 currentWorkingDir := persistentShell.GetWorkingDir()
246 interrupted := shell.IsInterrupt(err)
247 exitCode := shell.ExitCode(err)
248 if exitCode == 0 && !interrupted && err != nil {
249 return ai.ToolResponse{}, fmt.Errorf("error executing command: %w", err)
250 }
251
252 stdout = truncateOutput(stdout)
253 stderr = truncateOutput(stderr)
254
255 errorMessage := stderr
256 if errorMessage == "" && err != nil {
257 errorMessage = err.Error()
258 }
259
260 if interrupted {
261 if errorMessage != "" {
262 errorMessage += "\n"
263 }
264 errorMessage += "Command was aborted before completion"
265 } else if exitCode != 0 {
266 if errorMessage != "" {
267 errorMessage += "\n"
268 }
269 errorMessage += fmt.Sprintf("Exit code %d", exitCode)
270 }
271
272 hasBothOutputs := stdout != "" && stderr != ""
273
274 if hasBothOutputs {
275 stdout += "\n"
276 }
277
278 if errorMessage != "" {
279 stdout += "\n" + errorMessage
280 }
281
282 metadata := BashResponseMetadata{
283 StartTime: startTime.UnixMilli(),
284 EndTime: time.Now().UnixMilli(),
285 Output: stdout,
286 Description: params.Description,
287 WorkingDirectory: currentWorkingDir,
288 }
289 if stdout == "" {
290 return ai.WithResponseMetadata(ai.NewTextResponse(BashNoOutput), metadata), nil
291 }
292 stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", currentWorkingDir)
293 return ai.WithResponseMetadata(ai.NewTextResponse(stdout), metadata), nil
294 })
295}
296
297func truncateOutput(content string) string {
298 if len(content) <= MaxOutputLength {
299 return content
300 }
301
302 halfLength := MaxOutputLength / 2
303 start := content[:halfLength]
304 end := content[len(content)-halfLength:]
305
306 truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
307 return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
308}
309
310func countLines(s string) int {
311 if s == "" {
312 return 0
313 }
314 return len(strings.Split(s, "\n"))
315}