1package tools
2
3import (
4 "bufio"
5 "context"
6 _ "embed"
7 "fmt"
8 "io"
9 "os"
10 "path/filepath"
11 "strings"
12 "unicode/utf8"
13
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/lsp"
16 "github.com/charmbracelet/crush/internal/permission"
17 "github.com/charmbracelet/fantasy/ai"
18)
19
20//go:embed view.md
21var viewDescription []byte
22
23type ViewParams struct {
24 FilePath string `json:"file_path" description:"The path to the file to read"`
25 Offset int `json:"offset" description:"The line number to start reading from (0-based)"`
26 Limit int `json:"limit" description:"The number of lines to read (defaults to 2000)"`
27}
28
29type ViewPermissionsParams struct {
30 FilePath string `json:"file_path"`
31 Offset int `json:"offset"`
32 Limit int `json:"limit"`
33}
34
35type viewTool struct {
36 lspClients *csync.Map[string, *lsp.Client]
37 workingDir string
38 permissions permission.Service
39}
40
41type ViewResponseMetadata struct {
42 FilePath string `json:"file_path"`
43 Content string `json:"content"`
44}
45
46const (
47 ViewToolName = "view"
48 MaxReadSize = 250 * 1024
49 DefaultReadLimit = 2000
50 MaxLineLength = 2000
51)
52
53func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, workingDir string) ai.AgentTool {
54 return ai.NewAgentTool(
55 ViewToolName,
56 string(viewDescription),
57 func(ctx context.Context, params ViewParams, call ai.ToolCall) (ai.ToolResponse, error) {
58 if params.FilePath == "" {
59 return ai.NewTextErrorResponse("file_path is required"), nil
60 }
61
62 // Handle relative paths
63 filePath := params.FilePath
64 if !filepath.IsAbs(filePath) {
65 filePath = filepath.Join(workingDir, filePath)
66 }
67
68 // Check if file is outside working directory and request permission if needed
69 absWorkingDir, err := filepath.Abs(workingDir)
70 if err != nil {
71 return ai.ToolResponse{}, fmt.Errorf("error resolving working directory: %w", err)
72 }
73
74 absFilePath, err := filepath.Abs(filePath)
75 if err != nil {
76 return ai.ToolResponse{}, fmt.Errorf("error resolving file path: %w", err)
77 }
78
79 relPath, err := filepath.Rel(absWorkingDir, absFilePath)
80 if err != nil || strings.HasPrefix(relPath, "..") {
81 // File is outside working directory, request permission
82 sessionID := GetSessionFromContext(ctx)
83 if sessionID == "" {
84 return ai.ToolResponse{}, fmt.Errorf("session ID is required for accessing files outside working directory")
85 }
86
87 granted := permissions.Request(
88 permission.CreatePermissionRequest{
89 SessionID: sessionID,
90 Path: absFilePath,
91 ToolCallID: call.ID,
92 ToolName: ViewToolName,
93 Action: "read",
94 Description: fmt.Sprintf("Read file outside working directory: %s", absFilePath),
95 Params: ViewPermissionsParams(params),
96 },
97 )
98
99 if !granted {
100 return ai.ToolResponse{}, permission.ErrorPermissionDenied
101 }
102 }
103
104 // Check if file exists
105 fileInfo, err := os.Stat(filePath)
106 if err != nil {
107 if os.IsNotExist(err) {
108 // Try to offer suggestions for similarly named files
109 dir := filepath.Dir(filePath)
110 base := filepath.Base(filePath)
111
112 dirEntries, dirErr := os.ReadDir(dir)
113 if dirErr == nil {
114 var suggestions []string
115 for _, entry := range dirEntries {
116 if strings.Contains(strings.ToLower(entry.Name()), strings.ToLower(base)) ||
117 strings.Contains(strings.ToLower(base), strings.ToLower(entry.Name())) {
118 suggestions = append(suggestions, filepath.Join(dir, entry.Name()))
119 if len(suggestions) >= 3 {
120 break
121 }
122 }
123 }
124
125 if len(suggestions) > 0 {
126 return ai.NewTextErrorResponse(fmt.Sprintf("File not found: %s\n\nDid you mean one of these?\n%s",
127 filePath, strings.Join(suggestions, "\n"))), nil
128 }
129 }
130
131 return ai.NewTextErrorResponse(fmt.Sprintf("File not found: %s", filePath)), nil
132 }
133 return ai.ToolResponse{}, fmt.Errorf("error accessing file: %w", err)
134 }
135
136 // Check if it's a directory
137 if fileInfo.IsDir() {
138 return ai.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
139 }
140
141 // Check file size
142 if fileInfo.Size() > MaxReadSize {
143 return ai.NewTextErrorResponse(fmt.Sprintf("File is too large (%d bytes). Maximum size is %d bytes",
144 fileInfo.Size(), MaxReadSize)), nil
145 }
146
147 // Set default limit if not provided
148 if params.Limit <= 0 {
149 params.Limit = DefaultReadLimit
150 }
151
152 // Check if it's an image file
153 isImage, imageType := isImageFile(filePath)
154 // TODO: handle images
155 if isImage {
156 return ai.NewTextErrorResponse(fmt.Sprintf("This is an image file of type: %s\n", imageType)), nil
157 }
158
159 // Read the file content
160 content, lineCount, err := readTextFile(filePath, params.Offset, params.Limit)
161 isValidUt8 := utf8.ValidString(content)
162 if !isValidUt8 {
163 return ai.NewTextErrorResponse("File content is not valid UTF-8"), nil
164 }
165 if err != nil {
166 return ai.ToolResponse{}, fmt.Errorf("error reading file: %w", err)
167 }
168
169 notifyLSPs(ctx, lspClients, filePath)
170 output := "<file>\n"
171 // Format the output with line numbers
172 output += addLineNumbers(content, params.Offset+1)
173
174 // Add a note if the content was truncated
175 if lineCount > params.Offset+len(strings.Split(content, "\n")) {
176 output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)",
177 params.Offset+len(strings.Split(content, "\n")))
178 }
179 output += "\n</file>\n"
180 output += getDiagnostics(filePath, lspClients)
181 recordFileRead(filePath)
182 return ai.WithResponseMetadata(
183 ai.NewTextResponse(output),
184 ViewResponseMetadata{
185 FilePath: filePath,
186 Content: content,
187 },
188 ), nil
189 })
190}
191
192func addLineNumbers(content string, startLine int) string {
193 if content == "" {
194 return ""
195 }
196
197 lines := strings.Split(content, "\n")
198
199 var result []string
200 for i, line := range lines {
201 line = strings.TrimSuffix(line, "\r")
202
203 lineNum := i + startLine
204 numStr := fmt.Sprintf("%d", lineNum)
205
206 if len(numStr) >= 6 {
207 result = append(result, fmt.Sprintf("%s|%s", numStr, line))
208 } else {
209 paddedNum := fmt.Sprintf("%6s", numStr)
210 result = append(result, fmt.Sprintf("%s|%s", paddedNum, line))
211 }
212 }
213
214 return strings.Join(result, "\n")
215}
216
217func readTextFile(filePath string, offset, limit int) (string, int, error) {
218 file, err := os.Open(filePath)
219 if err != nil {
220 return "", 0, err
221 }
222 defer file.Close()
223
224 lineCount := 0
225
226 scanner := NewLineScanner(file)
227 if offset > 0 {
228 for lineCount < offset && scanner.Scan() {
229 lineCount++
230 }
231 if err = scanner.Err(); err != nil {
232 return "", 0, err
233 }
234 }
235
236 if offset == 0 {
237 _, err = file.Seek(0, io.SeekStart)
238 if err != nil {
239 return "", 0, err
240 }
241 }
242
243 // Pre-allocate slice with expected capacity
244 lines := make([]string, 0, limit)
245 lineCount = offset
246
247 for scanner.Scan() && len(lines) < limit {
248 lineCount++
249 lineText := scanner.Text()
250 if len(lineText) > MaxLineLength {
251 lineText = lineText[:MaxLineLength] + "..."
252 }
253 lines = append(lines, lineText)
254 }
255
256 // Continue scanning to get total line count
257 for scanner.Scan() {
258 lineCount++
259 }
260
261 if err := scanner.Err(); err != nil {
262 return "", 0, err
263 }
264
265 return strings.Join(lines, "\n"), lineCount, nil
266}
267
268func isImageFile(filePath string) (bool, string) {
269 ext := strings.ToLower(filepath.Ext(filePath))
270 switch ext {
271 case ".jpg", ".jpeg":
272 return true, "JPEG"
273 case ".png":
274 return true, "PNG"
275 case ".gif":
276 return true, "GIF"
277 case ".bmp":
278 return true, "BMP"
279 case ".svg":
280 return true, "SVG"
281 case ".webp":
282 return true, "WebP"
283 default:
284 return false, ""
285 }
286}
287
288type LineScanner struct {
289 scanner *bufio.Scanner
290}
291
292func NewLineScanner(r io.Reader) *LineScanner {
293 return &LineScanner{
294 scanner: bufio.NewScanner(r),
295 }
296}
297
298func (s *LineScanner) Scan() bool {
299 return s.scanner.Scan()
300}
301
302func (s *LineScanner) Text() string {
303 return s.scanner.Text()
304}
305
306func (s *LineScanner) Err() error {
307 return s.scanner.Err()
308}