bash.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"fmt"
  8	"html/template"
  9	"os"
 10	"path/filepath"
 11	"runtime"
 12	"strings"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/permission"
 18	"github.com/charmbracelet/crush/internal/shell"
 19)
 20
 21type BashParams struct {
 22	Command     string `json:"command" description:"The command to execute"`
 23	Description string `json:"description,omitempty" description:"A brief description of what the command does"`
 24	Timeout     int    `json:"timeout,omitempty" description:"Optional timeout in milliseconds (max 600000)"`
 25	Background  bool   `json:"background,omitempty" description:"Run the command in a background shell. Returns a shell ID for managing the process."`
 26}
 27
 28type BashPermissionsParams struct {
 29	Command     string `json:"command"`
 30	Description string `json:"description"`
 31	Timeout     int    `json:"timeout"`
 32	Background  bool   `json:"background"`
 33}
 34
 35type BashResponseMetadata struct {
 36	StartTime        int64  `json:"start_time"`
 37	EndTime          int64  `json:"end_time"`
 38	Output           string `json:"output"`
 39	Description      string `json:"description"`
 40	WorkingDirectory string `json:"working_directory"`
 41	Background       bool   `json:"background,omitempty"`
 42	ShellID          string `json:"shell_id,omitempty"`
 43}
 44
 45const (
 46	BashToolName = "bash"
 47
 48	DefaultTimeout  = 1 * 60 * 1000  // 1 minutes in milliseconds
 49	MaxTimeout      = 10 * 60 * 1000 // 10 minutes in milliseconds
 50	MaxOutputLength = 30000
 51	BashNoOutput    = "no output"
 52)
 53
 54//go:embed bash.tpl
 55var bashDescriptionTmpl []byte
 56
 57var bashDescriptionTpl = template.Must(
 58	template.New("bashDescription").
 59		Parse(string(bashDescriptionTmpl)),
 60)
 61
 62type bashDescriptionData struct {
 63	BannedCommands  string
 64	MaxOutputLength int
 65	Attribution     config.Attribution
 66}
 67
 68var bannedCommands = []string{
 69	// Network/Download tools
 70	"alias",
 71	"aria2c",
 72	"axel",
 73	"chrome",
 74	"curl",
 75	"curlie",
 76	"firefox",
 77	"http-prompt",
 78	"httpie",
 79	"links",
 80	"lynx",
 81	"nc",
 82	"safari",
 83	"scp",
 84	"ssh",
 85	"telnet",
 86	"w3m",
 87	"wget",
 88	"xh",
 89
 90	// System administration
 91	"doas",
 92	"su",
 93	"sudo",
 94
 95	// Package managers
 96	"apk",
 97	"apt",
 98	"apt-cache",
 99	"apt-get",
100	"dnf",
101	"dpkg",
102	"emerge",
103	"home-manager",
104	"makepkg",
105	"opkg",
106	"pacman",
107	"paru",
108	"pkg",
109	"pkg_add",
110	"pkg_delete",
111	"portage",
112	"rpm",
113	"yay",
114	"yum",
115	"zypper",
116
117	// System modification
118	"at",
119	"batch",
120	"chkconfig",
121	"crontab",
122	"fdisk",
123	"mkfs",
124	"mount",
125	"parted",
126	"service",
127	"systemctl",
128	"umount",
129
130	// Network configuration
131	"firewall-cmd",
132	"ifconfig",
133	"ip",
134	"iptables",
135	"netstat",
136	"pfctl",
137	"route",
138	"ufw",
139}
140
141func bashDescription(attribution *config.Attribution) string {
142	bannedCommandsStr := strings.Join(bannedCommands, ", ")
143	var out bytes.Buffer
144	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
145		BannedCommands:  bannedCommandsStr,
146		MaxOutputLength: MaxOutputLength,
147		Attribution:     *attribution,
148	}); err != nil {
149		// this should never happen.
150		panic("failed to execute bash description template: " + err.Error())
151	}
152	return out.String()
153}
154
155func blockFuncs() []shell.BlockFunc {
156	return []shell.BlockFunc{
157		shell.CommandsBlocker(bannedCommands),
158
159		// System package managers
160		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
161		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
162		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
163		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
164		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
165		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
166		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
167		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
168
169		// Language-specific package managers
170		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
171		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
172		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
173		shell.ArgumentsBlocker("go", []string{"install"}, nil),
174		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
175		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
176		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
177		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
178		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
179		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
180		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
181
182		// `go test -exec` can run arbitrary commands
183		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
184	}
185}
186
187func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) fantasy.AgentTool {
188	// Set up command blocking on the persistent shell
189	persistentShell := shell.GetPersistentShell(workingDir)
190	persistentShell.SetBlockFuncs(blockFuncs())
191	return fantasy.NewAgentTool(
192		BashToolName,
193		string(bashDescription(attribution)),
194		func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
195			if params.Timeout > MaxTimeout {
196				params.Timeout = MaxTimeout
197			} else if params.Timeout <= 0 {
198				params.Timeout = DefaultTimeout
199			}
200
201			if params.Command == "" {
202				return fantasy.NewTextErrorResponse("missing command"), nil
203			}
204
205			isSafeReadOnly := false
206			cmdLower := strings.ToLower(params.Command)
207
208			for _, safe := range safeCommands {
209				if strings.HasPrefix(cmdLower, safe) {
210					if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
211						isSafeReadOnly = true
212						break
213					}
214				}
215			}
216
217			sessionID := GetSessionFromContext(ctx)
218			if sessionID == "" {
219				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
220			}
221			if !isSafeReadOnly {
222				var shellDir string
223				if params.Background {
224					shellDir = workingDir
225				} else {
226					shellDir = shell.GetPersistentShell(workingDir).GetWorkingDir()
227				}
228				p := permissions.Request(
229					permission.CreatePermissionRequest{
230						SessionID:   sessionID,
231						Path:        shellDir,
232						ToolCallID:  call.ID,
233						ToolName:    BashToolName,
234						Action:      "execute",
235						Description: fmt.Sprintf("Execute command: %s", params.Command),
236						Params: BashPermissionsParams{
237							Command:     params.Command,
238							Description: params.Description,
239							Background:  params.Background,
240						},
241					},
242				)
243				if !p {
244					return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
245				}
246			}
247
248			if params.Background {
249				startTime := time.Now()
250				bgManager := shell.GetBackgroundShellManager()
251				bgShell, err := bgManager.Start(ctx, workingDir, blockFuncs(), params.Command)
252				if err != nil {
253					return fantasy.ToolResponse{}, fmt.Errorf("error starting background shell: %w", err)
254				}
255
256				metadata := BashResponseMetadata{
257					StartTime:        startTime.UnixMilli(),
258					EndTime:          time.Now().UnixMilli(),
259					Description:      params.Description,
260					WorkingDirectory: bgShell.GetWorkingDir(),
261					Background:       true,
262					ShellID:          bgShell.ID,
263				}
264				response := fmt.Sprintf("Background shell started with ID: %s\n\nUse bash_output tool to view output or bash_kill to terminate.", bgShell.ID)
265				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(response), metadata), nil
266			}
267
268			startTime := time.Now()
269			if params.Timeout > 0 {
270				var cancel context.CancelFunc
271				ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
272				defer cancel()
273			}
274
275			persistentShell := shell.GetPersistentShell(workingDir)
276			stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
277
278			currentWorkingDir := persistentShell.GetWorkingDir()
279			interrupted := shell.IsInterrupt(err)
280			exitCode := shell.ExitCode(err)
281			if exitCode == 0 && !interrupted && err != nil {
282				return fantasy.ToolResponse{}, fmt.Errorf("error executing command: %w", err)
283			}
284
285			stdout = truncateOutput(stdout)
286			stderr = truncateOutput(stderr)
287
288			errorMessage := stderr
289			if errorMessage == "" && err != nil {
290				errorMessage = err.Error()
291			}
292
293			if interrupted {
294				if errorMessage != "" {
295					errorMessage += "\n"
296				}
297				errorMessage += "Command was aborted before completion"
298			} else if exitCode != 0 {
299				if errorMessage != "" {
300					errorMessage += "\n"
301				}
302				errorMessage += fmt.Sprintf("Exit code %d", exitCode)
303			}
304
305			hasBothOutputs := stdout != "" && stderr != ""
306
307			if hasBothOutputs {
308				stdout += "\n"
309			}
310
311			if errorMessage != "" {
312				stdout += "\n" + errorMessage
313			}
314
315			metadata := BashResponseMetadata{
316				StartTime:        startTime.UnixMilli(),
317				EndTime:          time.Now().UnixMilli(),
318				Output:           stdout,
319				Description:      params.Description,
320				WorkingDirectory: currentWorkingDir,
321			}
322			if stdout == "" {
323				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
324			}
325			stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(currentWorkingDir))
326			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
327		})
328}
329
330func truncateOutput(content string) string {
331	if len(content) <= MaxOutputLength {
332		return content
333	}
334
335	halfLength := MaxOutputLength / 2
336	start := content[:halfLength]
337	end := content[len(content)-halfLength:]
338
339	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
340	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
341}
342
343func countLines(s string) int {
344	if s == "" {
345		return 0
346	}
347	return len(strings.Split(s, "\n"))
348}
349
350func normalizeWorkingDir(path string) string {
351	if runtime.GOOS == "windows" {
352		cwd, err := os.Getwd()
353		if err != nil {
354			cwd = "C:"
355		}
356		path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
357	}
358
359	return filepath.ToSlash(path)
360}