1package tools
2
3import (
4 "bytes"
5 "context"
6 _ "embed"
7 "fmt"
8 "html/template"
9 "os"
10 "path/filepath"
11 "runtime"
12 "strings"
13 "time"
14
15 "charm.land/fantasy"
16 "git.secluded.site/crush/internal/config"
17 "git.secluded.site/crush/internal/permission"
18 "git.secluded.site/crush/internal/shell"
19)
20
21type BashParams struct {
22 Command string `json:"command" description:"The command to execute"`
23 Description string `json:"description,omitempty" description:"A brief description of what the command does"`
24 Timeout int `json:"timeout,omitempty" description:"Optional timeout in milliseconds (max 600000)"`
25}
26
27type BashPermissionsParams struct {
28 Command string `json:"command"`
29 Description string `json:"description"`
30 Timeout int `json:"timeout"`
31}
32
33type BashResponseMetadata struct {
34 StartTime int64 `json:"start_time"`
35 EndTime int64 `json:"end_time"`
36 Output string `json:"output"`
37 Description string `json:"description"`
38 WorkingDirectory string `json:"working_directory"`
39}
40
41const (
42 BashToolName = "bash"
43
44 DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
45 MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
46 MaxOutputLength = 30000
47 BashNoOutput = "no output"
48)
49
50//go:embed bash.tpl
51var bashDescriptionTmpl []byte
52
53var bashDescriptionTpl = template.Must(
54 template.New("bashDescription").
55 Parse(string(bashDescriptionTmpl)),
56)
57
58type bashDescriptionData struct {
59 BannedCommands string
60 MaxOutputLength int
61 Attribution config.Attribution
62 ModelName string
63}
64
65var bannedCommands = []string{
66 // Network/Download tools
67 "alias",
68 "aria2c",
69 "axel",
70 "chrome",
71 "curl",
72 "curlie",
73 "firefox",
74 "http-prompt",
75 "httpie",
76 "links",
77 "lynx",
78 "nc",
79 "safari",
80 "scp",
81 "ssh",
82 "telnet",
83 "w3m",
84 "wget",
85 "xh",
86
87 // System administration
88 "doas",
89 "su",
90 "sudo",
91
92 // Package managers
93 "apk",
94 "apt",
95 "apt-cache",
96 "apt-get",
97 "dnf",
98 "dpkg",
99 "emerge",
100 "home-manager",
101 "makepkg",
102 "opkg",
103 "pacman",
104 "paru",
105 "pkg",
106 "pkg_add",
107 "pkg_delete",
108 "portage",
109 "rpm",
110 "yay",
111 "yum",
112 "zypper",
113
114 // System modification
115 "at",
116 "batch",
117 "chkconfig",
118 "crontab",
119 "fdisk",
120 "mkfs",
121 "mount",
122 "parted",
123 "service",
124 "systemctl",
125 "umount",
126
127 // Network configuration
128 "firewall-cmd",
129 "ifconfig",
130 "ip",
131 "iptables",
132 "netstat",
133 "pfctl",
134 "route",
135 "ufw",
136}
137
138func bashDescription(attribution *config.Attribution, modelName string) string {
139 bannedCommandsStr := strings.Join(bannedCommands, ", ")
140 var out bytes.Buffer
141 if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
142 BannedCommands: bannedCommandsStr,
143 MaxOutputLength: MaxOutputLength,
144 Attribution: *attribution,
145 ModelName: modelName,
146 }); err != nil {
147 // this should never happen.
148 panic("failed to execute bash description template: " + err.Error())
149 }
150 return out.String()
151}
152
153func blockFuncs() []shell.BlockFunc {
154 return []shell.BlockFunc{
155 shell.CommandsBlocker(bannedCommands),
156
157 // System package managers
158 shell.ArgumentsBlocker("apk", []string{"add"}, nil),
159 shell.ArgumentsBlocker("apt", []string{"install"}, nil),
160 shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
161 shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
162 shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
163 shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
164 shell.ArgumentsBlocker("yum", []string{"install"}, nil),
165 shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
166
167 // Language-specific package managers
168 shell.ArgumentsBlocker("brew", []string{"install"}, nil),
169 shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
170 shell.ArgumentsBlocker("gem", []string{"install"}, nil),
171 shell.ArgumentsBlocker("go", []string{"install"}, nil),
172 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
173 shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
174 shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
175 shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
176 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
177 shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
178 shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
179
180 // `go test -exec` can run arbitrary commands
181 shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
182 }
183}
184
185func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution, modelName string) fantasy.AgentTool {
186 // Set up command blocking on the persistent shell
187 persistentShell := shell.GetPersistentShell(workingDir)
188 persistentShell.SetBlockFuncs(blockFuncs())
189 return fantasy.NewAgentTool(
190 BashToolName,
191 string(bashDescription(attribution, modelName)),
192 func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
193 if params.Timeout > MaxTimeout {
194 params.Timeout = MaxTimeout
195 } else if params.Timeout <= 0 {
196 params.Timeout = DefaultTimeout
197 }
198
199 if params.Command == "" {
200 return fantasy.NewTextErrorResponse("missing command"), nil
201 }
202
203 isSafeReadOnly := false
204 cmdLower := strings.ToLower(params.Command)
205
206 for _, safe := range safeCommands {
207 if strings.HasPrefix(cmdLower, safe) {
208 if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
209 isSafeReadOnly = true
210 break
211 }
212 }
213 }
214
215 sessionID := GetSessionFromContext(ctx)
216 if sessionID == "" {
217 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
218 }
219 if !isSafeReadOnly {
220 shell := shell.GetPersistentShell(workingDir)
221 p := permissions.Request(
222 permission.CreatePermissionRequest{
223 SessionID: sessionID,
224 Path: shell.GetWorkingDir(),
225 ToolCallID: call.ID,
226 ToolName: BashToolName,
227 Action: "execute",
228 Description: fmt.Sprintf("Execute command: %s", params.Command),
229 Params: BashPermissionsParams{
230 Command: params.Command,
231 Description: params.Description,
232 },
233 },
234 )
235 if !p {
236 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
237 }
238 }
239 startTime := time.Now()
240 if params.Timeout > 0 {
241 var cancel context.CancelFunc
242 ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
243 defer cancel()
244 }
245
246 persistentShell := shell.GetPersistentShell(workingDir)
247 stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
248
249 // Get the current working directory after command execution
250 currentWorkingDir := persistentShell.GetWorkingDir()
251 interrupted := shell.IsInterrupt(err)
252 exitCode := shell.ExitCode(err)
253 if exitCode == 0 && !interrupted && err != nil {
254 return fantasy.ToolResponse{}, fmt.Errorf("error executing command: %w", err)
255 }
256
257 stdout = truncateOutput(stdout)
258 stderr = truncateOutput(stderr)
259
260 errorMessage := stderr
261 if errorMessage == "" && err != nil {
262 errorMessage = err.Error()
263 }
264
265 if interrupted {
266 if errorMessage != "" {
267 errorMessage += "\n"
268 }
269 errorMessage += "Command was aborted before completion"
270 } else if exitCode != 0 {
271 if errorMessage != "" {
272 errorMessage += "\n"
273 }
274 errorMessage += fmt.Sprintf("Exit code %d", exitCode)
275 }
276
277 hasBothOutputs := stdout != "" && stderr != ""
278
279 if hasBothOutputs {
280 stdout += "\n"
281 }
282
283 if errorMessage != "" {
284 stdout += "\n" + errorMessage
285 }
286
287 metadata := BashResponseMetadata{
288 StartTime: startTime.UnixMilli(),
289 EndTime: time.Now().UnixMilli(),
290 Output: stdout,
291 Description: params.Description,
292 WorkingDirectory: currentWorkingDir,
293 }
294 if stdout == "" {
295 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
296 }
297 stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(currentWorkingDir))
298 return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
299 })
300}
301
302func truncateOutput(content string) string {
303 if len(content) <= MaxOutputLength {
304 return content
305 }
306
307 halfLength := MaxOutputLength / 2
308 start := content[:halfLength]
309 end := content[len(content)-halfLength:]
310
311 truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
312 return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
313}
314
315func countLines(s string) int {
316 if s == "" {
317 return 0
318 }
319 return len(strings.Split(s, "\n"))
320}
321
322func normalizeWorkingDir(path string) string {
323 if runtime.GOOS == "windows" {
324 cwd, err := os.Getwd()
325 if err != nil {
326 cwd = "C:"
327 }
328 path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
329 }
330
331 return filepath.ToSlash(path)
332}