bash.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"fmt"
  8	"html/template"
  9	"os"
 10	"path/filepath"
 11	"runtime"
 12	"strings"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"git.secluded.site/crush/internal/config"
 17	"git.secluded.site/crush/internal/permission"
 18	"git.secluded.site/crush/internal/shell"
 19)
 20
 21type BashParams struct {
 22	Command     string `json:"command" description:"The command to execute"`
 23	Description string `json:"description,omitempty" description:"A brief description of what the command does"`
 24	Timeout     int    `json:"timeout,omitempty" description:"Optional timeout in milliseconds (max 600000)"`
 25}
 26
 27type BashPermissionsParams struct {
 28	Command     string `json:"command"`
 29	Description string `json:"description"`
 30	Timeout     int    `json:"timeout"`
 31}
 32
 33type BashResponseMetadata struct {
 34	StartTime        int64  `json:"start_time"`
 35	EndTime          int64  `json:"end_time"`
 36	Output           string `json:"output"`
 37	Description      string `json:"description"`
 38	WorkingDirectory string `json:"working_directory"`
 39}
 40
 41const (
 42	BashToolName = "bash"
 43
 44	DefaultTimeout  = 1 * 60 * 1000  // 1 minutes in milliseconds
 45	MaxTimeout      = 10 * 60 * 1000 // 10 minutes in milliseconds
 46	MaxOutputLength = 30000
 47	BashNoOutput    = "no output"
 48)
 49
 50//go:embed bash.tpl
 51var bashDescriptionTmpl []byte
 52
 53var bashDescriptionTpl = template.Must(
 54	template.New("bashDescription").
 55		Parse(string(bashDescriptionTmpl)),
 56)
 57
 58type bashDescriptionData struct {
 59	BannedCommands  string
 60	MaxOutputLength int
 61	Attribution     config.Attribution
 62	ModelName       string
 63}
 64
 65var bannedCommands = []string{
 66	// Network/Download tools
 67	"alias",
 68	"aria2c",
 69	"axel",
 70	"chrome",
 71	"curl",
 72	"curlie",
 73	"firefox",
 74	"http-prompt",
 75	"httpie",
 76	"links",
 77	"lynx",
 78	"nc",
 79	"safari",
 80	"scp",
 81	"ssh",
 82	"telnet",
 83	"w3m",
 84	"wget",
 85	"xh",
 86
 87	// System administration
 88	"doas",
 89	"su",
 90	"sudo",
 91
 92	// Package managers
 93	"apk",
 94	"apt",
 95	"apt-cache",
 96	"apt-get",
 97	"dnf",
 98	"dpkg",
 99	"emerge",
100	"home-manager",
101	"makepkg",
102	"opkg",
103	"pacman",
104	"paru",
105	"pkg",
106	"pkg_add",
107	"pkg_delete",
108	"portage",
109	"rpm",
110	"yay",
111	"yum",
112	"zypper",
113
114	// System modification
115	"at",
116	"batch",
117	"chkconfig",
118	"crontab",
119	"fdisk",
120	"mkfs",
121	"mount",
122	"parted",
123	"service",
124	"systemctl",
125	"umount",
126
127	// Network configuration
128	"firewall-cmd",
129	"ifconfig",
130	"ip",
131	"iptables",
132	"netstat",
133	"pfctl",
134	"route",
135	"ufw",
136}
137
138func bashDescription(attribution *config.Attribution, modelName string) string {
139	bannedCommandsStr := strings.Join(bannedCommands, ", ")
140	var out bytes.Buffer
141	if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
142		BannedCommands:  bannedCommandsStr,
143		MaxOutputLength: MaxOutputLength,
144		Attribution:     *attribution,
145		ModelName:       modelName,
146	}); err != nil {
147		// this should never happen.
148		panic("failed to execute bash description template: " + err.Error())
149	}
150	return out.String()
151}
152
153func blockFuncs() []shell.BlockFunc {
154	return []shell.BlockFunc{
155		shell.CommandsBlocker(bannedCommands),
156
157		// System package managers
158		shell.ArgumentsBlocker("apk", []string{"add"}, nil),
159		shell.ArgumentsBlocker("apt", []string{"install"}, nil),
160		shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
161		shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
162		shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
163		shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
164		shell.ArgumentsBlocker("yum", []string{"install"}, nil),
165		shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
166
167		// Language-specific package managers
168		shell.ArgumentsBlocker("brew", []string{"install"}, nil),
169		shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
170		shell.ArgumentsBlocker("gem", []string{"install"}, nil),
171		shell.ArgumentsBlocker("go", []string{"install"}, nil),
172		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
173		shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
174		shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
175		shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
176		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
177		shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
178		shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
179
180		// `go test -exec` can run arbitrary commands
181		shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
182	}
183}
184
185func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution, modelName string) fantasy.AgentTool {
186	// Set up command blocking on the persistent shell
187	persistentShell := shell.GetPersistentShell(workingDir)
188	persistentShell.SetBlockFuncs(blockFuncs())
189	return fantasy.NewAgentTool(
190		BashToolName,
191		string(bashDescription(attribution, modelName)),
192		func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
193			if params.Timeout > MaxTimeout {
194				params.Timeout = MaxTimeout
195			} else if params.Timeout <= 0 {
196				params.Timeout = DefaultTimeout
197			}
198
199			if params.Command == "" {
200				return fantasy.NewTextErrorResponse("missing command"), nil
201			}
202
203			isSafeReadOnly := false
204			cmdLower := strings.ToLower(params.Command)
205
206			for _, safe := range safeCommands {
207				if strings.HasPrefix(cmdLower, safe) {
208					if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
209						isSafeReadOnly = true
210						break
211					}
212				}
213			}
214
215			sessionID := GetSessionFromContext(ctx)
216			if sessionID == "" {
217				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
218			}
219			if !isSafeReadOnly {
220				shell := shell.GetPersistentShell(workingDir)
221				p := permissions.Request(
222					permission.CreatePermissionRequest{
223						SessionID:   sessionID,
224						Path:        shell.GetWorkingDir(),
225						ToolCallID:  call.ID,
226						ToolName:    BashToolName,
227						Action:      "execute",
228						Description: fmt.Sprintf("Execute command: %s", params.Command),
229						Params: BashPermissionsParams{
230							Command:     params.Command,
231							Description: params.Description,
232						},
233					},
234				)
235				if !p {
236					return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
237				}
238			}
239			startTime := time.Now()
240			if params.Timeout > 0 {
241				var cancel context.CancelFunc
242				ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
243				defer cancel()
244			}
245
246			persistentShell := shell.GetPersistentShell(workingDir)
247			stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
248
249			// Get the current working directory after command execution
250			currentWorkingDir := persistentShell.GetWorkingDir()
251			interrupted := shell.IsInterrupt(err)
252			exitCode := shell.ExitCode(err)
253			if exitCode == 0 && !interrupted && err != nil {
254				return fantasy.ToolResponse{}, fmt.Errorf("error executing command: %w", err)
255			}
256
257			stdout = truncateOutput(stdout)
258			stderr = truncateOutput(stderr)
259
260			errorMessage := stderr
261			if errorMessage == "" && err != nil {
262				errorMessage = err.Error()
263			}
264
265			if interrupted {
266				if errorMessage != "" {
267					errorMessage += "\n"
268				}
269				errorMessage += "Command was aborted before completion"
270			} else if exitCode != 0 {
271				if errorMessage != "" {
272					errorMessage += "\n"
273				}
274				errorMessage += fmt.Sprintf("Exit code %d", exitCode)
275			}
276
277			hasBothOutputs := stdout != "" && stderr != ""
278
279			if hasBothOutputs {
280				stdout += "\n"
281			}
282
283			if errorMessage != "" {
284				stdout += "\n" + errorMessage
285			}
286
287			metadata := BashResponseMetadata{
288				StartTime:        startTime.UnixMilli(),
289				EndTime:          time.Now().UnixMilli(),
290				Output:           stdout,
291				Description:      params.Description,
292				WorkingDirectory: currentWorkingDir,
293			}
294			if stdout == "" {
295				return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
296			}
297			stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(currentWorkingDir))
298			return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
299		})
300}
301
302func truncateOutput(content string) string {
303	if len(content) <= MaxOutputLength {
304		return content
305	}
306
307	halfLength := MaxOutputLength / 2
308	start := content[:halfLength]
309	end := content[len(content)-halfLength:]
310
311	truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
312	return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
313}
314
315func countLines(s string) int {
316	if s == "" {
317		return 0
318	}
319	return len(strings.Split(s, "\n"))
320}
321
322func normalizeWorkingDir(path string) string {
323	if runtime.GOOS == "windows" {
324		cwd, err := os.Getwd()
325		if err != nil {
326			cwd = "C:"
327		}
328		path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
329	}
330
331	return filepath.ToSlash(path)
332}