edit_files_tool.rs

  1mod edit_action;
  2pub mod log;
  3
  4use anyhow::{anyhow, Context, Result};
  5use assistant_tool::{ActionLog, Tool};
  6use collections::HashSet;
  7use edit_action::{EditAction, EditActionParser};
  8use futures::StreamExt;
  9use gpui::{App, AsyncApp, Entity, Task};
 10use language_model::{
 11    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 12};
 13use log::{EditToolLog, EditToolRequestId};
 14use project::{search::SearchQuery, Project};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use std::fmt::Write;
 18use std::sync::Arc;
 19use util::paths::PathMatcher;
 20use util::ResultExt;
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct EditFilesToolInput {
 24    /// High-level edit instructions. These will be interpreted by a smaller
 25    /// model, so explain the changes you want that model to make and which
 26    /// file paths need changing.
 27    ///
 28    /// The description should be concise and clear. We will show this
 29    /// description to the user as well.
 30    ///
 31    /// WARNING: When specifying which file paths need changing, you MUST
 32    /// start each path with one of the project's root directories.
 33    ///
 34    /// WARNING: NEVER include code blocks or snippets in edit instructions.
 35    /// Only provide natural language descriptions of the changes needed! The tool will
 36    /// reject any instructions that contain code blocks or snippets.
 37    ///
 38    /// The following examples assume we have two root directories in the project:
 39    /// - root-1
 40    /// - root-2
 41    ///
 42    /// <example>
 43    /// If you want to introduce a new quit function to kill the process, your
 44    /// instructions should be: "Add a new `quit` function to
 45    /// `root-1/src/main.rs` to kill the process".
 46    ///
 47    /// Notice how the file path starts with root-1. Without that, the path
 48    /// would be ambiguous and the call would fail!
 49    /// </example>
 50    ///
 51    /// <example>
 52    /// If you want to change documentation to always start with a capital
 53    /// letter, your instructions should be: "In `root-2/db.js`,
 54    /// `root-2/inMemory.js` and `root-2/sql.js`, change all the documentation
 55    /// to start with a capital letter".
 56    ///
 57    /// Notice how we never specify code snippets in the instructions!
 58    /// </example>
 59    pub edit_instructions: String,
 60}
 61
 62pub struct EditFilesTool;
 63
 64impl Tool for EditFilesTool {
 65    fn name(&self) -> String {
 66        "edit-files".into()
 67    }
 68
 69    fn description(&self) -> String {
 70        include_str!("./edit_files_tool/description.md").into()
 71    }
 72
 73    fn input_schema(&self) -> serde_json::Value {
 74        let schema = schemars::schema_for!(EditFilesToolInput);
 75        serde_json::to_value(&schema).unwrap()
 76    }
 77
 78    fn run(
 79        self: Arc<Self>,
 80        input: serde_json::Value,
 81        messages: &[LanguageModelRequestMessage],
 82        project: Entity<Project>,
 83        action_log: Entity<ActionLog>,
 84        cx: &mut App,
 85    ) -> Task<Result<String>> {
 86        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
 87            Ok(input) => input,
 88            Err(err) => return Task::ready(Err(anyhow!(err))),
 89        };
 90
 91        match EditToolLog::try_global(cx) {
 92            Some(log) => {
 93                let req_id = log.update(cx, |log, cx| {
 94                    log.new_request(input.edit_instructions.clone(), cx)
 95                });
 96
 97                let task = EditToolRequest::new(
 98                    input,
 99                    messages,
100                    project,
101                    action_log,
102                    Some((log.clone(), req_id)),
103                    cx,
104                );
105
106                cx.spawn(async move |cx| {
107                    let result = task.await;
108
109                    let str_result = match &result {
110                        Ok(out) => Ok(out.clone()),
111                        Err(err) => Err(err.to_string()),
112                    };
113
114                    log.update(cx, |log, cx| log.set_tool_output(req_id, str_result, cx))
115                        .log_err();
116
117                    result
118                })
119            }
120
121            None => EditToolRequest::new(input, messages, project, action_log, None, cx),
122        }
123    }
124}
125
126struct EditToolRequest {
127    parser: EditActionParser,
128    output: String,
129    changed_buffers: HashSet<Entity<language::Buffer>>,
130    bad_searches: Vec<BadSearch>,
131    project: Entity<Project>,
132    action_log: Entity<ActionLog>,
133    tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
134}
135
136#[derive(Debug)]
137enum DiffResult {
138    BadSearch(BadSearch),
139    Diff(language::Diff),
140}
141
142#[derive(Debug)]
143struct BadSearch {
144    file_path: String,
145    search: String,
146}
147
148impl EditToolRequest {
149    fn new(
150        input: EditFilesToolInput,
151        messages: &[LanguageModelRequestMessage],
152        project: Entity<Project>,
153        action_log: Entity<ActionLog>,
154        tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
155        cx: &mut App,
156    ) -> Task<Result<String>> {
157        let model_registry = LanguageModelRegistry::read_global(cx);
158        let Some(model) = model_registry.editor_model() else {
159            return Task::ready(Err(anyhow!("No editor model configured")));
160        };
161
162        let mut messages = messages.to_vec();
163        // Remove the last tool use (this run) to prevent an invalid request
164        'outer: for message in messages.iter_mut().rev() {
165            for (index, content) in message.content.iter().enumerate().rev() {
166                match content {
167                    MessageContent::ToolUse(_) => {
168                        message.content.remove(index);
169                        break 'outer;
170                    }
171                    MessageContent::ToolResult(_) => {
172                        // If we find any tool results before a tool use, the request is already valid
173                        break 'outer;
174                    }
175                    MessageContent::Text(_) | MessageContent::Image(_) => {}
176                }
177            }
178        }
179
180        messages.push(LanguageModelRequestMessage {
181            role: Role::User,
182            content: vec![
183                include_str!("./edit_files_tool/edit_prompt.md").into(),
184                input.edit_instructions.into(),
185            ],
186            cache: false,
187        });
188
189        cx.spawn(async move |cx| {
190            let llm_request = LanguageModelRequest {
191                messages,
192                tools: vec![],
193                stop: vec![],
194                temperature: Some(0.0),
195            };
196
197            let stream = model.stream_completion_text(llm_request, &cx);
198            let mut chunks = stream.await?;
199
200            let mut request = Self {
201                parser: EditActionParser::new(),
202                // we start with the success header so we don't need to shift the output in the common case
203                output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
204                changed_buffers: HashSet::default(),
205                bad_searches: Vec::new(),
206                action_log,
207                project,
208                tool_log,
209            };
210
211            while let Some(chunk) = chunks.stream.next().await {
212                request.process_response_chunk(&chunk?, cx).await?;
213            }
214
215            request.finalize(cx).await
216        })
217    }
218
219    async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
220        let new_actions = self.parser.parse_chunk(chunk);
221
222        if let Some((ref log, req_id)) = self.tool_log {
223            log.update(cx, |log, cx| {
224                log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
225            })
226            .log_err();
227        }
228
229        for action in new_actions {
230            self.apply_action(action, cx).await?;
231        }
232
233        Ok(())
234    }
235
236    async fn apply_action(
237        &mut self,
238        (action, source): (EditAction, String),
239        cx: &mut AsyncApp,
240    ) -> Result<()> {
241        let project_path = self.project.read_with(cx, |project, cx| {
242            project
243                .find_project_path(action.file_path(), cx)
244                .context("Path not found in project")
245        })??;
246
247        let buffer = self
248            .project
249            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
250            .await?;
251
252        let result = match action {
253            EditAction::Replace {
254                old,
255                new,
256                file_path,
257            } => {
258                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
259
260                cx.background_executor()
261                    .spawn(Self::replace_diff(old, new, file_path, snapshot))
262                    .await
263            }
264            EditAction::Write { content, .. } => Ok(DiffResult::Diff(
265                buffer
266                    .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
267                    .await,
268            )),
269        }?;
270
271        match result {
272            DiffResult::BadSearch(invalid_replace) => {
273                self.bad_searches.push(invalid_replace);
274            }
275            DiffResult::Diff(diff) => {
276                let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
277
278                write!(&mut self.output, "\n\n{}", source)?;
279                self.changed_buffers.insert(buffer);
280            }
281        }
282
283        Ok(())
284    }
285
286    async fn replace_diff(
287        old: String,
288        new: String,
289        file_path: std::path::PathBuf,
290        snapshot: language::BufferSnapshot,
291    ) -> Result<DiffResult> {
292        let query = SearchQuery::text(
293            old.clone(),
294            false,
295            true,
296            true,
297            PathMatcher::new(&[])?,
298            PathMatcher::new(&[])?,
299            None,
300        )?;
301
302        let matches = query.search(&snapshot, None).await;
303
304        if matches.is_empty() {
305            return Ok(DiffResult::BadSearch(BadSearch {
306                search: old.clone(),
307                file_path: file_path.display().to_string(),
308            }));
309        }
310
311        let edit_range = matches[0].clone();
312        let diff = language::text_diff(&old, &new);
313
314        let edits = diff
315            .into_iter()
316            .map(|(old_range, text)| {
317                let start = edit_range.start + old_range.start;
318                let end = edit_range.start + old_range.end;
319                (start..end, text)
320            })
321            .collect::<Vec<_>>();
322
323        let diff = language::Diff {
324            base_version: snapshot.version().clone(),
325            line_ending: snapshot.line_ending(),
326            edits,
327        };
328
329        anyhow::Ok(DiffResult::Diff(diff))
330    }
331
332    const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
333    const ERROR_OUTPUT_HEADER_NO_EDITS: &str = "I couldn't apply any edits!";
334    const ERROR_OUTPUT_HEADER_WITH_EDITS: &str =
335        "Errors occurred. First, here's a list of the edits we managed to apply:";
336
337    async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
338        let changed_buffer_count = self.changed_buffers.len();
339
340        // Save each buffer once at the end
341        for buffer in &self.changed_buffers {
342            self.project
343                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
344                .await?;
345        }
346
347        self.action_log
348            .update(cx, |log, cx| log.buffer_edited(self.changed_buffers, cx))
349            .log_err();
350
351        let errors = self.parser.errors();
352
353        if errors.is_empty() && self.bad_searches.is_empty() {
354            if changed_buffer_count == 0 {
355                return Err(anyhow!(
356                    "The instructions didn't lead to any changes. You might need to consult the file contents first."
357                ));
358            }
359
360            Ok(self.output)
361        } else {
362            let mut output = self.output;
363
364            if output.is_empty() {
365                output.replace_range(
366                    0..Self::SUCCESS_OUTPUT_HEADER.len(),
367                    Self::ERROR_OUTPUT_HEADER_NO_EDITS,
368                );
369            } else {
370                output.replace_range(
371                    0..Self::SUCCESS_OUTPUT_HEADER.len(),
372                    Self::ERROR_OUTPUT_HEADER_WITH_EDITS,
373                );
374            }
375
376            if !self.bad_searches.is_empty() {
377                writeln!(
378                    &mut output,
379                    "\n\n# {} SEARCH/REPLACE block(s) failed to match:\n",
380                    self.bad_searches.len()
381                )?;
382
383                for replace in self.bad_searches {
384                    writeln!(
385                        &mut output,
386                        "## No exact match in: {}\n```\n{}\n```\n",
387                        replace.file_path, replace.search,
388                    )?;
389                }
390
391                write!(&mut output,
392                    "The SEARCH section must exactly match an existing block of lines including all white \
393                    space, comments, indentation, docstrings, etc."
394                )?;
395            }
396
397            if !errors.is_empty() {
398                writeln!(
399                    &mut output,
400                    "\n\n# {} SEARCH/REPLACE blocks failed to parse:",
401                    errors.len()
402                )?;
403
404                for error in errors {
405                    writeln!(&mut output, "- {}", error)?;
406                }
407            }
408
409            if changed_buffer_count > 0 {
410                writeln!(
411                    &mut output,
412                    "\n\nThe other SEARCH/REPLACE blocks were applied successfully. Do not re-send them!",
413                )?;
414            }
415
416            writeln!(
417                &mut output,
418                "{}You can fix errors by running the tool again. You can include instructions, \
419                but errors are part of the conversation so you don't need to repeat them.",
420                if changed_buffer_count == 0 {
421                    "\n\n"
422                } else {
423                    ""
424                }
425            )?;
426
427            Err(anyhow!(output))
428        }
429    }
430}