edit_files_tool.rs

  1mod edit_action;
  2pub mod log;
  3
  4use anyhow::{anyhow, Context, Result};
  5use assistant_tool::Tool;
  6use collections::HashSet;
  7use edit_action::{EditAction, EditActionParser};
  8use futures::StreamExt;
  9use gpui::{App, AsyncApp, Entity, Task};
 10use language_model::{
 11    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 12};
 13use log::{EditToolLog, EditToolRequestId};
 14use project::{search::SearchQuery, Project};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use std::fmt::Write;
 18use std::sync::Arc;
 19use util::paths::PathMatcher;
 20use util::ResultExt;
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct EditFilesToolInput {
 24    /// High-level edit instructions. These will be interpreted by a smaller
 25    /// model, so explain the changes you want that model to make and which
 26    /// file paths need changing.
 27    ///
 28    /// The description should be concise and clear. We will show this
 29    /// description to the user as well.
 30    ///
 31    /// WARNING: When specifying which file paths need changing, you MUST
 32    /// start each path with one of the project's root directories.
 33    ///
 34    /// WARNING: NEVER include code blocks or snippets in edit instructions.
 35    /// Only provide natural language descriptions of the changes needed! The tool will
 36    /// reject any instructions that contain code blocks or snippets.
 37    ///
 38    /// The following examples assume we have two root directories in the project:
 39    /// - root-1
 40    /// - root-2
 41    ///
 42    /// <example>
 43    /// If you want to introduce a new quit function to kill the process, your
 44    /// instructions should be: "Add a new `quit` function to
 45    /// `root-1/src/main.rs` to kill the process".
 46    ///
 47    /// Notice how the file path starts with root-1. Without that, the path
 48    /// would be ambiguous and the call would fail!
 49    /// </example>
 50    ///
 51    /// <example>
 52    /// If you want to change documentation to always start with a capital
 53    /// letter, your instructions should be: "In `root-2/db.js`,
 54    /// `root-2/inMemory.js` and `root-2/sql.js`, change all the documentation
 55    /// to start with a capital letter".
 56    ///
 57    /// Notice how we never specify code snippets in the instructions!
 58    /// </example>
 59    pub edit_instructions: String,
 60}
 61
 62pub struct EditFilesTool;
 63
 64impl Tool for EditFilesTool {
 65    fn name(&self) -> String {
 66        "edit-files".into()
 67    }
 68
 69    fn description(&self) -> String {
 70        include_str!("./edit_files_tool/description.md").into()
 71    }
 72
 73    fn input_schema(&self) -> serde_json::Value {
 74        let schema = schemars::schema_for!(EditFilesToolInput);
 75        serde_json::to_value(&schema).unwrap()
 76    }
 77
 78    fn run(
 79        self: Arc<Self>,
 80        input: serde_json::Value,
 81        messages: &[LanguageModelRequestMessage],
 82        project: Entity<Project>,
 83        cx: &mut App,
 84    ) -> Task<Result<String>> {
 85        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
 86            Ok(input) => input,
 87            Err(err) => return Task::ready(Err(anyhow!(err))),
 88        };
 89
 90        match EditToolLog::try_global(cx) {
 91            Some(log) => {
 92                let req_id = log.update(cx, |log, cx| {
 93                    log.new_request(input.edit_instructions.clone(), cx)
 94                });
 95
 96                let task =
 97                    EditToolRequest::new(input, messages, project, Some((log.clone(), req_id)), cx);
 98
 99                cx.spawn(|mut cx| async move {
100                    let result = task.await;
101
102                    let str_result = match &result {
103                        Ok(out) => Ok(out.clone()),
104                        Err(err) => Err(err.to_string()),
105                    };
106
107                    log.update(&mut cx, |log, cx| {
108                        log.set_tool_output(req_id, str_result, cx)
109                    })
110                    .log_err();
111
112                    result
113                })
114            }
115
116            None => EditToolRequest::new(input, messages, project, None, cx),
117        }
118    }
119}
120
121struct EditToolRequest {
122    parser: EditActionParser,
123    changed_buffers: HashSet<Entity<language::Buffer>>,
124    bad_searches: Vec<BadSearch>,
125    project: Entity<Project>,
126    log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
127}
128
129#[derive(Debug)]
130enum DiffResult {
131    BadSearch(BadSearch),
132    Diff(language::Diff),
133}
134
135#[derive(Debug)]
136struct BadSearch {
137    file_path: String,
138    search: String,
139}
140
141impl EditToolRequest {
142    fn new(
143        input: EditFilesToolInput,
144        messages: &[LanguageModelRequestMessage],
145        project: Entity<Project>,
146        log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
147        cx: &mut App,
148    ) -> Task<Result<String>> {
149        let model_registry = LanguageModelRegistry::read_global(cx);
150        let Some(model) = model_registry.editor_model() else {
151            return Task::ready(Err(anyhow!("No editor model configured")));
152        };
153
154        let mut messages = messages.to_vec();
155        if let Some(last_message) = messages.last_mut() {
156            // Strip out tool use from the last message because we're in the middle of executing a tool call.
157            last_message
158                .content
159                .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
160        }
161        messages.push(LanguageModelRequestMessage {
162            role: Role::User,
163            content: vec![
164                include_str!("./edit_files_tool/edit_prompt.md").into(),
165                input.edit_instructions.into(),
166            ],
167            cache: false,
168        });
169
170        cx.spawn(|mut cx| async move {
171            let llm_request = LanguageModelRequest {
172                messages,
173                tools: vec![],
174                stop: vec![],
175                temperature: Some(0.0),
176            };
177
178            let stream = model.stream_completion_text(llm_request, &cx);
179            let mut chunks = stream.await?;
180
181            let mut request = Self {
182                parser: EditActionParser::new(),
183                changed_buffers: HashSet::default(),
184                bad_searches: Vec::new(),
185                project,
186                log,
187            };
188
189            while let Some(chunk) = chunks.stream.next().await {
190                request.process_response_chunk(&chunk?, &mut cx).await?;
191            }
192
193            request.finalize(&mut cx).await
194        })
195    }
196
197    async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
198        let new_actions = self.parser.parse_chunk(chunk);
199
200        if let Some((ref log, req_id)) = self.log {
201            log.update(cx, |log, cx| {
202                log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
203            })
204            .log_err();
205        }
206
207        for action in new_actions {
208            self.apply_action(action, cx).await?;
209        }
210
211        Ok(())
212    }
213
214    async fn apply_action(&mut self, action: EditAction, cx: &mut AsyncApp) -> Result<()> {
215        let project_path = self.project.read_with(cx, |project, cx| {
216            project
217                .find_project_path(action.file_path(), cx)
218                .context("Path not found in project")
219        })??;
220
221        let buffer = self
222            .project
223            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
224            .await?;
225
226        let result = match action {
227            EditAction::Replace {
228                old,
229                new,
230                file_path,
231            } => {
232                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
233
234                cx.background_executor()
235                    .spawn(Self::replace_diff(old, new, file_path, snapshot))
236                    .await
237            }
238            EditAction::Write { content, .. } => Ok(DiffResult::Diff(
239                buffer
240                    .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
241                    .await,
242            )),
243        }?;
244
245        match result {
246            DiffResult::BadSearch(invalid_replace) => {
247                self.bad_searches.push(invalid_replace);
248            }
249            DiffResult::Diff(diff) => {
250                let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
251
252                self.changed_buffers.insert(buffer);
253            }
254        }
255
256        Ok(())
257    }
258
259    async fn replace_diff(
260        old: String,
261        new: String,
262        file_path: std::path::PathBuf,
263        snapshot: language::BufferSnapshot,
264    ) -> Result<DiffResult> {
265        let query = SearchQuery::text(
266            old.clone(),
267            false,
268            true,
269            true,
270            PathMatcher::new(&[])?,
271            PathMatcher::new(&[])?,
272            None,
273        )?;
274
275        let matches = query.search(&snapshot, None).await;
276
277        if matches.is_empty() {
278            return Ok(DiffResult::BadSearch(BadSearch {
279                search: new.clone(),
280                file_path: file_path.display().to_string(),
281            }));
282        }
283
284        let edit_range = matches[0].clone();
285        let diff = language::text_diff(&old, &new);
286
287        let edits = diff
288            .into_iter()
289            .map(|(old_range, text)| {
290                let start = edit_range.start + old_range.start;
291                let end = edit_range.start + old_range.end;
292                (start..end, text)
293            })
294            .collect::<Vec<_>>();
295
296        let diff = language::Diff {
297            base_version: snapshot.version().clone(),
298            line_ending: snapshot.line_ending(),
299            edits,
300        };
301
302        anyhow::Ok(DiffResult::Diff(diff))
303    }
304
305    async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
306        let mut answer = match self.changed_buffers.len() {
307            0 => "No files were edited.".to_string(),
308            1 => "Successfully edited ".to_string(),
309            _ => "Successfully edited these files:\n\n".to_string(),
310        };
311
312        // Save each buffer once at the end
313        for buffer in self.changed_buffers {
314            let (path, save_task) = self.project.update(cx, |project, cx| {
315                let path = buffer
316                    .read(cx)
317                    .file()
318                    .map(|file| file.path().display().to_string());
319
320                let task = project.save_buffer(buffer.clone(), cx);
321
322                (path, task)
323            })?;
324
325            save_task.await?;
326
327            if let Some(path) = path {
328                writeln!(&mut answer, "{}", path)?;
329            }
330        }
331
332        let errors = self.parser.errors();
333
334        if errors.is_empty() && self.bad_searches.is_empty() {
335            Ok(answer.trim_end().to_string())
336        } else {
337            if !self.bad_searches.is_empty() {
338                writeln!(
339                    &mut answer,
340                    "\nThese searches failed because they didn't match any strings:"
341                )?;
342
343                for replace in self.bad_searches {
344                    writeln!(
345                        &mut answer,
346                        "- '{}' does not appear in `{}`",
347                        replace.search.replace("\r", "\\r").replace("\n", "\\n"),
348                        replace.file_path
349                    )?;
350                }
351
352                writeln!(&mut answer, "Make sure to use exact searches.")?;
353            }
354
355            if !errors.is_empty() {
356                writeln!(
357                    &mut answer,
358                    "\nThese SEARCH/REPLACE blocks failed to parse:"
359                )?;
360
361                for error in errors {
362                    writeln!(&mut answer, "- {}", error)?;
363                }
364            }
365
366            writeln!(
367                &mut answer,
368                "\nYou can fix errors by running the tool again. You can include instructions,\
369                but errors are part of the conversation so you don't need to repeat them."
370            )?;
371
372            Err(anyhow!(answer))
373        }
374    }
375}