handle errors correctly in the edit tool

Kujtim Hoxha created

Change summary

internal/llm/tools/diagnostics.go |   4 
internal/llm/tools/edit.go        | 152 ++++++++++++++++----------------
internal/llm/tools/view.go        |   2 
internal/llm/tools/write.go       |   2 
4 files changed, 78 insertions(+), 82 deletions(-)

Detailed changes

internal/llm/tools/diagnostics.go 🔗

@@ -82,7 +82,7 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 		waitForLspDiagnostics(ctx, params.FilePath, lsps)
 	}
 
-	output := appendDiagnostics(params.FilePath, lsps)
+	output := getDiagnostics(params.FilePath, lsps)
 
 	return NewTextResponse(output), nil
 }
@@ -154,7 +154,7 @@ func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol
 	return false
 }
 
-func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
+func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
 	fileDiagnostics := []string{}
 	projectDiagnostics := []string{}
 

internal/llm/tools/edit.go 🔗

@@ -131,68 +131,54 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 		params.FilePath = filepath.Join(wd, params.FilePath)
 	}
 
+	var response ToolResponse
+	var err error
+
 	if params.OldString == "" {
-		result, err := e.createNewFile(ctx, params.FilePath, params.NewString)
+		response, err = e.createNewFile(ctx, params.FilePath, params.NewString)
 		if err != nil {
-			return NewTextErrorResponse(fmt.Sprintf("error creating file: %s", err)), nil
+			return response, nil
 		}
-		return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{
-			Additions: result.additions,
-			Removals:  result.removals,
-		}), nil
 	}
 
 	if params.NewString == "" {
-		result, err := e.deleteContent(ctx, params.FilePath, params.OldString)
+		response, err = e.deleteContent(ctx, params.FilePath, params.OldString)
 		if err != nil {
-			return NewTextErrorResponse(fmt.Sprintf("error deleting content: %s", err)), nil
+			return response, nil
 		}
-		return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{
-			Additions: result.additions,
-			Removals:  result.removals,
-		}), nil
 	}
 
-	result, err := e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString)
+	response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString)
 	if err != nil {
-		return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil
+		return response, nil
 	}
 
 	waitForLspDiagnostics(ctx, params.FilePath, e.lspClients)
-	text := fmt.Sprintf("<result>\n%s\n</result>\n", result.text)
-	text += appendDiagnostics(params.FilePath, e.lspClients)
-	return WithResponseMetadata(NewTextResponse(text), EditResponseMetadata{
-		Additions: result.additions,
-		Removals:  result.removals,
-	}), nil
-}
-
-type editResponse struct {
-	text      string
-	additions int
-	removals  int
+	text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
+	text += getDiagnostics(params.FilePath, e.lspClients)
+	response.Content = text
+	return response, nil
 }
 
-func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (editResponse, error) {
-	er := editResponse{}
+func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (ToolResponse, error) {
 	fileInfo, err := os.Stat(filePath)
 	if err == nil {
 		if fileInfo.IsDir() {
-			return er, fmt.Errorf("path is a directory, not a file: %s", filePath)
+			return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
 		}
-		return er, fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath)
+		return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", filePath)), nil
 	} else if !os.IsNotExist(err) {
-		return er, fmt.Errorf("failed to access file: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
 	}
 
 	dir := filepath.Dir(filePath)
 	if err = os.MkdirAll(dir, 0o755); err != nil {
-		return er, fmt.Errorf("failed to create parent directories: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
 	}
 
 	sessionID, messageID := GetContextValues(ctx)
 	if sessionID == "" || messageID == "" {
-		return er, fmt.Errorf("session ID and message ID are required for creating a new file")
+		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 	}
 
 	diff, stats, err := git.GenerateGitDiffWithStats(
@@ -201,7 +187,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
 		content,
 	)
 	if err != nil {
-		return er, fmt.Errorf("failed to get file diff: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to get file diff: %w", err)
 	}
 	p := e.permissions.Request(
 		permission.CreatePermissionRequest{
@@ -216,63 +202,67 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
 		},
 	)
 	if !p {
-		return er, fmt.Errorf("permission denied")
+		return ToolResponse{}, permission.ErrorPermissionDenied
 	}
 
 	err = os.WriteFile(filePath, []byte(content), 0o644)
 	if err != nil {
-		return er, fmt.Errorf("failed to write file: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
 	}
 
 	recordFileWrite(filePath)
 	recordFileRead(filePath)
 
-	er.text = "File created: " + filePath
-	er.additions = stats.Additions
-	er.removals = stats.Removals
-	return er, nil
+	return WithResponseMetadata(
+		NewTextResponse("File created: "+filePath),
+		EditResponseMetadata{
+			Additions: stats.Additions,
+			Removals:  stats.Removals,
+		},
+	), nil
 }
 
-func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (editResponse, error) {
-	er := editResponse{}
+func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (ToolResponse, error) {
 	fileInfo, err := os.Stat(filePath)
 	if err != nil {
 		if os.IsNotExist(err) {
-			return er, fmt.Errorf("file not found: %s", filePath)
+			return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
 		}
-		return er, fmt.Errorf("failed to access file: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
 	}
 
 	if fileInfo.IsDir() {
-		return er, fmt.Errorf("path is a directory, not a file: %s", filePath)
+		return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
 	}
 
 	if getLastReadTime(filePath).IsZero() {
-		return er, fmt.Errorf("you must read the file before editing it. Use the View tool first")
+		return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
 	}
 
 	modTime := fileInfo.ModTime()
 	lastRead := getLastReadTime(filePath)
 	if modTime.After(lastRead) {
-		return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
-			filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))
+		return NewTextErrorResponse(
+			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
+				filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
+			)), nil
 	}
 
 	content, err := os.ReadFile(filePath)
 	if err != nil {
-		return er, fmt.Errorf("failed to read file: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
 	}
 
 	oldContent := string(content)
 
 	index := strings.Index(oldContent, oldString)
 	if index == -1 {
-		return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks")
+		return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
 	}
 
 	lastIndex := strings.LastIndex(oldContent, oldString)
 	if index != lastIndex {
-		return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match")
+		return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil
 	}
 
 	newContent := oldContent[:index] + oldContent[index+len(oldString):]
@@ -280,7 +270,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
 	sessionID, messageID := GetContextValues(ctx)
 
 	if sessionID == "" || messageID == "" {
-		return er, fmt.Errorf("session ID and message ID are required for creating a new file")
+		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 	}
 
 	diff, stats, err := git.GenerateGitDiffWithStats(
@@ -289,7 +279,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
 		newContent,
 	)
 	if err != nil {
-		return er, fmt.Errorf("failed to get file diff: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to get file diff: %w", err)
 	}
 
 	p := e.permissions.Request(
@@ -305,62 +295,66 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
 		},
 	)
 	if !p {
-		return er, fmt.Errorf("permission denied")
+		return ToolResponse{}, permission.ErrorPermissionDenied
 	}
 
 	err = os.WriteFile(filePath, []byte(newContent), 0o644)
 	if err != nil {
-		return er, fmt.Errorf("failed to write file: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
 	}
 	recordFileWrite(filePath)
 	recordFileRead(filePath)
 
-	er.text = "Content deleted from file: " + filePath
-	er.additions = stats.Additions
-	er.removals = stats.Removals
-	return er, nil
+	return WithResponseMetadata(
+		NewTextResponse("Content deleted from file: "+filePath),
+		EditResponseMetadata{
+			Additions: stats.Additions,
+			Removals:  stats.Removals,
+		},
+	), nil
 }
 
-func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (editResponse, error) {
-	er := editResponse{}
+func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (ToolResponse, error) {
 	fileInfo, err := os.Stat(filePath)
 	if err != nil {
 		if os.IsNotExist(err) {
-			return er, fmt.Errorf("file not found: %s", filePath)
+			return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
 		}
-		return er, fmt.Errorf("failed to access file: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
 	}
 
 	if fileInfo.IsDir() {
-		return er, fmt.Errorf("path is a directory, not a file: %s", filePath)
+		return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
 	}
 
 	if getLastReadTime(filePath).IsZero() {
-		return er, fmt.Errorf("you must read the file before editing it. Use the View tool first")
+		return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
 	}
 
 	modTime := fileInfo.ModTime()
 	lastRead := getLastReadTime(filePath)
 	if modTime.After(lastRead) {
-		return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
-			filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))
+		return NewTextErrorResponse(
+			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
+				filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
+			)), nil
 	}
 
 	content, err := os.ReadFile(filePath)
 	if err != nil {
-		return er, fmt.Errorf("failed to read file: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
 	}
 
 	oldContent := string(content)
 
 	index := strings.Index(oldContent, oldString)
 	if index == -1 {
-		return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks")
+		return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
 	}
 
 	lastIndex := strings.LastIndex(oldContent, oldString)
 	if index != lastIndex {
-		return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match")
+		return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil
 	}
 
 	newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
@@ -368,7 +362,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
 	sessionID, messageID := GetContextValues(ctx)
 
 	if sessionID == "" || messageID == "" {
-		return er, fmt.Errorf("session ID and message ID are required for creating a new file")
+		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 	}
 	diff, stats, err := git.GenerateGitDiffWithStats(
 		removeWorkingDirectoryPrefix(filePath),
@@ -376,7 +370,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
 		newContent,
 	)
 	if err != nil {
-		return er, fmt.Errorf("failed to get file diff: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to get file diff: %w", err)
 	}
 
 	p := e.permissions.Request(
@@ -393,19 +387,21 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
 		},
 	)
 	if !p {
-		return er, fmt.Errorf("permission denied")
+		return ToolResponse{}, permission.ErrorPermissionDenied
 	}
 
 	err = os.WriteFile(filePath, []byte(newContent), 0o644)
 	if err != nil {
-		return er, fmt.Errorf("failed to write file: %w", err)
+		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
 	}
 
 	recordFileWrite(filePath)
 	recordFileRead(filePath)
-	er.text = "Content replaced in file: " + filePath
-	er.additions = stats.Additions
-	er.removals = stats.Removals
 
-	return er, nil
+	return WithResponseMetadata(
+		NewTextResponse("Content replaced in file: "+filePath),
+		EditResponseMetadata{
+			Additions: stats.Additions,
+			Removals:  stats.Removals,
+		}), nil
 }

internal/llm/tools/view.go 🔗

@@ -177,7 +177,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 			params.Offset+len(strings.Split(content, "\n")))
 	}
 	output += "\n</file>\n"
-	output += appendDiagnostics(filePath, v.lspClients)
+	output += getDiagnostics(filePath, v.lspClients)
 	recordFileRead(filePath)
 	return NewTextResponse(output), nil
 }

internal/llm/tools/write.go 🔗

@@ -183,7 +183,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 
 	result := fmt.Sprintf("File successfully written: %s", filePath)
 	result = fmt.Sprintf("<result>\n%s\n</result>", result)
-	result += appendDiagnostics(filePath, w.lspClients)
+	result += getDiagnostics(filePath, w.lspClients)
 	return WithResponseMetadata(NewTextResponse(result),
 		WriteResponseMetadata{
 			Additions: stats.Additions,