edit_files_tool.rs

  1mod edit_action;
  2pub mod log;
  3
  4use anyhow::{anyhow, Context, Result};
  5use assistant_tool::{ActionLog, Tool};
  6use collections::HashSet;
  7use edit_action::{EditAction, EditActionParser};
  8use futures::StreamExt;
  9use gpui::{App, AsyncApp, Entity, Task};
 10use language_model::{
 11    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 12};
 13use log::{EditToolLog, EditToolRequestId};
 14use project::{search::SearchQuery, Project};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use std::fmt::Write;
 18use std::sync::Arc;
 19use util::paths::PathMatcher;
 20use util::ResultExt;
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct EditFilesToolInput {
 24    /// High-level edit instructions. These will be interpreted by a smaller
 25    /// model, so explain the changes you want that model to make and which
 26    /// file paths need changing.
 27    ///
 28    /// The description should be concise and clear. We will show this
 29    /// description to the user as well.
 30    ///
 31    /// WARNING: When specifying which file paths need changing, you MUST
 32    /// start each path with one of the project's root directories.
 33    ///
 34    /// WARNING: NEVER include code blocks or snippets in edit instructions.
 35    /// Only provide natural language descriptions of the changes needed! The tool will
 36    /// reject any instructions that contain code blocks or snippets.
 37    ///
 38    /// The following examples assume we have two root directories in the project:
 39    /// - root-1
 40    /// - root-2
 41    ///
 42    /// <example>
 43    /// If you want to introduce a new quit function to kill the process, your
 44    /// instructions should be: "Add a new `quit` function to
 45    /// `root-1/src/main.rs` to kill the process".
 46    ///
 47    /// Notice how the file path starts with root-1. Without that, the path
 48    /// would be ambiguous and the call would fail!
 49    /// </example>
 50    ///
 51    /// <example>
 52    /// If you want to change documentation to always start with a capital
 53    /// letter, your instructions should be: "In `root-2/db.js`,
 54    /// `root-2/inMemory.js` and `root-2/sql.js`, change all the documentation
 55    /// to start with a capital letter".
 56    ///
 57    /// Notice how we never specify code snippets in the instructions!
 58    /// </example>
 59    pub edit_instructions: String,
 60}
 61
 62pub struct EditFilesTool;
 63
 64impl Tool for EditFilesTool {
 65    fn name(&self) -> String {
 66        "edit-files".into()
 67    }
 68
 69    fn description(&self) -> String {
 70        include_str!("./edit_files_tool/description.md").into()
 71    }
 72
 73    fn input_schema(&self) -> serde_json::Value {
 74        let schema = schemars::schema_for!(EditFilesToolInput);
 75        serde_json::to_value(&schema).unwrap()
 76    }
 77
 78    fn run(
 79        self: Arc<Self>,
 80        input: serde_json::Value,
 81        messages: &[LanguageModelRequestMessage],
 82        project: Entity<Project>,
 83        action_log: Entity<ActionLog>,
 84        cx: &mut App,
 85    ) -> Task<Result<String>> {
 86        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
 87            Ok(input) => input,
 88            Err(err) => return Task::ready(Err(anyhow!(err))),
 89        };
 90
 91        match EditToolLog::try_global(cx) {
 92            Some(log) => {
 93                let req_id = log.update(cx, |log, cx| {
 94                    log.new_request(input.edit_instructions.clone(), cx)
 95                });
 96
 97                let task = EditToolRequest::new(
 98                    input,
 99                    messages,
100                    project,
101                    action_log,
102                    Some((log.clone(), req_id)),
103                    cx,
104                );
105
106                cx.spawn(|mut cx| async move {
107                    let result = task.await;
108
109                    let str_result = match &result {
110                        Ok(out) => Ok(out.clone()),
111                        Err(err) => Err(err.to_string()),
112                    };
113
114                    log.update(&mut cx, |log, cx| {
115                        log.set_tool_output(req_id, str_result, cx)
116                    })
117                    .log_err();
118
119                    result
120                })
121            }
122
123            None => EditToolRequest::new(input, messages, project, action_log, None, cx),
124        }
125    }
126}
127
128struct EditToolRequest {
129    parser: EditActionParser,
130    changed_buffers: HashSet<Entity<language::Buffer>>,
131    bad_searches: Vec<BadSearch>,
132    project: Entity<Project>,
133    action_log: Entity<ActionLog>,
134    tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
135}
136
137#[derive(Debug)]
138enum DiffResult {
139    BadSearch(BadSearch),
140    Diff(language::Diff),
141}
142
143#[derive(Debug)]
144struct BadSearch {
145    file_path: String,
146    search: String,
147}
148
149impl EditToolRequest {
150    fn new(
151        input: EditFilesToolInput,
152        messages: &[LanguageModelRequestMessage],
153        project: Entity<Project>,
154        action_log: Entity<ActionLog>,
155        tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
156        cx: &mut App,
157    ) -> Task<Result<String>> {
158        let model_registry = LanguageModelRegistry::read_global(cx);
159        let Some(model) = model_registry.editor_model() else {
160            return Task::ready(Err(anyhow!("No editor model configured")));
161        };
162
163        let mut messages = messages.to_vec();
164        // Remove the last tool use (this run) to prevent an invalid request
165        'outer: for message in messages.iter_mut().rev() {
166            for (index, content) in message.content.iter().enumerate().rev() {
167                match content {
168                    MessageContent::ToolUse(_) => {
169                        message.content.remove(index);
170                        break 'outer;
171                    }
172                    MessageContent::ToolResult(_) => {
173                        // If we find any tool results before a tool use, the request is already valid
174                        break 'outer;
175                    }
176                    MessageContent::Text(_) | MessageContent::Image(_) => {}
177                }
178            }
179        }
180
181        messages.push(LanguageModelRequestMessage {
182            role: Role::User,
183            content: vec![
184                include_str!("./edit_files_tool/edit_prompt.md").into(),
185                input.edit_instructions.into(),
186            ],
187            cache: false,
188        });
189
190        cx.spawn(|mut cx| async move {
191            let llm_request = LanguageModelRequest {
192                messages,
193                tools: vec![],
194                stop: vec![],
195                temperature: Some(0.0),
196            };
197
198            let stream = model.stream_completion_text(llm_request, &cx);
199            let mut chunks = stream.await?;
200
201            let mut request = Self {
202                parser: EditActionParser::new(),
203                changed_buffers: HashSet::default(),
204                bad_searches: Vec::new(),
205                action_log,
206                project,
207                tool_log,
208            };
209
210            while let Some(chunk) = chunks.stream.next().await {
211                request.process_response_chunk(&chunk?, &mut cx).await?;
212            }
213
214            request.finalize(&mut cx).await
215        })
216    }
217
218    async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
219        let new_actions = self.parser.parse_chunk(chunk);
220
221        if let Some((ref log, req_id)) = self.tool_log {
222            log.update(cx, |log, cx| {
223                log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
224            })
225            .log_err();
226        }
227
228        for action in new_actions {
229            self.apply_action(action, cx).await?;
230        }
231
232        Ok(())
233    }
234
235    async fn apply_action(&mut self, action: EditAction, cx: &mut AsyncApp) -> Result<()> {
236        let project_path = self.project.read_with(cx, |project, cx| {
237            project
238                .find_project_path(action.file_path(), cx)
239                .context("Path not found in project")
240        })??;
241
242        let buffer = self
243            .project
244            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
245            .await?;
246
247        let result = match action {
248            EditAction::Replace {
249                old,
250                new,
251                file_path,
252            } => {
253                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
254
255                cx.background_executor()
256                    .spawn(Self::replace_diff(old, new, file_path, snapshot))
257                    .await
258            }
259            EditAction::Write { content, .. } => Ok(DiffResult::Diff(
260                buffer
261                    .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
262                    .await,
263            )),
264        }?;
265
266        match result {
267            DiffResult::BadSearch(invalid_replace) => {
268                self.bad_searches.push(invalid_replace);
269            }
270            DiffResult::Diff(diff) => {
271                let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
272
273                self.changed_buffers.insert(buffer);
274            }
275        }
276
277        Ok(())
278    }
279
280    async fn replace_diff(
281        old: String,
282        new: String,
283        file_path: std::path::PathBuf,
284        snapshot: language::BufferSnapshot,
285    ) -> Result<DiffResult> {
286        let query = SearchQuery::text(
287            old.clone(),
288            false,
289            true,
290            true,
291            PathMatcher::new(&[])?,
292            PathMatcher::new(&[])?,
293            None,
294        )?;
295
296        let matches = query.search(&snapshot, None).await;
297
298        if matches.is_empty() {
299            return Ok(DiffResult::BadSearch(BadSearch {
300                search: new.clone(),
301                file_path: file_path.display().to_string(),
302            }));
303        }
304
305        let edit_range = matches[0].clone();
306        let diff = language::text_diff(&old, &new);
307
308        let edits = diff
309            .into_iter()
310            .map(|(old_range, text)| {
311                let start = edit_range.start + old_range.start;
312                let end = edit_range.start + old_range.end;
313                (start..end, text)
314            })
315            .collect::<Vec<_>>();
316
317        let diff = language::Diff {
318            base_version: snapshot.version().clone(),
319            line_ending: snapshot.line_ending(),
320            edits,
321        };
322
323        anyhow::Ok(DiffResult::Diff(diff))
324    }
325
326    async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
327        let mut answer = match self.changed_buffers.len() {
328            0 => "No files were edited.".to_string(),
329            1 => "Successfully edited ".to_string(),
330            _ => "Successfully edited these files:\n\n".to_string(),
331        };
332
333        // Save each buffer once at the end
334        for buffer in &self.changed_buffers {
335            let (path, save_task) = self.project.update(cx, |project, cx| {
336                let path = buffer
337                    .read(cx)
338                    .file()
339                    .map(|file| file.path().display().to_string());
340
341                let task = project.save_buffer(buffer.clone(), cx);
342
343                (path, task)
344            })?;
345
346            save_task.await?;
347
348            if let Some(path) = path {
349                writeln!(&mut answer, "{}", path)?;
350            }
351        }
352
353        self.action_log
354            .update(cx, |log, cx| {
355                log.notify_buffers_changed(self.changed_buffers, cx)
356            })
357            .log_err();
358
359        let errors = self.parser.errors();
360
361        if errors.is_empty() && self.bad_searches.is_empty() {
362            let answer = answer.trim_end().to_string();
363            Ok(answer)
364        } else {
365            if !self.bad_searches.is_empty() {
366                writeln!(
367                    &mut answer,
368                    "\nThese searches failed because they didn't match any strings:"
369                )?;
370
371                for replace in self.bad_searches {
372                    writeln!(
373                        &mut answer,
374                        "- '{}' does not appear in `{}`",
375                        replace.search.replace("\r", "\\r").replace("\n", "\\n"),
376                        replace.file_path
377                    )?;
378                }
379
380                writeln!(&mut answer, "Make sure to use exact searches.")?;
381            }
382
383            if !errors.is_empty() {
384                writeln!(
385                    &mut answer,
386                    "\nThese SEARCH/REPLACE blocks failed to parse:"
387                )?;
388
389                for error in errors {
390                    writeln!(&mut answer, "- {}", error)?;
391                }
392            }
393
394            writeln!(
395                &mut answer,
396                "\nYou can fix errors by running the tool again. You can include instructions,\
397                but errors are part of the conversation so you don't need to repeat them."
398            )?;
399
400            Err(anyhow!(answer.trim_end().to_string()))
401        }
402    }
403}