bash.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"fmt"
  8	"html/template"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/permission"
 14	"github.com/charmbracelet/crush/internal/shell"
 15	"github.com/charmbracelet/fantasy/ai"
 16)
 17
 18type BashParams struct {
 19	Command     string `json:"command" description:"The command to execute"`
 20	Description string `json:"description,omitempty" description:"A brief description of what the command does"`
 21	Timeout     int    `json:"timeout,omitempty" description:"Optional timeout in milliseconds (max 600000)"`
 22}
 23
 24type BashPermissionsParams struct {
 25	Command     string `json:"command"`
 26	Description string `json:"description"`
 27	Timeout     int    `json:"timeout"`
 28}
 29
 30type BashResponseMetadata struct {
 31	StartTime        int64  `json:"start_time"`
 32	EndTime          int64  `json:"end_time"`
 33	Output           string `json:"output"`
 34	Description      string `json:"description"`
 35	WorkingDirectory string `json:"working_directory"`
 36}
 37
 38const (
 39	BashToolName = "bash"
 40
 41	DefaultTimeout  = 1 * 60 * 1000  // 1 minutes in milliseconds
 42	MaxTimeout      = 10 * 60 * 1000 // 10 minutes in milliseconds
 43	MaxOutputLength = 30000
 44	BashNoOutput    = "no output"
 45)
 46
 47//go:embed bash.gotmpl
 48var bashDescriptionTmpl []byte
 49
 50var bashDescriptionTpl = template.Must(
 51	template.New("bashDescription").
 52		Parse(string(bashDescriptionTmpl)),
 53)
 54
 55type bashDescriptionData struct {
 56	BannedCommands  string
 57	MaxOutputLength int
 58	Attribution     config.Attribution
 59}
 60
 61var bannedCommands = []string{
 62	// Network/Download tools
 63	"alias",
 64	"aria2c",
 65	"axel",
 66	"chrome",
 67	"curl",
 68	"curlie",
 69	"firefox",
 70	"http-prompt",
 71	"httpie",
 72	"links",
 73	"lynx",
 74	"nc",
 75	"safari",
 76	"scp",
 77	"ssh",
 78	"telnet",
 79	"w3m",
 80	"wget",
 81	"xh",
 82
 83	// System administration
 84	"doas",
 85	"su",
 86	"sudo",
 87
 88	// Package managers
 89	"apk",
 90	"apt",
 91	"apt-cache",
 92	"apt-get",
 93	"dnf",
 94	"dpkg",
 95	"emerge",
 96	"home-manager",
 97	"makepkg",
 98	"opkg",
 99	"pacman",
100	"paru",
101	"pkg",
102	"pkg_add",
103	"pkg_delete",
104	"portage",
105	"rpm",
106	"yay",
107	"yum",
108	"zypper",
109
110	// System modification
111	"at",
112	"batch",
113	"chkconfig",
114	"crontab",
115	"fdisk",
116	"mkfs",
117	"mount",
118	"parted",
119	"service",
120	"systemctl",
121	"umount",
122
123	// Network configuration
124	"firewall-cmd",
125	"ifconfig",
126	"ip",
127	"iptables",
128	"netstat",
129	"pfctl",
130	"route",
131	"ufw",
132}
133
134func bashDescription(attribution *config.Attribution) string {
135	bannedCommandsStr := strings.Join(bannedCommands, ", ")
136	var out bytes.Buffer
137	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
138		BannedCommands:  bannedCommandsStr,
139		MaxOutputLength: MaxOutputLength,
140		Attribution:     *attribution,
141	}); err != nil {
142		// this should never happen.
143		panic("failed to execute bash description template: " + err.Error())
144	}
145	return out.String()
146}
147
148func blockFuncs() []shell.BlockFunc {
149	return []shell.BlockFunc{
150		shell.CommandsBlocker(bannedCommands),
151
152		// System package managers
153		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
154		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
155		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
156		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
157		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
158		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
159		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
160		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
161
162		// Language-specific package managers
163		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
164		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
165		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
166		shell.ArgumentsBlocker("go", []string{"install"}, nil),
167		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
168		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
169		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
170		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
171		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
172		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
173		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
174
175		// `go test -exec` can run arbitrary commands
176		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
177	}
178}
179
180func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) ai.AgentTool {
181	// Set up command blocking on the persistent shell
182	persistentShell := shell.GetPersistentShell(workingDir)
183	persistentShell.SetBlockFuncs(blockFuncs())
184	return ai.NewAgentTool(
185		BashToolName,
186		string(bashDescription(attribution)),
187		func(ctx context.Context, params BashParams, call ai.ToolCall) (ai.ToolResponse, error) {
188			if params.Timeout > MaxTimeout {
189				params.Timeout = MaxTimeout
190			} else if params.Timeout <= 0 {
191				params.Timeout = DefaultTimeout
192			}
193
194			if params.Command == "" {
195				return ai.NewTextErrorResponse("missing command"), nil
196			}
197
198			isSafeReadOnly := false
199			cmdLower := strings.ToLower(params.Command)
200
201			for _, safe := range safeCommands {
202				if strings.HasPrefix(cmdLower, safe) {
203					if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
204						isSafeReadOnly = true
205						break
206					}
207				}
208			}
209
210			sessionID := GetSessionFromContext(ctx)
211			if sessionID == "" {
212				return ai.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
213			}
214			if !isSafeReadOnly {
215				shell := shell.GetPersistentShell(workingDir)
216				p := permissions.Request(
217					permission.CreatePermissionRequest{
218						SessionID:   sessionID,
219						Path:        shell.GetWorkingDir(),
220						ToolCallID:  call.ID,
221						ToolName:    BashToolName,
222						Action:      "execute",
223						Description: fmt.Sprintf("Execute command: %s", params.Command),
224						Params: BashPermissionsParams{
225							Command:     params.Command,
226							Description: params.Description,
227						},
228					},
229				)
230				if !p {
231					return ai.ToolResponse{}, permission.ErrorPermissionDenied
232				}
233			}
234			startTime := time.Now()
235			if params.Timeout > 0 {
236				var cancel context.CancelFunc
237				ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
238				defer cancel()
239			}
240
241			persistentShell := shell.GetPersistentShell(workingDir)
242			stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
243
244			// Get the current working directory after command execution
245			currentWorkingDir := persistentShell.GetWorkingDir()
246			interrupted := shell.IsInterrupt(err)
247			exitCode := shell.ExitCode(err)
248			if exitCode == 0 && !interrupted && err != nil {
249				return ai.ToolResponse{}, fmt.Errorf("error executing command: %w", err)
250			}
251
252			stdout = truncateOutput(stdout)
253			stderr = truncateOutput(stderr)
254
255			errorMessage := stderr
256			if errorMessage == "" && err != nil {
257				errorMessage = err.Error()
258			}
259
260			if interrupted {
261				if errorMessage != "" {
262					errorMessage += "\n"
263				}
264				errorMessage += "Command was aborted before completion"
265			} else if exitCode != 0 {
266				if errorMessage != "" {
267					errorMessage += "\n"
268				}
269				errorMessage += fmt.Sprintf("Exit code %d", exitCode)
270			}
271
272			hasBothOutputs := stdout != "" && stderr != ""
273
274			if hasBothOutputs {
275				stdout += "\n"
276			}
277
278			if errorMessage != "" {
279				stdout += "\n" + errorMessage
280			}
281
282			metadata := BashResponseMetadata{
283				StartTime:        startTime.UnixMilli(),
284				EndTime:          time.Now().UnixMilli(),
285				Output:           stdout,
286				Description:      params.Description,
287				WorkingDirectory: currentWorkingDir,
288			}
289			if stdout == "" {
290				return ai.WithResponseMetadata(ai.NewTextResponse(BashNoOutput), metadata), nil
291			}
292			stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", currentWorkingDir)
293			return ai.WithResponseMetadata(ai.NewTextResponse(stdout), metadata), nil
294		})
295}
296
297func truncateOutput(content string) string {
298	if len(content) <= MaxOutputLength {
299		return content
300	}
301
302	halfLength := MaxOutputLength / 2
303	start := content[:halfLength]
304	end := content[len(content)-halfLength:]
305
306	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
307	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
308}
309
310func countLines(s string) int {
311	if s == "" {
312		return 0
313	}
314	return len(strings.Split(s, "\n"))
315}