edit_files_tool.rs

  1mod edit_action;
  2
  3use anyhow::{anyhow, Context, Result};
  4use assistant_tool::Tool;
  5use collections::HashSet;
  6use edit_action::{EditAction, EditActionParser};
  7use futures::StreamExt;
  8use gpui::{App, Entity, Task};
  9use language_model::{
 10    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 11};
 12use project::{Project, ProjectPath};
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use std::fmt::Write;
 16use std::sync::Arc;
 17
 18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 19pub struct EditFilesToolInput {
 20    /// High-level edit instructions. These will be interpreted by a smaller model,
 21    /// so explain the edits you want that model to make and to which files need changing.
 22    /// The description should be concise and clear. We will show this description to the user
 23    /// as well.
 24    ///
 25    /// <example>
 26    /// If you want to rename a function you can say "Rename the function 'foo' to 'bar'".
 27    /// </example>
 28    ///
 29    /// <example>
 30    /// If you want to add a new function you can say "Add a new method to the `User` struct that prints the age".
 31    /// </example>
 32    pub edit_instructions: String,
 33}
 34
 35pub struct EditFilesTool;
 36
 37impl Tool for EditFilesTool {
 38    fn name(&self) -> String {
 39        "edit-files".into()
 40    }
 41
 42    fn description(&self) -> String {
 43        include_str!("./edit_files_tool/description.md").into()
 44    }
 45
 46    fn input_schema(&self) -> serde_json::Value {
 47        let schema = schemars::schema_for!(EditFilesToolInput);
 48        serde_json::to_value(&schema).unwrap()
 49    }
 50
 51    fn run(
 52        self: Arc<Self>,
 53        input: serde_json::Value,
 54        messages: &[LanguageModelRequestMessage],
 55        project: Entity<Project>,
 56        cx: &mut App,
 57    ) -> Task<Result<String>> {
 58        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
 59            Ok(input) => input,
 60            Err(err) => return Task::ready(Err(anyhow!(err))),
 61        };
 62
 63        let model_registry = LanguageModelRegistry::read_global(cx);
 64        let Some(model) = model_registry.editor_model() else {
 65            return Task::ready(Err(anyhow!("No editor model configured")));
 66        };
 67
 68        let mut messages = messages.to_vec();
 69        if let Some(last_message) = messages.last_mut() {
 70            // Strip out tool use from the last message because we're in the middle of executing a tool call.
 71            last_message
 72                .content
 73                .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
 74        }
 75        messages.push(LanguageModelRequestMessage {
 76            role: Role::User,
 77            content: vec![
 78                include_str!("./edit_files_tool/edit_prompt.md").into(),
 79                input.edit_instructions.into(),
 80            ],
 81            cache: false,
 82        });
 83
 84        cx.spawn(|mut cx| async move {
 85            let request = LanguageModelRequest {
 86                messages,
 87                tools: vec![],
 88                stop: vec![],
 89                temperature: None,
 90            };
 91
 92            let mut parser = EditActionParser::new();
 93
 94            let stream = model.stream_completion_text(request, &cx);
 95            let mut chunks = stream.await?;
 96
 97            let mut changed_buffers = HashSet::default();
 98            let mut applied_edits = 0;
 99
100            while let Some(chunk) = chunks.stream.next().await {
101                for action in parser.parse_chunk(&chunk?) {
102                    let project_path = project.read_with(&cx, |project, cx| {
103                        let worktree_root_name = action
104                            .file_path()
105                            .components()
106                            .next()
107                            .context("Invalid path")?;
108                        let worktree = project
109                            .worktree_for_root_name(
110                                &worktree_root_name.as_os_str().to_string_lossy(),
111                                cx,
112                            )
113                            .context("Directory not found in project")?;
114                        anyhow::Ok(ProjectPath {
115                            worktree_id: worktree.read(cx).id(),
116                            path: Arc::from(
117                                action.file_path().strip_prefix(worktree_root_name).unwrap(),
118                            ),
119                        })
120                    })??;
121
122                    let buffer = project
123                        .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
124                        .await?;
125
126                    let diff = buffer
127                        .read_with(&cx, |buffer, cx| {
128                            let new_text = match action {
129                                EditAction::Replace { old, new, .. } => {
130                                    // TODO: Replace in background?
131                                    buffer.text().replace(&old, &new)
132                                }
133                                EditAction::Write { content, .. } => content,
134                            };
135
136                            buffer.diff(new_text, cx)
137                        })?
138                        .await;
139
140                    let _clock =
141                        buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
142
143                    changed_buffers.insert(buffer);
144
145                    applied_edits += 1;
146                }
147            }
148
149            let mut answer = match changed_buffers.len() {
150                0 => "No files were edited.".to_string(),
151                1 => "Successfully edited ".to_string(),
152                _ => "Successfully edited these files:\n\n".to_string(),
153            };
154
155            // Save each buffer once at the end
156            for buffer in changed_buffers {
157                project
158                    .update(&mut cx, |project, cx| {
159                        if let Some(file) = buffer.read(&cx).file() {
160                            let _ = write!(&mut answer, "{}\n\n", &file.path().display());
161                        }
162
163                        project.save_buffer(buffer, cx)
164                    })?
165                    .await?;
166            }
167
168            let errors = parser.errors();
169
170            if errors.is_empty() {
171                Ok(answer.trim_end().to_string())
172            } else {
173                let error_message = errors
174                    .iter()
175                    .map(|e| e.to_string())
176                    .collect::<Vec<_>>()
177                    .join("\n");
178
179                if applied_edits > 0 {
180                    Err(anyhow!(
181                        "Applied {} edit(s), but some blocks failed to parse:\n{}",
182                        applied_edits,
183                        error_message
184                    ))
185                } else {
186                    Err(anyhow!(error_message))
187                }
188            }
189        })
190    }
191}