edit_file_tool.rs

  1use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{App, AppContext, AsyncApp, Entity, Task};
  5use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::{path::PathBuf, sync::Arc};
 10use ui::IconName;
 11
 12use crate::replace::replace_exact;
 13
 14#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 15pub struct EditFileToolInput {
 16    /// The full path of the file to modify in the project.
 17    ///
 18    /// WARNING: When specifying which file path need changing, you MUST
 19    /// start each path with one of the project's root directories.
 20    ///
 21    /// The following examples assume we have two root directories in the project:
 22    /// - backend
 23    /// - frontend
 24    ///
 25    /// <example>
 26    /// `backend/src/main.rs`
 27    ///
 28    /// Notice how the file path starts with root-1. Without that, the path
 29    /// would be ambiguous and the call would fail!
 30    /// </example>
 31    ///
 32    /// <example>
 33    /// `frontend/db.js`
 34    /// </example>
 35    pub path: PathBuf,
 36
 37    /// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
 38    ///
 39    /// <example>Fix API endpoint URLs</example>
 40    /// <example>Update copyright year in `page_footer`</example>
 41    pub display_description: String,
 42
 43    /// The text to replace.
 44    pub old_string: String,
 45
 46    /// The text to replace it with.
 47    pub new_string: String,
 48}
 49
 50#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 51struct PartialInput {
 52    #[serde(default)]
 53    path: String,
 54    #[serde(default)]
 55    display_description: String,
 56    #[serde(default)]
 57    old_string: String,
 58    #[serde(default)]
 59    new_string: String,
 60}
 61
 62pub struct EditFileTool;
 63
 64const DEFAULT_UI_TEXT: &str = "Editing file";
 65
 66impl Tool for EditFileTool {
 67    fn name(&self) -> String {
 68        "edit_file".into()
 69    }
 70
 71    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 72        false
 73    }
 74
 75    fn description(&self) -> String {
 76        include_str!("edit_file_tool/description.md").to_string()
 77    }
 78
 79    fn icon(&self) -> IconName {
 80        IconName::Pencil
 81    }
 82
 83    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 84        json_schema_for::<EditFileToolInput>(format)
 85    }
 86
 87    fn ui_text(&self, input: &serde_json::Value) -> String {
 88        match serde_json::from_value::<EditFileToolInput>(input.clone()) {
 89            Ok(input) => input.display_description,
 90            Err(_) => "Editing file".to_string(),
 91        }
 92    }
 93
 94    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
 95        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
 96            let description = input.display_description.trim();
 97            if !description.is_empty() {
 98                return description.to_string();
 99            }
100
101            let path = input.path.trim();
102            if !path.is_empty() {
103                return path.to_string();
104            }
105        }
106
107        DEFAULT_UI_TEXT.to_string()
108    }
109
110    fn run(
111        self: Arc<Self>,
112        input: serde_json::Value,
113        _messages: &[LanguageModelRequestMessage],
114        project: Entity<Project>,
115        action_log: Entity<ActionLog>,
116        cx: &mut App,
117    ) -> ToolResult {
118        let input = match serde_json::from_value::<EditFileToolInput>(input) {
119            Ok(input) => input,
120            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
121        };
122
123        cx.spawn(async move |cx: &mut AsyncApp| {
124            let project_path = project.read_with(cx, |project, cx| {
125                project
126                    .find_project_path(&input.path, cx)
127                    .context("Path not found in project")
128            })??;
129
130            let buffer = project
131                .update(cx, |project, cx| project.open_buffer(project_path, cx))?
132                .await?;
133
134            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
135
136            if input.old_string.is_empty() {
137                return Err(anyhow!("`old_string` cannot be empty. Use a different tool if you want to create a file."));
138            }
139
140            if input.old_string == input.new_string {
141                return Err(anyhow!("The `old_string` and `new_string` are identical, so no changes would be made."));
142            }
143
144            let result = cx
145                .background_spawn(async move {
146                    // Try to match exactly
147                    let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
148                    .await
149                    // If that fails, try being flexible about indentation
150                    .or_else(|| replace_with_flexible_indent(&input.old_string, &input.new_string, &snapshot))?;
151
152                    if diff.edits.is_empty() {
153                        return None;
154                    }
155
156                    let old_text = snapshot.text();
157
158                    Some((old_text, diff))
159                })
160                .await;
161
162            let Some((old_text, diff)) = result else {
163                let err = buffer.read_with(cx, |buffer, _cx| {
164                    let file_exists = buffer
165                        .file()
166                        .map_or(false, |file| file.disk_state().exists());
167
168                    if !file_exists {
169                        anyhow!("{} does not exist", input.path.display())
170                    } else if buffer.is_empty() {
171                        anyhow!(
172                            "{} is empty, so the provided `old_string` wasn't found.",
173                            input.path.display()
174                        )
175                    } else {
176                        anyhow!("Failed to match the provided `old_string`")
177                    }
178                })?;
179
180                return Err(err)
181            };
182
183            let snapshot = cx.update(|cx| {
184                action_log.update(cx, |log, cx| {
185                    log.buffer_read(buffer.clone(), cx)
186                });
187                let snapshot = buffer.update(cx, |buffer, cx| {
188                    buffer.finalize_last_transaction();
189                    buffer.apply_diff(diff, cx);
190                    buffer.finalize_last_transaction();
191                    buffer.snapshot()
192                });
193                action_log.update(cx, |log, cx| {
194                    log.buffer_edited(buffer.clone(), cx)
195                });
196                snapshot
197            })?;
198
199            project.update( cx, |project, cx| {
200                project.save_buffer(buffer, cx)
201            })?.await?;
202
203            let diff_str = cx.background_spawn(async move {
204                let new_text = snapshot.text();
205                language::unified_diff(&old_text, &new_text)
206            }).await;
207
208
209            Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
210
211        }).into()
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use serde_json::json;
219
220    #[test]
221    fn still_streaming_ui_text_with_path() {
222        let tool = EditFileTool;
223        let input = json!({
224            "path": "src/main.rs",
225            "display_description": "",
226            "old_string": "old code",
227            "new_string": "new code"
228        });
229
230        assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
231    }
232
233    #[test]
234    fn still_streaming_ui_text_with_description() {
235        let tool = EditFileTool;
236        let input = json!({
237            "path": "",
238            "display_description": "Fix error handling",
239            "old_string": "old code",
240            "new_string": "new code"
241        });
242
243        assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
244    }
245
246    #[test]
247    fn still_streaming_ui_text_with_path_and_description() {
248        let tool = EditFileTool;
249        let input = json!({
250            "path": "src/main.rs",
251            "display_description": "Fix error handling",
252            "old_string": "old code",
253            "new_string": "new code"
254        });
255
256        assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
257    }
258
259    #[test]
260    fn still_streaming_ui_text_no_path_or_description() {
261        let tool = EditFileTool;
262        let input = json!({
263            "path": "",
264            "display_description": "",
265            "old_string": "old code",
266            "new_string": "new code"
267        });
268
269        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
270    }
271
272    #[test]
273    fn still_streaming_ui_text_with_null() {
274        let tool = EditFileTool;
275        let input = serde_json::Value::Null;
276
277        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
278    }
279}