1package tools
2
3import (
4 "bytes"
5 "context"
6 _ "embed"
7 "encoding/json"
8 "fmt"
9 "html/template"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/permission"
15 "github.com/charmbracelet/crush/internal/proto"
16 "github.com/charmbracelet/crush/internal/shell"
17)
18
19type (
20 BashParams = proto.BashParams
21 BashPermissionsParams = proto.BashPermissionsParams
22 BashResponseMetadata = proto.BashResponseMetadata
23)
24
25type bashTool struct {
26 permissions permission.Service
27 workingDir string
28 attribution *config.Attribution
29}
30
31const (
32 BashToolName = proto.BashToolName
33
34 DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
35 MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
36 MaxOutputLength = 30000
37 BashNoOutput = "no output"
38)
39
40//go:embed bash.md
41var bashDescription []byte
42
43var bashDescriptionTpl = template.Must(
44 template.New("bashDescription").
45 Parse(string(bashDescription)),
46)
47
48type bashDescriptionData struct {
49 BannedCommands string
50 MaxOutputLength int
51 AttributionStep string
52 AttributionExample string
53 PRAttribution string
54}
55
56var bannedCommands = []string{
57 // Network/Download tools
58 "alias",
59 "aria2c",
60 "axel",
61 "chrome",
62 "curl",
63 "curlie",
64 "firefox",
65 "http-prompt",
66 "httpie",
67 "links",
68 "lynx",
69 "nc",
70 "safari",
71 "scp",
72 "ssh",
73 "telnet",
74 "w3m",
75 "wget",
76 "xh",
77
78 // System administration
79 "doas",
80 "su",
81 "sudo",
82
83 // Package managers
84 "apk",
85 "apt",
86 "apt-cache",
87 "apt-get",
88 "dnf",
89 "dpkg",
90 "emerge",
91 "home-manager",
92 "makepkg",
93 "opkg",
94 "pacman",
95 "paru",
96 "pkg",
97 "pkg_add",
98 "pkg_delete",
99 "portage",
100 "rpm",
101 "yay",
102 "yum",
103 "zypper",
104
105 // System modification
106 "at",
107 "batch",
108 "chkconfig",
109 "crontab",
110 "fdisk",
111 "mkfs",
112 "mount",
113 "parted",
114 "service",
115 "systemctl",
116 "umount",
117
118 // Network configuration
119 "firewall-cmd",
120 "ifconfig",
121 "ip",
122 "iptables",
123 "netstat",
124 "pfctl",
125 "route",
126 "ufw",
127}
128
129func (b *bashTool) bashDescription() string {
130 bannedCommandsStr := strings.Join(bannedCommands, ", ")
131
132 // Build attribution text based on settings
133 var attributionStep, attributionExample, prAttribution string
134
135 // Default to true if attribution is nil (backward compatibility)
136 generatedWith := b.attribution == nil || b.attribution.GeneratedWith
137 coAuthoredBy := b.attribution == nil || b.attribution.CoAuthoredBy
138
139 // Build PR attribution
140 if generatedWith {
141 prAttribution = "💘 Generated with Crush"
142 }
143
144 if generatedWith || coAuthoredBy {
145 var attributionParts []string
146 if generatedWith {
147 attributionParts = append(attributionParts, "💘 Generated with Crush")
148 }
149 if coAuthoredBy {
150 attributionParts = append(attributionParts, "Co-Authored-By: Crush <crush@charm.land>")
151 }
152
153 if len(attributionParts) > 0 {
154 attributionStep = fmt.Sprintf("4. Create the commit with a message ending with:\n%s", strings.Join(attributionParts, "\n"))
155
156 attributionText := strings.Join(attributionParts, "\n ")
157 attributionExample = fmt.Sprintf(`<example>
158git commit -m "$(cat <<'EOF'
159 Commit message here.
160
161 %s
162 EOF
163)"</example>`, attributionText)
164 }
165 }
166
167 if attributionStep == "" {
168 attributionStep = "4. Create the commit with your commit message."
169 attributionExample = `<example>
170git commit -m "$(cat <<'EOF'
171 Commit message here.
172 EOF
173)"</example>`
174 }
175
176 var out bytes.Buffer
177 if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
178 BannedCommands: bannedCommandsStr,
179 MaxOutputLength: MaxOutputLength,
180 AttributionStep: attributionStep,
181 AttributionExample: attributionExample,
182 PRAttribution: prAttribution,
183 }); err != nil {
184 // this should never happen.
185 panic("failed to execute bash description template: " + err.Error())
186 }
187 return out.String()
188}
189
190func blockFuncs() []shell.BlockFunc {
191 return []shell.BlockFunc{
192 shell.CommandsBlocker(bannedCommands),
193
194 // System package managers
195 shell.ArgumentsBlocker("apk", []string{"add"}, nil),
196 shell.ArgumentsBlocker("apt", []string{"install"}, nil),
197 shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
198 shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
199 shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
200 shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
201 shell.ArgumentsBlocker("yum", []string{"install"}, nil),
202 shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
203
204 // Language-specific package managers
205 shell.ArgumentsBlocker("brew", []string{"install"}, nil),
206 shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
207 shell.ArgumentsBlocker("gem", []string{"install"}, nil),
208 shell.ArgumentsBlocker("go", []string{"install"}, nil),
209 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
210 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
211 shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
212 shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
213 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
214 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
215 shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
216
217 // `go test -exec` can run arbitrary commands
218 shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
219 }
220}
221
222func NewBashTool(permission permission.Service, workingDir string, attribution *config.Attribution) BaseTool {
223 // Set up command blocking on the persistent shell
224 persistentShell := shell.GetPersistentShell(workingDir)
225 persistentShell.SetBlockFuncs(blockFuncs())
226
227 return &bashTool{
228 permissions: permission,
229 workingDir: workingDir,
230 attribution: attribution,
231 }
232}
233
234func (b *bashTool) Name() string {
235 return BashToolName
236}
237
238func (b *bashTool) Info() ToolInfo {
239 return ToolInfo{
240 Name: BashToolName,
241 Description: b.bashDescription(),
242 Parameters: map[string]any{
243 "command": map[string]any{
244 "type": "string",
245 "description": "The command to execute",
246 },
247 "timeout": map[string]any{
248 "type": "number",
249 "description": "Optional timeout in milliseconds (max 600000)",
250 },
251 },
252 Required: []string{"command"},
253 }
254}
255
256func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
257 var params BashParams
258 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
259 return NewTextErrorResponse("invalid parameters"), nil
260 }
261
262 if params.Timeout > MaxTimeout {
263 params.Timeout = MaxTimeout
264 } else if params.Timeout <= 0 {
265 params.Timeout = DefaultTimeout
266 }
267
268 if params.Command == "" {
269 return NewTextErrorResponse("missing command"), nil
270 }
271
272 isSafeReadOnly := false
273 cmdLower := strings.ToLower(params.Command)
274
275 for _, safe := range safeCommands {
276 if strings.HasPrefix(cmdLower, safe) {
277 if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
278 isSafeReadOnly = true
279 break
280 }
281 }
282 }
283
284 sessionID, messageID := GetContextValues(ctx)
285 if sessionID == "" || messageID == "" {
286 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for executing shell command")
287 }
288 if !isSafeReadOnly {
289 shell := shell.GetPersistentShell(b.workingDir)
290 p := b.permissions.Request(
291 permission.CreatePermissionRequest{
292 SessionID: sessionID,
293 Path: shell.GetWorkingDir(),
294 ToolCallID: call.ID,
295 ToolName: BashToolName,
296 Action: "execute",
297 Description: fmt.Sprintf("Execute command: %s", params.Command),
298 Params: BashPermissionsParams{
299 Command: params.Command,
300 },
301 },
302 )
303 if !p {
304 return ToolResponse{}, permission.ErrorPermissionDenied
305 }
306 }
307 startTime := time.Now()
308 if params.Timeout > 0 {
309 var cancel context.CancelFunc
310 ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
311 defer cancel()
312 }
313
314 persistentShell := shell.GetPersistentShell(b.workingDir)
315 stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
316
317 // Get the current working directory after command execution
318 currentWorkingDir := persistentShell.GetWorkingDir()
319 interrupted := shell.IsInterrupt(err)
320 exitCode := shell.ExitCode(err)
321 if exitCode == 0 && !interrupted && err != nil {
322 return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
323 }
324
325 stdout = truncateOutput(stdout)
326 stderr = truncateOutput(stderr)
327
328 errorMessage := stderr
329 if errorMessage == "" && err != nil {
330 errorMessage = err.Error()
331 }
332
333 if interrupted {
334 if errorMessage != "" {
335 errorMessage += "\n"
336 }
337 errorMessage += "Command was aborted before completion"
338 } else if exitCode != 0 {
339 if errorMessage != "" {
340 errorMessage += "\n"
341 }
342 errorMessage += fmt.Sprintf("Exit code %d", exitCode)
343 }
344
345 hasBothOutputs := stdout != "" && stderr != ""
346
347 if hasBothOutputs {
348 stdout += "\n"
349 }
350
351 if errorMessage != "" {
352 stdout += "\n" + errorMessage
353 }
354
355 metadata := BashResponseMetadata{
356 StartTime: startTime.UnixMilli(),
357 EndTime: time.Now().UnixMilli(),
358 Output: stdout,
359 WorkingDirectory: currentWorkingDir,
360 }
361 if stdout == "" {
362 return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil
363 }
364 stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", currentWorkingDir)
365 return WithResponseMetadata(NewTextResponse(stdout), metadata), nil
366}
367
368func truncateOutput(content string) string {
369 if len(content) <= MaxOutputLength {
370 return content
371 }
372
373 halfLength := MaxOutputLength / 2
374 start := content[:halfLength]
375 end := content[len(content)-halfLength:]
376
377 truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
378 return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
379}
380
381func countLines(s string) int {
382 if s == "" {
383 return 0
384 }
385 return len(strings.Split(s, "\n"))
386}