1package tools
2
3import (
4 "bytes"
5 "context"
6 _ "embed"
7 "encoding/json"
8 "fmt"
9 "html/template"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/permission"
15 "github.com/charmbracelet/crush/internal/shell"
16)
17
18type BashParams struct {
19 Command string `json:"command"`
20 Timeout int `json:"timeout"`
21}
22
23type BashPermissionsParams struct {
24 Command string `json:"command"`
25 Timeout int `json:"timeout"`
26}
27
28type BashResponseMetadata struct {
29 StartTime int64 `json:"start_time"`
30 EndTime int64 `json:"end_time"`
31 Output string `json:"output"`
32 WorkingDirectory string `json:"working_directory"`
33}
34type bashTool struct {
35 permissions permission.Service
36 workingDir string
37 attribution *config.Attribution
38}
39
40const (
41 BashToolName = "bash"
42
43 DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
44 MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
45 MaxOutputLength = 30000
46 BashNoOutput = "no output"
47)
48
49//go:embed bash.md
50var bashDescription []byte
51
52var bashDescriptionTpl = template.Must(
53 template.New("bashDescription").
54 Parse(string(bashDescription)),
55)
56
57type bashDescriptionData struct {
58 BannedCommands string
59 MaxOutputLength int
60 AttributionStep string
61 AttributionExample string
62 PRAttribution string
63}
64
65var bannedCommands = []string{
66 // Network/Download tools
67 "alias",
68 "aria2c",
69 "axel",
70 "chrome",
71 "curl",
72 "curlie",
73 "firefox",
74 "http-prompt",
75 "httpie",
76 "links",
77 "lynx",
78 "nc",
79 "safari",
80 "scp",
81 "ssh",
82 "telnet",
83 "w3m",
84 "wget",
85 "xh",
86
87 // System administration
88 "doas",
89 "su",
90 "sudo",
91
92 // Package managers
93 "apk",
94 "apt",
95 "apt-cache",
96 "apt-get",
97 "dnf",
98 "dpkg",
99 "emerge",
100 "home-manager",
101 "makepkg",
102 "opkg",
103 "pacman",
104 "paru",
105 "pkg",
106 "pkg_add",
107 "pkg_delete",
108 "portage",
109 "rpm",
110 "yay",
111 "yum",
112 "zypper",
113
114 // System modification
115 "at",
116 "batch",
117 "chkconfig",
118 "crontab",
119 "fdisk",
120 "mkfs",
121 "mount",
122 "parted",
123 "service",
124 "systemctl",
125 "umount",
126
127 // Network configuration
128 "firewall-cmd",
129 "ifconfig",
130 "ip",
131 "iptables",
132 "netstat",
133 "pfctl",
134 "route",
135 "ufw",
136}
137
138func (b *bashTool) bashDescription() string {
139 bannedCommandsStr := strings.Join(bannedCommands, ", ")
140
141 // Build attribution text based on settings
142 var attributionStep, attributionExample, prAttribution string
143
144 // Default to true if attribution is nil (backward compatibility)
145 generatedWith := b.attribution == nil || b.attribution.GeneratedWith
146 coAuthoredBy := b.attribution == nil || b.attribution.CoAuthoredBy
147
148 // Build PR attribution
149 if generatedWith {
150 prAttribution = "💘 Generated with Crush"
151 }
152
153 if generatedWith || coAuthoredBy {
154 var attributionParts []string
155 if generatedWith {
156 attributionParts = append(attributionParts, "💘 Generated with Crush")
157 }
158 if coAuthoredBy {
159 attributionParts = append(attributionParts, "Co-Authored-By: Crush <crush@charm.land>")
160 }
161
162 if len(attributionParts) > 0 {
163 attributionStep = fmt.Sprintf("4. Create the commit with a message ending with:\n%s", strings.Join(attributionParts, "\n"))
164
165 attributionText := strings.Join(attributionParts, "\n ")
166 attributionExample = fmt.Sprintf(`<example>
167git commit -m "$(cat <<'EOF'
168 Commit message here.
169
170 %s
171 EOF
172)"</example>`, attributionText)
173 }
174 }
175
176 if attributionStep == "" {
177 attributionStep = "4. Create the commit with your commit message."
178 attributionExample = `<example>
179git commit -m "$(cat <<'EOF'
180 Commit message here.
181 EOF
182)"</example>`
183 }
184
185 var out bytes.Buffer
186 if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
187 BannedCommands: bannedCommandsStr,
188 MaxOutputLength: MaxOutputLength,
189 AttributionStep: attributionStep,
190 AttributionExample: attributionExample,
191 PRAttribution: prAttribution,
192 }); err != nil {
193 // this should never happen.
194 panic("failed to execute bash description template: " + err.Error())
195 }
196 return out.String()
197}
198
199func blockFuncs() []shell.BlockFunc {
200 return []shell.BlockFunc{
201 shell.CommandsBlocker(bannedCommands),
202
203 // System package managers
204 shell.ArgumentsBlocker("apk", []string{"add"}, nil),
205 shell.ArgumentsBlocker("apt", []string{"install"}, nil),
206 shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
207 shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
208 shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
209 shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
210 shell.ArgumentsBlocker("yum", []string{"install"}, nil),
211 shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
212
213 // Language-specific package managers
214 shell.ArgumentsBlocker("brew", []string{"install"}, nil),
215 shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
216 shell.ArgumentsBlocker("gem", []string{"install"}, nil),
217 shell.ArgumentsBlocker("go", []string{"install"}, nil),
218 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
219 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
220 shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
221 shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
222 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
223 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
224 shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
225
226 // `go test -exec` can run arbitrary commands
227 shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
228 }
229}
230
231func NewBashTool(permission permission.Service, workingDir string, attribution *config.Attribution) BaseTool {
232 // Set up command blocking on the persistent shell
233 persistentShell := shell.GetPersistentShell(workingDir)
234 persistentShell.SetBlockFuncs(blockFuncs())
235
236 return &bashTool{
237 permissions: permission,
238 workingDir: workingDir,
239 attribution: attribution,
240 }
241}
242
243func (b *bashTool) Name() string {
244 return BashToolName
245}
246
247func (b *bashTool) Info() ToolInfo {
248 return ToolInfo{
249 Name: BashToolName,
250 Description: b.bashDescription(),
251 Parameters: map[string]any{
252 "command": map[string]any{
253 "type": "string",
254 "description": "The command to execute",
255 },
256 "timeout": map[string]any{
257 "type": "number",
258 "description": "Optional timeout in milliseconds (max 600000)",
259 },
260 },
261 Required: []string{"command"},
262 }
263}
264
265func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
266 var params BashParams
267 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
268 return NewTextErrorResponse("invalid parameters"), nil
269 }
270
271 if params.Timeout > MaxTimeout {
272 params.Timeout = MaxTimeout
273 } else if params.Timeout <= 0 {
274 params.Timeout = DefaultTimeout
275 }
276
277 if params.Command == "" {
278 return NewTextErrorResponse("missing command"), nil
279 }
280
281 isSafeReadOnly := false
282 cmdLower := strings.ToLower(params.Command)
283
284 for _, safe := range safeCommands {
285 if strings.HasPrefix(cmdLower, safe) {
286 if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
287 isSafeReadOnly = true
288 break
289 }
290 }
291 }
292
293 sessionID, messageID := GetContextValues(ctx)
294 if sessionID == "" || messageID == "" {
295 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for executing shell command")
296 }
297 if !isSafeReadOnly {
298 shell := shell.GetPersistentShell(b.workingDir)
299 p := b.permissions.Request(
300 permission.CreatePermissionRequest{
301 SessionID: sessionID,
302 Path: shell.GetWorkingDir(),
303 ToolCallID: call.ID,
304 ToolName: BashToolName,
305 Action: "execute",
306 Description: fmt.Sprintf("Execute command: %s", params.Command),
307 Params: BashPermissionsParams{
308 Command: params.Command,
309 },
310 },
311 )
312 if !p {
313 return ToolResponse{}, permission.ErrorPermissionDenied
314 }
315 }
316 startTime := time.Now()
317 if params.Timeout > 0 {
318 var cancel context.CancelFunc
319 ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
320 defer cancel()
321 }
322
323 persistentShell := shell.GetPersistentShell(b.workingDir)
324 stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
325
326 // Get the current working directory after command execution
327 currentWorkingDir := persistentShell.GetWorkingDir()
328 interrupted := shell.IsInterrupt(err)
329 exitCode := shell.ExitCode(err)
330 if exitCode == 0 && !interrupted && err != nil {
331 return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
332 }
333
334 stdout = truncateOutput(stdout)
335 stderr = truncateOutput(stderr)
336
337 errorMessage := stderr
338 if errorMessage == "" && err != nil {
339 errorMessage = err.Error()
340 }
341
342 if interrupted {
343 if errorMessage != "" {
344 errorMessage += "\n"
345 }
346 errorMessage += "Command was aborted before completion"
347 } else if exitCode != 0 {
348 if errorMessage != "" {
349 errorMessage += "\n"
350 }
351 errorMessage += fmt.Sprintf("Exit code %d", exitCode)
352 }
353
354 hasBothOutputs := stdout != "" && stderr != ""
355
356 if hasBothOutputs {
357 stdout += "\n"
358 }
359
360 if errorMessage != "" {
361 stdout += "\n" + errorMessage
362 }
363
364 metadata := BashResponseMetadata{
365 StartTime: startTime.UnixMilli(),
366 EndTime: time.Now().UnixMilli(),
367 Output: stdout,
368 WorkingDirectory: currentWorkingDir,
369 }
370 if stdout == "" {
371 return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil
372 }
373 stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", currentWorkingDir)
374 return WithResponseMetadata(NewTextResponse(stdout), metadata), nil
375}
376
377func truncateOutput(content string) string {
378 if len(content) <= MaxOutputLength {
379 return content
380 }
381
382 halfLength := MaxOutputLength / 2
383 start := content[:halfLength]
384 end := content[len(content)-halfLength:]
385
386 truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
387 return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
388}
389
390func countLines(s string) int {
391 if s == "" {
392 return 0
393 }
394 return len(strings.Split(s, "\n"))
395}