edit_files_tool.rs

  1mod edit_action;
  2pub mod log;
  3mod replace;
  4
  5use anyhow::{anyhow, Context, Result};
  6use assistant_tool::{ActionLog, Tool};
  7use collections::HashSet;
  8use edit_action::{EditAction, EditActionParser};
  9use futures::StreamExt;
 10use gpui::{App, AsyncApp, Entity, Task};
 11use language_model::{
 12    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 13};
 14use log::{EditToolLog, EditToolRequestId};
 15use project::Project;
 16use replace::{replace_exact, replace_with_flexible_indent};
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use std::fmt::Write;
 20use std::sync::Arc;
 21use util::ResultExt;
 22
 23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 24pub struct EditFilesToolInput {
 25    /// High-level edit instructions. These will be interpreted by a smaller
 26    /// model, so explain the changes you want that model to make and which
 27    /// file paths need changing.
 28    ///
 29    /// The description should be concise and clear. We will show this
 30    /// description to the user as well.
 31    ///
 32    /// WARNING: When specifying which file paths need changing, you MUST
 33    /// start each path with one of the project's root directories.
 34    ///
 35    /// WARNING: NEVER include code blocks or snippets in edit instructions.
 36    /// Only provide natural language descriptions of the changes needed! The tool will
 37    /// reject any instructions that contain code blocks or snippets.
 38    ///
 39    /// The following examples assume we have two root directories in the project:
 40    /// - root-1
 41    /// - root-2
 42    ///
 43    /// <example>
 44    /// If you want to introduce a new quit function to kill the process, your
 45    /// instructions should be: "Add a new `quit` function to
 46    /// `root-1/src/main.rs` to kill the process".
 47    ///
 48    /// Notice how the file path starts with root-1. Without that, the path
 49    /// would be ambiguous and the call would fail!
 50    /// </example>
 51    ///
 52    /// <example>
 53    /// If you want to change documentation to always start with a capital
 54    /// letter, your instructions should be: "In `root-2/db.js`,
 55    /// `root-2/inMemory.js` and `root-2/sql.js`, change all the documentation
 56    /// to start with a capital letter".
 57    ///
 58    /// Notice how we never specify code snippets in the instructions!
 59    /// </example>
 60    pub edit_instructions: String,
 61}
 62
 63pub struct EditFilesTool;
 64
 65impl Tool for EditFilesTool {
 66    fn name(&self) -> String {
 67        "edit-files".into()
 68    }
 69
 70    fn description(&self) -> String {
 71        include_str!("./edit_files_tool/description.md").into()
 72    }
 73
 74    fn input_schema(&self) -> serde_json::Value {
 75        let schema = schemars::schema_for!(EditFilesToolInput);
 76        serde_json::to_value(&schema).unwrap()
 77    }
 78
 79    fn run(
 80        self: Arc<Self>,
 81        input: serde_json::Value,
 82        messages: &[LanguageModelRequestMessage],
 83        project: Entity<Project>,
 84        action_log: Entity<ActionLog>,
 85        cx: &mut App,
 86    ) -> Task<Result<String>> {
 87        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
 88            Ok(input) => input,
 89            Err(err) => return Task::ready(Err(anyhow!(err))),
 90        };
 91
 92        match EditToolLog::try_global(cx) {
 93            Some(log) => {
 94                let req_id = log.update(cx, |log, cx| {
 95                    log.new_request(input.edit_instructions.clone(), cx)
 96                });
 97
 98                let task = EditToolRequest::new(
 99                    input,
100                    messages,
101                    project,
102                    action_log,
103                    Some((log.clone(), req_id)),
104                    cx,
105                );
106
107                cx.spawn(async move |cx| {
108                    let result = task.await;
109
110                    let str_result = match &result {
111                        Ok(out) => Ok(out.clone()),
112                        Err(err) => Err(err.to_string()),
113                    };
114
115                    log.update(cx, |log, cx| log.set_tool_output(req_id, str_result, cx))
116                        .log_err();
117
118                    result
119                })
120            }
121
122            None => EditToolRequest::new(input, messages, project, action_log, None, cx),
123        }
124    }
125}
126
127struct EditToolRequest {
128    parser: EditActionParser,
129    output: String,
130    changed_buffers: HashSet<Entity<language::Buffer>>,
131    bad_searches: Vec<BadSearch>,
132    project: Entity<Project>,
133    action_log: Entity<ActionLog>,
134    tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
135}
136
137#[derive(Debug)]
138enum DiffResult {
139    BadSearch(BadSearch),
140    Diff(language::Diff),
141}
142
143#[derive(Debug)]
144struct BadSearch {
145    file_path: String,
146    search: String,
147}
148
149impl EditToolRequest {
150    fn new(
151        input: EditFilesToolInput,
152        messages: &[LanguageModelRequestMessage],
153        project: Entity<Project>,
154        action_log: Entity<ActionLog>,
155        tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
156        cx: &mut App,
157    ) -> Task<Result<String>> {
158        let model_registry = LanguageModelRegistry::read_global(cx);
159        let Some(model) = model_registry.editor_model() else {
160            return Task::ready(Err(anyhow!("No editor model configured")));
161        };
162
163        let mut messages = messages.to_vec();
164        // Remove the last tool use (this run) to prevent an invalid request
165        'outer: for message in messages.iter_mut().rev() {
166            for (index, content) in message.content.iter().enumerate().rev() {
167                match content {
168                    MessageContent::ToolUse(_) => {
169                        message.content.remove(index);
170                        break 'outer;
171                    }
172                    MessageContent::ToolResult(_) => {
173                        // If we find any tool results before a tool use, the request is already valid
174                        break 'outer;
175                    }
176                    MessageContent::Text(_) | MessageContent::Image(_) => {}
177                }
178            }
179        }
180
181        messages.push(LanguageModelRequestMessage {
182            role: Role::User,
183            content: vec![
184                include_str!("./edit_files_tool/edit_prompt.md").into(),
185                input.edit_instructions.into(),
186            ],
187            cache: false,
188        });
189
190        cx.spawn(async move |cx| {
191            let llm_request = LanguageModelRequest {
192                messages,
193                tools: vec![],
194                stop: vec![],
195                temperature: Some(0.0),
196            };
197
198            let stream = model.stream_completion_text(llm_request, &cx);
199            let mut chunks = stream.await?;
200
201            let mut request = Self {
202                parser: EditActionParser::new(),
203                // we start with the success header so we don't need to shift the output in the common case
204                output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
205                changed_buffers: HashSet::default(),
206                bad_searches: Vec::new(),
207                action_log,
208                project,
209                tool_log,
210            };
211
212            while let Some(chunk) = chunks.stream.next().await {
213                request.process_response_chunk(&chunk?, cx).await?;
214            }
215
216            request.finalize(cx).await
217        })
218    }
219
220    async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
221        let new_actions = self.parser.parse_chunk(chunk);
222
223        if let Some((ref log, req_id)) = self.tool_log {
224            log.update(cx, |log, cx| {
225                log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
226            })
227            .log_err();
228        }
229
230        for action in new_actions {
231            self.apply_action(action, cx).await?;
232        }
233
234        Ok(())
235    }
236
237    async fn apply_action(
238        &mut self,
239        (action, source): (EditAction, String),
240        cx: &mut AsyncApp,
241    ) -> Result<()> {
242        let project_path = self.project.read_with(cx, |project, cx| {
243            project
244                .find_project_path(action.file_path(), cx)
245                .context("Path not found in project")
246        })??;
247
248        let buffer = self
249            .project
250            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
251            .await?;
252
253        let result = match action {
254            EditAction::Replace {
255                old,
256                new,
257                file_path,
258            } => {
259                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
260
261                cx.background_executor()
262                    .spawn(Self::replace_diff(old, new, file_path, snapshot))
263                    .await
264            }
265            EditAction::Write { content, .. } => Ok(DiffResult::Diff(
266                buffer
267                    .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
268                    .await,
269            )),
270        }?;
271
272        match result {
273            DiffResult::BadSearch(invalid_replace) => {
274                self.bad_searches.push(invalid_replace);
275            }
276            DiffResult::Diff(diff) => {
277                let edit_ids = buffer.update(cx, |buffer, cx| {
278                    buffer.finalize_last_transaction();
279                    buffer.apply_diff(diff, cx);
280                    let transaction = buffer.finalize_last_transaction();
281                    transaction.map_or(Vec::new(), |transaction| transaction.edit_ids.clone())
282                })?;
283                self.action_log
284                    .update(cx, |log, cx| {
285                        log.buffer_edited(buffer.clone(), edit_ids, cx)
286                    })?
287                    .await?;
288
289                write!(&mut self.output, "\n\n{}", source)?;
290                self.changed_buffers.insert(buffer);
291            }
292        }
293
294        Ok(())
295    }
296
297    async fn replace_diff(
298        old: String,
299        new: String,
300        file_path: std::path::PathBuf,
301        snapshot: language::BufferSnapshot,
302    ) -> Result<DiffResult> {
303        let result =
304            // Try to match exactly
305            replace_exact(&old, &new, &snapshot)
306            .await
307            // If that fails, try being flexible about indentation
308            .or_else(|| replace_with_flexible_indent(&old, &new, &snapshot));
309
310        let Some(diff) = result else {
311            return anyhow::Ok(DiffResult::BadSearch(BadSearch {
312                search: old,
313                file_path: file_path.display().to_string(),
314            }));
315        };
316
317        anyhow::Ok(DiffResult::Diff(diff))
318    }
319
320    const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
321    const ERROR_OUTPUT_HEADER_NO_EDITS: &str = "I couldn't apply any edits!";
322    const ERROR_OUTPUT_HEADER_WITH_EDITS: &str =
323        "Errors occurred. First, here's a list of the edits we managed to apply:";
324
325    async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
326        let changed_buffer_count = self.changed_buffers.len();
327
328        // Save each buffer once at the end
329        for buffer in &self.changed_buffers {
330            self.project
331                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
332                .await?;
333        }
334
335        let errors = self.parser.errors();
336
337        if errors.is_empty() && self.bad_searches.is_empty() {
338            if changed_buffer_count == 0 {
339                return Err(anyhow!(
340                    "The instructions didn't lead to any changes. You might need to consult the file contents first."
341                ));
342            }
343
344            Ok(self.output)
345        } else {
346            let mut output = self.output;
347
348            if output.is_empty() {
349                output.replace_range(
350                    0..Self::SUCCESS_OUTPUT_HEADER.len(),
351                    Self::ERROR_OUTPUT_HEADER_NO_EDITS,
352                );
353            } else {
354                output.replace_range(
355                    0..Self::SUCCESS_OUTPUT_HEADER.len(),
356                    Self::ERROR_OUTPUT_HEADER_WITH_EDITS,
357                );
358            }
359
360            if !self.bad_searches.is_empty() {
361                writeln!(
362                    &mut output,
363                    "\n\n# {} SEARCH/REPLACE block(s) failed to match:\n",
364                    self.bad_searches.len()
365                )?;
366
367                for replace in self.bad_searches {
368                    writeln!(
369                        &mut output,
370                        "## No exact match in: {}\n```\n{}\n```\n",
371                        replace.file_path, replace.search,
372                    )?;
373                }
374
375                write!(&mut output,
376                    "The SEARCH section must exactly match an existing block of lines including all white \
377                    space, comments, indentation, docstrings, etc."
378                )?;
379            }
380
381            if !errors.is_empty() {
382                writeln!(
383                    &mut output,
384                    "\n\n# {} SEARCH/REPLACE blocks failed to parse:",
385                    errors.len()
386                )?;
387
388                for error in errors {
389                    writeln!(&mut output, "- {}", error)?;
390                }
391            }
392
393            if changed_buffer_count > 0 {
394                writeln!(
395                    &mut output,
396                    "\n\nThe other SEARCH/REPLACE blocks were applied successfully. Do not re-send them!",
397                )?;
398            }
399
400            writeln!(
401                &mut output,
402                "{}You can fix errors by running the tool again. You can include instructions, \
403                but errors are part of the conversation so you don't need to repeat them.",
404                if changed_buffer_count == 0 {
405                    "\n\n"
406                } else {
407                    ""
408                }
409            )?;
410
411            Err(anyhow!(output))
412        }
413    }
414}