edit_files_tool.rs

  1mod edit_action;
  2pub mod log;
  3
  4use anyhow::{anyhow, Context, Result};
  5use assistant_tool::Tool;
  6use collections::HashSet;
  7use edit_action::{EditAction, EditActionParser};
  8use futures::StreamExt;
  9use gpui::{App, Entity, Task};
 10use language_model::{
 11    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 12};
 13use log::{EditToolLog, EditToolRequestId};
 14use project::{Project, ProjectPath};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use std::fmt::Write;
 18use std::sync::Arc;
 19use util::ResultExt;
 20
 21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 22pub struct EditFilesToolInput {
 23    /// High-level edit instructions. These will be interpreted by a smaller
 24    /// model, so explain the changes you want that model to make and which
 25    /// file paths need changing.
 26    ///
 27    /// The description should be concise and clear. We will show this
 28    /// description to the user as well.
 29    ///
 30    /// WARNING: When specifying which file paths need changing, you MUST
 31    /// start each path with one of the project's root directories.
 32    ///
 33    /// WARNING: NEVER include code blocks or snippets in edit instructions.
 34    /// Only provide natural language descriptions of the changes needed! The tool will
 35    /// reject any instructions that contain code blocks or snippets.
 36    ///
 37    /// The following examples assume we have two root directories in the project:
 38    /// - root-1
 39    /// - root-2
 40    ///
 41    /// <example>
 42    /// If you want to introduce a new quit function to kill the process, your
 43    /// instructions should be: "Add a new `quit` function to
 44    /// `root-1/src/main.rs` to kill the process".
 45    ///
 46    /// Notice how the file path starts with root-1. Without that, the path
 47    /// would be ambiguous and the call would fail!
 48    /// </example>
 49    ///
 50    /// <example>
 51    /// If you want to change documentation to always start with a capital
 52    /// letter, your instructions should be: "In `root-2/db.js`,
 53    /// `root-2/inMemory.js` and `root-2/sql.js`, change all the documentation
 54    /// to start with a capital letter".
 55    ///
 56    /// Notice how we never specify code snippets in the instructions!
 57    /// </example>
 58    pub edit_instructions: String,
 59}
 60
 61pub struct EditFilesTool;
 62
 63impl Tool for EditFilesTool {
 64    fn name(&self) -> String {
 65        "edit-files".into()
 66    }
 67
 68    fn description(&self) -> String {
 69        include_str!("./edit_files_tool/description.md").into()
 70    }
 71
 72    fn input_schema(&self) -> serde_json::Value {
 73        let schema = schemars::schema_for!(EditFilesToolInput);
 74        serde_json::to_value(&schema).unwrap()
 75    }
 76
 77    fn run(
 78        self: Arc<Self>,
 79        input: serde_json::Value,
 80        messages: &[LanguageModelRequestMessage],
 81        project: Entity<Project>,
 82        cx: &mut App,
 83    ) -> Task<Result<String>> {
 84        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
 85            Ok(input) => input,
 86            Err(err) => return Task::ready(Err(anyhow!(err))),
 87        };
 88
 89        match EditToolLog::try_global(cx) {
 90            Some(log) => {
 91                let req_id = log.update(cx, |log, cx| {
 92                    log.new_request(input.edit_instructions.clone(), cx)
 93                });
 94
 95                let task =
 96                    EditFilesTool::run(input, messages, project, Some((log.clone(), req_id)), cx);
 97
 98                cx.spawn(|mut cx| async move {
 99                    let result = task.await;
100
101                    let str_result = match &result {
102                        Ok(out) => Ok(out.clone()),
103                        Err(err) => Err(err.to_string()),
104                    };
105
106                    log.update(&mut cx, |log, cx| {
107                        log.set_tool_output(req_id, str_result, cx)
108                    })
109                    .log_err();
110
111                    result
112                })
113            }
114
115            None => EditFilesTool::run(input, messages, project, None, cx),
116        }
117    }
118}
119
120impl EditFilesTool {
121    fn run(
122        input: EditFilesToolInput,
123        messages: &[LanguageModelRequestMessage],
124        project: Entity<Project>,
125        log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
126        cx: &mut App,
127    ) -> Task<Result<String>> {
128        let model_registry = LanguageModelRegistry::read_global(cx);
129        let Some(model) = model_registry.editor_model() else {
130            return Task::ready(Err(anyhow!("No editor model configured")));
131        };
132
133        let mut messages = messages.to_vec();
134        if let Some(last_message) = messages.last_mut() {
135            // Strip out tool use from the last message because we're in the middle of executing a tool call.
136            last_message
137                .content
138                .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
139        }
140        messages.push(LanguageModelRequestMessage {
141            role: Role::User,
142            content: vec![
143                include_str!("./edit_files_tool/edit_prompt.md").into(),
144                input.edit_instructions.into(),
145            ],
146            cache: false,
147        });
148
149        cx.spawn(|mut cx| async move {
150            let request = LanguageModelRequest {
151                messages,
152                tools: vec![],
153                stop: vec![],
154                temperature: Some(0.0),
155            };
156
157            let mut parser = EditActionParser::new();
158
159            let stream = model.stream_completion_text(request, &cx);
160            let mut chunks = stream.await?;
161
162            let mut changed_buffers = HashSet::default();
163            let mut applied_edits = 0;
164
165            let log = log.clone();
166
167            while let Some(chunk) = chunks.stream.next().await {
168                let chunk = chunk?;
169
170                let new_actions = parser.parse_chunk(&chunk);
171
172                if let Some((ref log, req_id)) = log {
173                    log.update(&mut cx, |log, cx| {
174                        log.push_editor_response_chunk(req_id, &chunk, &new_actions, cx)
175                    })
176                    .log_err();
177                }
178
179                for action in new_actions {
180                    let project_path = project.read_with(&cx, |project, cx| {
181                        let worktree_root_name = action
182                            .file_path()
183                            .components()
184                            .next()
185                            .context("Invalid path")?;
186                        let worktree = project
187                            .worktree_for_root_name(
188                                &worktree_root_name.as_os_str().to_string_lossy(),
189                                cx,
190                            )
191                            .context("Directory not found in project")?;
192                        anyhow::Ok(ProjectPath {
193                            worktree_id: worktree.read(cx).id(),
194                            path: Arc::from(
195                                action.file_path().strip_prefix(worktree_root_name).unwrap(),
196                            ),
197                        })
198                    })??;
199
200                    let buffer = project
201                        .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
202                        .await?;
203
204                    let diff = buffer
205                        .read_with(&cx, |buffer, cx| {
206                            let new_text = match action {
207                                EditAction::Replace {
208                                    file_path,
209                                    old,
210                                    new,
211                                } => {
212                                    // TODO: Replace in background?
213                                    let text = buffer.text();
214                                    if text.contains(&old) {
215                                        text.replace(&old, &new)
216                                    } else {
217                                        return Err(anyhow!(
218                                            "Could not find search text in {}",
219                                            file_path.display()
220                                        ));
221                                    }
222                                }
223                                EditAction::Write { content, .. } => content,
224                            };
225
226                            anyhow::Ok(buffer.diff(new_text, cx))
227                        })??
228                        .await;
229
230                    let _clock =
231                        buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
232
233                    changed_buffers.insert(buffer);
234
235                    applied_edits += 1;
236                }
237            }
238
239            let mut answer = match changed_buffers.len() {
240                0 => "No files were edited.".to_string(),
241                1 => "Successfully edited ".to_string(),
242                _ => "Successfully edited these files:\n\n".to_string(),
243            };
244
245            // Save each buffer once at the end
246            for buffer in changed_buffers {
247                project
248                    .update(&mut cx, |project, cx| {
249                        if let Some(file) = buffer.read(&cx).file() {
250                            let _ = writeln!(&mut answer, "{}", &file.full_path(cx).display());
251                        }
252
253                        project.save_buffer(buffer, cx)
254                    })?
255                    .await?;
256            }
257
258            let errors = parser.errors();
259
260            if errors.is_empty() {
261                Ok(answer.trim_end().to_string())
262            } else {
263                let error_message = errors
264                    .iter()
265                    .map(|e| e.to_string())
266                    .collect::<Vec<_>>()
267                    .join("\n");
268
269                if applied_edits > 0 {
270                    Err(anyhow!(
271                        "Applied {} edit(s), but some blocks failed to parse:\n{}",
272                        applied_edits,
273                        error_message
274                    ))
275                } else {
276                    Err(anyhow!(error_message))
277                }
278            }
279        })
280    }
281}