edit_files_tool.rs

  1mod edit_action;
  2pub mod log;
  3mod replace;
  4
  5use anyhow::{anyhow, Context, Result};
  6use assistant_tool::{ActionLog, Tool};
  7use collections::HashSet;
  8use edit_action::{EditAction, EditActionParser};
  9use futures::StreamExt;
 10use gpui::{App, AsyncApp, Entity, Task};
 11use language_model::{
 12    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 13};
 14use log::{EditToolLog, EditToolRequestId};
 15use project::Project;
 16use replace::{replace_exact, replace_with_flexible_indent};
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use std::fmt::Write;
 20use std::sync::Arc;
 21use util::ResultExt;
 22
 23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 24pub struct EditFilesToolInput {
 25    /// High-level edit instructions. These will be interpreted by a smaller
 26    /// model, so explain the changes you want that model to make and which
 27    /// file paths need changing. The description should be concise and clear.
 28    ///
 29    /// WARNING: When specifying which file paths need changing, you MUST
 30    /// start each path with one of the project's root directories.
 31    ///
 32    /// WARNING: NEVER include code blocks or snippets in edit instructions.
 33    /// Only provide natural language descriptions of the changes needed! The tool will
 34    /// reject any instructions that contain code blocks or snippets.
 35    ///
 36    /// The following examples assume we have two root directories in the project:
 37    /// - root-1
 38    /// - root-2
 39    ///
 40    /// <example>
 41    /// If you want to introduce a new quit function to kill the process, your
 42    /// instructions should be: "Add a new `quit` function to
 43    /// `root-1/src/main.rs` to kill the process".
 44    ///
 45    /// Notice how the file path starts with root-1. Without that, the path
 46    /// would be ambiguous and the call would fail!
 47    /// </example>
 48    ///
 49    /// <example>
 50    /// If you want to change documentation to always start with a capital
 51    /// letter, your instructions should be: "In `root-2/db.js`,
 52    /// `root-2/inMemory.js` and `root-2/sql.js`, change all the documentation
 53    /// to start with a capital letter".
 54    ///
 55    /// Notice how we never specify code snippets in the instructions!
 56    /// </example>
 57    pub edit_instructions: String,
 58
 59    /// A user-friendly description of what changes are being made.
 60    /// This will be shown to the user in the UI to describe the edit operation. The screen real estate for this UI will be extremely
 61    /// constrained, so make the description extremely terse.
 62    ///
 63    /// <example>
 64    /// For fixing a broken authentication system:
 65    /// "Fix auth bug in login flow"
 66    /// </example>
 67    ///
 68    /// <example>
 69    /// For adding unit tests to a module:
 70    /// "Add tests for user profile logic"
 71    /// </example>
 72    pub display_description: String,
 73}
 74
 75pub struct EditFilesTool;
 76
 77impl Tool for EditFilesTool {
 78    fn name(&self) -> String {
 79        "edit-files".into()
 80    }
 81
 82    fn description(&self) -> String {
 83        include_str!("./edit_files_tool/description.md").into()
 84    }
 85
 86    fn input_schema(&self) -> serde_json::Value {
 87        let schema = schemars::schema_for!(EditFilesToolInput);
 88        serde_json::to_value(&schema).unwrap()
 89    }
 90
 91    fn ui_text(&self, input: &serde_json::Value) -> String {
 92        match serde_json::from_value::<EditFilesToolInput>(input.clone()) {
 93            Ok(input) => input.display_description,
 94            Err(_) => "Edit files".to_string(),
 95        }
 96    }
 97
 98    fn run(
 99        self: Arc<Self>,
100        input: serde_json::Value,
101        messages: &[LanguageModelRequestMessage],
102        project: Entity<Project>,
103        action_log: Entity<ActionLog>,
104        cx: &mut App,
105    ) -> Task<Result<String>> {
106        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
107            Ok(input) => input,
108            Err(err) => return Task::ready(Err(anyhow!(err))),
109        };
110
111        match EditToolLog::try_global(cx) {
112            Some(log) => {
113                let req_id = log.update(cx, |log, cx| {
114                    log.new_request(input.edit_instructions.clone(), cx)
115                });
116
117                let task = EditToolRequest::new(
118                    input,
119                    messages,
120                    project,
121                    action_log,
122                    Some((log.clone(), req_id)),
123                    cx,
124                );
125
126                cx.spawn(async move |cx| {
127                    let result = task.await;
128
129                    let str_result = match &result {
130                        Ok(out) => Ok(out.clone()),
131                        Err(err) => Err(err.to_string()),
132                    };
133
134                    log.update(cx, |log, cx| log.set_tool_output(req_id, str_result, cx))
135                        .log_err();
136
137                    result
138                })
139            }
140
141            None => EditToolRequest::new(input, messages, project, action_log, None, cx),
142        }
143    }
144}
145
146struct EditToolRequest {
147    parser: EditActionParser,
148    output: String,
149    changed_buffers: HashSet<Entity<language::Buffer>>,
150    bad_searches: Vec<BadSearch>,
151    project: Entity<Project>,
152    action_log: Entity<ActionLog>,
153    tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
154}
155
156#[derive(Debug)]
157enum DiffResult {
158    BadSearch(BadSearch),
159    Diff(language::Diff),
160}
161
162#[derive(Debug)]
163enum BadSearch {
164    NoMatch {
165        file_path: String,
166        search: String,
167    },
168    EmptyBuffer {
169        file_path: String,
170        search: String,
171        exists: bool,
172    },
173}
174
175impl EditToolRequest {
176    fn new(
177        input: EditFilesToolInput,
178        messages: &[LanguageModelRequestMessage],
179        project: Entity<Project>,
180        action_log: Entity<ActionLog>,
181        tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
182        cx: &mut App,
183    ) -> Task<Result<String>> {
184        let model_registry = LanguageModelRegistry::read_global(cx);
185        let Some(model) = model_registry.editor_model() else {
186            return Task::ready(Err(anyhow!("No editor model configured")));
187        };
188
189        let mut messages = messages.to_vec();
190        // Remove the last tool use (this run) to prevent an invalid request
191        'outer: for message in messages.iter_mut().rev() {
192            for (index, content) in message.content.iter().enumerate().rev() {
193                match content {
194                    MessageContent::ToolUse(_) => {
195                        message.content.remove(index);
196                        break 'outer;
197                    }
198                    MessageContent::ToolResult(_) => {
199                        // If we find any tool results before a tool use, the request is already valid
200                        break 'outer;
201                    }
202                    MessageContent::Text(_) | MessageContent::Image(_) => {}
203                }
204            }
205        }
206
207        messages.push(LanguageModelRequestMessage {
208            role: Role::User,
209            content: vec![
210                include_str!("./edit_files_tool/edit_prompt.md").into(),
211                input.edit_instructions.into(),
212            ],
213            cache: false,
214        });
215
216        cx.spawn(async move |cx| {
217            let llm_request = LanguageModelRequest {
218                messages,
219                tools: vec![],
220                stop: vec![],
221                temperature: Some(0.0),
222            };
223
224            let stream = model.stream_completion_text(llm_request, &cx);
225            let mut chunks = stream.await?;
226
227            let mut request = Self {
228                parser: EditActionParser::new(),
229                // we start with the success header so we don't need to shift the output in the common case
230                output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
231                changed_buffers: HashSet::default(),
232                bad_searches: Vec::new(),
233                action_log,
234                project,
235                tool_log,
236            };
237
238            while let Some(chunk) = chunks.stream.next().await {
239                request.process_response_chunk(&chunk?, cx).await?;
240            }
241
242            request.finalize(cx).await
243        })
244    }
245
246    async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
247        let new_actions = self.parser.parse_chunk(chunk);
248
249        if let Some((ref log, req_id)) = self.tool_log {
250            log.update(cx, |log, cx| {
251                log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
252            })
253            .log_err();
254        }
255
256        for action in new_actions {
257            self.apply_action(action, cx).await?;
258        }
259
260        Ok(())
261    }
262
263    async fn apply_action(
264        &mut self,
265        (action, source): (EditAction, String),
266        cx: &mut AsyncApp,
267    ) -> Result<()> {
268        let project_path = self.project.read_with(cx, |project, cx| {
269            project
270                .find_project_path(action.file_path(), cx)
271                .context("Path not found in project")
272        })??;
273
274        let buffer = self
275            .project
276            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
277            .await?;
278
279        let result = match action {
280            EditAction::Replace {
281                old,
282                new,
283                file_path,
284            } => {
285                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
286
287                cx.background_executor()
288                    .spawn(Self::replace_diff(old, new, file_path, snapshot))
289                    .await
290            }
291            EditAction::Write { content, .. } => Ok(DiffResult::Diff(
292                buffer
293                    .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
294                    .await,
295            )),
296        }?;
297
298        match result {
299            DiffResult::BadSearch(invalid_replace) => {
300                self.bad_searches.push(invalid_replace);
301            }
302            DiffResult::Diff(diff) => {
303                let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
304
305                write!(&mut self.output, "\n\n{}", source)?;
306                self.changed_buffers.insert(buffer);
307            }
308        }
309
310        Ok(())
311    }
312
313    async fn replace_diff(
314        old: String,
315        new: String,
316        file_path: std::path::PathBuf,
317        snapshot: language::BufferSnapshot,
318    ) -> Result<DiffResult> {
319        if snapshot.is_empty() {
320            let exists = snapshot
321                .file()
322                .map_or(false, |file| file.disk_state().exists());
323
324            return Ok(DiffResult::BadSearch(BadSearch::EmptyBuffer {
325                file_path: file_path.display().to_string(),
326                exists,
327                search: old,
328            }));
329        }
330
331        let result =
332            // Try to match exactly
333            replace_exact(&old, &new, &snapshot)
334            .await
335            // If that fails, try being flexible about indentation
336            .or_else(|| replace_with_flexible_indent(&old, &new, &snapshot));
337
338        let Some(diff) = result else {
339            return anyhow::Ok(DiffResult::BadSearch(BadSearch::NoMatch {
340                search: old,
341                file_path: file_path.display().to_string(),
342            }));
343        };
344
345        anyhow::Ok(DiffResult::Diff(diff))
346    }
347
348    const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
349    const ERROR_OUTPUT_HEADER_NO_EDITS: &str = "I couldn't apply any edits!";
350    const ERROR_OUTPUT_HEADER_WITH_EDITS: &str =
351        "Errors occurred. First, here's a list of the edits we managed to apply:";
352
353    async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
354        let changed_buffer_count = self.changed_buffers.len();
355
356        // Save each buffer once at the end
357        for buffer in &self.changed_buffers {
358            self.project
359                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
360                .await?;
361        }
362
363        self.action_log
364            .update(cx, |log, cx| log.buffer_edited(self.changed_buffers, cx))
365            .log_err();
366
367        let errors = self.parser.errors();
368
369        if errors.is_empty() && self.bad_searches.is_empty() {
370            if changed_buffer_count == 0 {
371                return Err(anyhow!(
372                    "The instructions didn't lead to any changes. You might need to consult the file contents first."
373                ));
374            }
375
376            Ok(self.output)
377        } else {
378            let mut output = self.output;
379
380            if output.is_empty() {
381                output.replace_range(
382                    0..Self::SUCCESS_OUTPUT_HEADER.len(),
383                    Self::ERROR_OUTPUT_HEADER_NO_EDITS,
384                );
385            } else {
386                output.replace_range(
387                    0..Self::SUCCESS_OUTPUT_HEADER.len(),
388                    Self::ERROR_OUTPUT_HEADER_WITH_EDITS,
389                );
390            }
391
392            if !self.bad_searches.is_empty() {
393                writeln!(
394                    &mut output,
395                    "\n\n# {} SEARCH/REPLACE block(s) failed to match:\n",
396                    self.bad_searches.len()
397                )?;
398
399                for bad_search in self.bad_searches {
400                    match bad_search {
401                        BadSearch::NoMatch { file_path, search } => {
402                            writeln!(
403                                &mut output,
404                                "## No exact match in: `{}`\n```\n{}\n```\n",
405                                file_path, search,
406                            )?;
407                        }
408                        BadSearch::EmptyBuffer {
409                            file_path,
410                            exists: true,
411                            search,
412                        } => {
413                            writeln!(
414                                &mut output,
415                                "## No match because `{}` is empty:\n```\n{}\n```\n",
416                                file_path, search,
417                            )?;
418                        }
419                        BadSearch::EmptyBuffer {
420                            file_path,
421                            exists: false,
422                            search,
423                        } => {
424                            writeln!(
425                                &mut output,
426                                "## No match because `{}` does not exist:\n```\n{}\n```\n",
427                                file_path, search,
428                            )?;
429                        }
430                    }
431                }
432
433                write!(&mut output,
434                    "The SEARCH section must exactly match an existing block of lines including all white \
435                    space, comments, indentation, docstrings, etc."
436                )?;
437            }
438
439            if !errors.is_empty() {
440                writeln!(
441                    &mut output,
442                    "\n\n# {} SEARCH/REPLACE blocks failed to parse:",
443                    errors.len()
444                )?;
445
446                for error in errors {
447                    writeln!(&mut output, "- {}", error)?;
448                }
449            }
450
451            if changed_buffer_count > 0 {
452                writeln!(
453                    &mut output,
454                    "\n\nThe other SEARCH/REPLACE blocks were applied successfully. Do not re-send them!",
455                )?;
456            }
457
458            writeln!(
459                &mut output,
460                "{}You can fix errors by running the tool again. You can include instructions, \
461                but errors are part of the conversation so you don't need to repeat them.",
462                if changed_buffer_count == 0 {
463                    "\n\n"
464                } else {
465                    ""
466                }
467            )?;
468
469            Err(anyhow!(output))
470        }
471    }
472}