handle errors correctly in the other tools

Kujtim Hoxha created

Change summary

internal/llm/tools/edit.go        |  8 ++++++--
internal/llm/tools/fetch.go       |  7 ++++---
internal/llm/tools/glob.go        | 15 +++++++++++++--
internal/llm/tools/grep.go        | 15 +++++++++++++--
internal/llm/tools/ls.go          | 15 +++++++++++++--
internal/llm/tools/sourcegraph.go | 15 ++++++++++-----
internal/llm/tools/view.go        |  5 +++--
internal/llm/tools/write.go       | 18 ++++++++++--------
8 files changed, 72 insertions(+), 26 deletions(-)

Detailed changes

internal/llm/tools/edit.go 🔗

@@ -27,8 +27,9 @@ type EditPermissionsParams struct {
 }
 
 type EditResponseMetadata struct {
-	Additions int `json:"additions"`
-	Removals  int `json:"removals"`
+	Diff      string `json:"diff"`
+	Additions int    `json:"additions"`
+	Removals  int    `json:"removals"`
 }
 
 type editTool struct {
@@ -216,6 +217,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
 	return WithResponseMetadata(
 		NewTextResponse("File created: "+filePath),
 		EditResponseMetadata{
+			Diff:      diff,
 			Additions: stats.Additions,
 			Removals:  stats.Removals,
 		},
@@ -308,6 +310,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
 	return WithResponseMetadata(
 		NewTextResponse("Content deleted from file: "+filePath),
 		EditResponseMetadata{
+			Diff:      diff,
 			Additions: stats.Additions,
 			Removals:  stats.Removals,
 		},
@@ -401,6 +404,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
 	return WithResponseMetadata(
 		NewTextResponse("Content replaced in file: "+filePath),
 		EditResponseMetadata{
+			Diff:      diff,
 			Additions: stats.Additions,
 			Removals:  stats.Removals,
 		}), nil

internal/llm/tools/fetch.go 🔗

@@ -86,6 +86,7 @@ func (t *fetchTool) Info() ToolInfo {
 			"format": map[string]any{
 				"type":        "string",
 				"description": "The format to return the content in (text, markdown, or html)",
+				"enum":        []string{"text", "markdown", "html"},
 			},
 			"timeout": map[string]any{
 				"type":        "number",
@@ -126,7 +127,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 	)
 
 	if !p {
-		return NewTextErrorResponse("Permission denied to fetch from URL: " + params.URL), nil
+		return ToolResponse{}, permission.ErrorPermissionDenied
 	}
 
 	client := t.client
@@ -142,14 +143,14 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 
 	req, err := http.NewRequestWithContext(ctx, "GET", params.URL, nil)
 	if err != nil {
-		return NewTextErrorResponse("Failed to create request: " + err.Error()), nil
+		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
 	}
 
 	req.Header.Set("User-Agent", "termai/1.0")
 
 	resp, err := client.Do(req)
 	if err != nil {
-		return NewTextErrorResponse("Failed to execute request: " + err.Error()), nil
+		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
 	}
 	defer resp.Body.Close()
 

internal/llm/tools/glob.go 🔗

@@ -63,6 +63,11 @@ type GlobParams struct {
 	Path    string `json:"path"`
 }
 
+type GlobMetadata struct {
+	NumberOfFiles int  `json:"number_of_files"`
+	Truncated     bool `json:"truncated"`
+}
+
 type globTool struct{}
 
 func NewGlobTool() BaseTool {
@@ -104,7 +109,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 
 	files, truncated, err := globFiles(params.Pattern, searchPath, 100)
 	if err != nil {
-		return NewTextErrorResponse(fmt.Sprintf("error performing glob search: %s", err)), nil
+		return ToolResponse{}, fmt.Errorf("error finding files: %w", err)
 	}
 
 	var output string
@@ -117,7 +122,13 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 		}
 	}
 
-	return NewTextResponse(output), nil
+	return WithResponseMetadata(
+		NewTextResponse(output),
+		GlobMetadata{
+			NumberOfFiles: len(files),
+			Truncated:     truncated,
+		},
+	), nil
 }
 
 func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) {

internal/llm/tools/grep.go 🔗

@@ -27,6 +27,11 @@ type grepMatch struct {
 	modTime time.Time
 }
 
+type GrepMetadata struct {
+	NumberOfMatches int  `json:"number_of_matches"`
+	Truncated       bool `json:"truncated"`
+}
+
 type grepTool struct{}
 
 const (
@@ -110,7 +115,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 
 	matches, truncated, err := searchFiles(params.Pattern, searchPath, params.Include, 100)
 	if err != nil {
-		return NewTextErrorResponse(fmt.Sprintf("error searching files: %s", err)), nil
+		return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
 	}
 
 	var output string
@@ -127,7 +132,13 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 		}
 	}
 
-	return NewTextResponse(output), nil
+	return WithResponseMetadata(
+		NewTextResponse(output),
+		GrepMetadata{
+			NumberOfMatches: len(matches),
+			Truncated:       truncated,
+		},
+	), nil
 }
 
 func pluralize(count int) string {

internal/llm/tools/ls.go 🔗

@@ -23,6 +23,11 @@ type TreeNode struct {
 	Children []*TreeNode `json:"children,omitempty"`
 }
 
+type LSMetadata struct {
+	NumberOfFiles int  `json:"number_of_files"`
+	Truncated     bool `json:"truncated"`
+}
+
 type lsTool struct{}
 
 const (
@@ -104,7 +109,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 
 	files, truncated, err := listDirectory(searchPath, params.Ignore, MaxLSFiles)
 	if err != nil {
-		return NewTextErrorResponse(fmt.Sprintf("error listing directory: %s", err)), nil
+		return ToolResponse{}, fmt.Errorf("error listing directory: %w", err)
 	}
 
 	tree := createFileTree(files)
@@ -114,7 +119,13 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 		output = fmt.Sprintf("There are more than %d files in the directory. Use a more specific path or use the Glob tool to find specific files. The first %d files and directories are included below:\n\n%s", MaxLSFiles, MaxLSFiles, output)
 	}
 
-	return NewTextResponse(output), nil
+	return WithResponseMetadata(
+		NewTextResponse(output),
+		LSMetadata{
+			NumberOfFiles: len(files),
+			Truncated:     truncated,
+		},
+	), nil
 }
 
 func listDirectory(initialPath string, ignorePatterns []string, limit int) ([]string, bool, error) {

internal/llm/tools/sourcegraph.go 🔗

@@ -18,6 +18,11 @@ type SourcegraphParams struct {
 	Timeout       int    `json:"timeout,omitempty"`
 }
 
+type SourcegraphMetadata struct {
+	NumberOfMatches int  `json:"number_of_matches"`
+	Truncated       bool `json:"truncated"`
+}
+
 type sourcegraphTool struct {
 	client *http.Client
 }
@@ -198,7 +203,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 
 	graphqlQueryBytes, err := json.Marshal(request)
 	if err != nil {
-		return NewTextErrorResponse("Failed to create GraphQL request: " + err.Error()), nil
+		return ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
 	}
 	graphqlQuery := string(graphqlQueryBytes)
 
@@ -209,7 +214,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 		bytes.NewBuffer([]byte(graphqlQuery)),
 	)
 	if err != nil {
-		return NewTextErrorResponse("Failed to create request: " + err.Error()), nil
+		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
 	}
 
 	req.Header.Set("Content-Type", "application/json")
@@ -217,7 +222,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 
 	resp, err := client.Do(req)
 	if err != nil {
-		return NewTextErrorResponse("Failed to execute request: " + err.Error()), nil
+		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
 	}
 	defer resp.Body.Close()
 
@@ -231,12 +236,12 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	}
 	body, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
+		return ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
 	}
 
 	var result map[string]any
 	if err = json.Unmarshal(body, &result); err != nil {
-		return NewTextErrorResponse("Failed to parse response: " + err.Error()), nil
+		return ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
 	}
 
 	formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)

internal/llm/tools/view.go 🔗

@@ -135,7 +135,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 
 			return NewTextErrorResponse(fmt.Sprintf("File not found: %s", filePath)), nil
 		}
-		return NewTextErrorResponse(fmt.Sprintf("Failed to access file: %s", err)), nil
+		return ToolResponse{}, fmt.Errorf("error accessing file: %w", err)
 	}
 
 	// Check if it's a directory
@@ -156,6 +156,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 
 	// Check if it's an image file
 	isImage, imageType := isImageFile(filePath)
+	// TODO: handle images
 	if isImage {
 		return NewTextErrorResponse(fmt.Sprintf("This is an image file of type: %s\nUse a different tool to process images", imageType)), nil
 	}
@@ -163,7 +164,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 	// Read the file content
 	content, lineCount, err := readTextFile(filePath, params.Offset, params.Limit)
 	if err != nil {
-		return NewTextErrorResponse(fmt.Sprintf("Failed to read file: %s", err)), nil
+		return ToolResponse{}, fmt.Errorf("error reading file: %w", err)
 	}
 
 	notifyLspOpenFile(ctx, filePath, v.lspClients)

internal/llm/tools/write.go 🔗

@@ -30,8 +30,9 @@ type writeTool struct {
 }
 
 type WriteResponseMetadata struct {
-	Additions int `json:"additions"`
-	Removals  int `json:"removals"`
+	Diff      string `json:"diff"`
+	Additions int    `json:"additions"`
+	Removals  int    `json:"removals"`
 }
 
 const (
@@ -128,12 +129,12 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 			return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
 		}
 	} else if !os.IsNotExist(err) {
-		return NewTextErrorResponse(fmt.Sprintf("Failed to access file: %s", err)), nil
+		return ToolResponse{}, fmt.Errorf("error checking file: %w", err)
 	}
 
 	dir := filepath.Dir(filePath)
 	if err = os.MkdirAll(dir, 0o755); err != nil {
-		return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil
+		return ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
 	}
 
 	oldContent := ""
@@ -146,7 +147,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 
 	sessionID, messageID := GetContextValues(ctx)
 	if sessionID == "" || messageID == "" {
-		return NewTextErrorResponse("session ID or message ID is missing"), nil
+		return ToolResponse{}, fmt.Errorf("session_id and message_id are required")
 	}
 	diff, stats, err := git.GenerateGitDiffWithStats(
 		removeWorkingDirectoryPrefix(filePath),
@@ -154,7 +155,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 		params.Content,
 	)
 	if err != nil {
-		return NewTextErrorResponse(fmt.Sprintf("Failed to get file diff: %s", err)), nil
+		return ToolResponse{}, fmt.Errorf("error generating diff: %w", err)
 	}
 	p := w.permissions.Request(
 		permission.CreatePermissionRequest{
@@ -169,12 +170,12 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 		},
 	)
 	if !p {
-		return NewTextErrorResponse(fmt.Sprintf("Permission denied to create file: %s", filePath)), nil
+		return ToolResponse{}, permission.ErrorPermissionDenied
 	}
 
 	err = os.WriteFile(filePath, []byte(params.Content), 0o644)
 	if err != nil {
-		return NewTextErrorResponse(fmt.Sprintf("Failed to write file: %s", err)), nil
+		return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
 	}
 
 	recordFileWrite(filePath)
@@ -186,6 +187,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 	result += getDiagnostics(filePath, w.lspClients)
 	return WithResponseMetadata(NewTextResponse(result),
 		WriteResponseMetadata{
+			Diff:      diff,
 			Additions: stats.Additions,
 			Removals:  stats.Removals,
 		},