edit_files_tool.rs

  1mod edit_action;
  2pub mod log;
  3
  4use anyhow::{anyhow, Context, Result};
  5use assistant_tool::Tool;
  6use collections::HashSet;
  7use edit_action::{EditAction, EditActionParser};
  8use futures::StreamExt;
  9use gpui::{App, Entity, Task};
 10use language_model::{
 11    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 12};
 13use log::{EditToolLog, EditToolRequestId};
 14use project::{Project, ProjectPath};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use std::fmt::Write;
 18use std::sync::Arc;
 19use util::ResultExt;
 20
 21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 22pub struct EditFilesToolInput {
 23    /// High-level edit instructions. These will be interpreted by a smaller model,
 24    /// so explain the edits you want that model to make and to which files need changing.
 25    /// The description should be concise and clear. We will show this description to the user
 26    /// as well.
 27    ///
 28    /// <example>
 29    /// If you want to rename a function you can say "Rename the function 'foo' to 'bar'".
 30    /// </example>
 31    ///
 32    /// <example>
 33    /// If you want to add a new function you can say "Add a new method to the `User` struct that prints the age".
 34    /// </example>
 35    pub edit_instructions: String,
 36}
 37
 38pub struct EditFilesTool;
 39
 40impl Tool for EditFilesTool {
 41    fn name(&self) -> String {
 42        "edit-files".into()
 43    }
 44
 45    fn description(&self) -> String {
 46        include_str!("./edit_files_tool/description.md").into()
 47    }
 48
 49    fn input_schema(&self) -> serde_json::Value {
 50        let schema = schemars::schema_for!(EditFilesToolInput);
 51        serde_json::to_value(&schema).unwrap()
 52    }
 53
 54    fn run(
 55        self: Arc<Self>,
 56        input: serde_json::Value,
 57        messages: &[LanguageModelRequestMessage],
 58        project: Entity<Project>,
 59        cx: &mut App,
 60    ) -> Task<Result<String>> {
 61        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
 62            Ok(input) => input,
 63            Err(err) => return Task::ready(Err(anyhow!(err))),
 64        };
 65
 66        match EditToolLog::try_global(cx) {
 67            Some(log) => {
 68                let req_id = log.update(cx, |log, cx| {
 69                    log.new_request(input.edit_instructions.clone(), cx)
 70                });
 71
 72                let task =
 73                    EditFilesTool::run(input, messages, project, Some((log.clone(), req_id)), cx);
 74
 75                cx.spawn(|mut cx| async move {
 76                    let result = task.await;
 77
 78                    let str_result = match &result {
 79                        Ok(out) => Ok(out.clone()),
 80                        Err(err) => Err(err.to_string()),
 81                    };
 82
 83                    log.update(&mut cx, |log, cx| {
 84                        log.set_tool_output(req_id, str_result, cx)
 85                    })
 86                    .log_err();
 87
 88                    result
 89                })
 90            }
 91
 92            None => EditFilesTool::run(input, messages, project, None, cx),
 93        }
 94    }
 95}
 96
 97impl EditFilesTool {
 98    fn run(
 99        input: EditFilesToolInput,
100        messages: &[LanguageModelRequestMessage],
101        project: Entity<Project>,
102        log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
103        cx: &mut App,
104    ) -> Task<Result<String>> {
105        let model_registry = LanguageModelRegistry::read_global(cx);
106        let Some(model) = model_registry.editor_model() else {
107            return Task::ready(Err(anyhow!("No editor model configured")));
108        };
109
110        let mut messages = messages.to_vec();
111        if let Some(last_message) = messages.last_mut() {
112            // Strip out tool use from the last message because we're in the middle of executing a tool call.
113            last_message
114                .content
115                .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
116        }
117        messages.push(LanguageModelRequestMessage {
118            role: Role::User,
119            content: vec![
120                include_str!("./edit_files_tool/edit_prompt.md").into(),
121                input.edit_instructions.into(),
122            ],
123            cache: false,
124        });
125
126        cx.spawn(|mut cx| async move {
127            let request = LanguageModelRequest {
128                messages,
129                tools: vec![],
130                stop: vec![],
131                temperature: None,
132            };
133
134            let mut parser = EditActionParser::new();
135
136            let stream = model.stream_completion_text(request, &cx);
137            let mut chunks = stream.await?;
138
139            let mut changed_buffers = HashSet::default();
140            let mut applied_edits = 0;
141
142            let log = log.clone();
143
144            while let Some(chunk) = chunks.stream.next().await {
145                let chunk = chunk?;
146
147                let new_actions = parser.parse_chunk(&chunk);
148
149                if let Some((ref log, req_id)) = log {
150                    log.update(&mut cx, |log, cx| {
151                        log.push_editor_response_chunk(req_id, &chunk, &new_actions, cx)
152                    })
153                    .log_err();
154                }
155
156                for action in new_actions {
157                    let project_path = project.read_with(&cx, |project, cx| {
158                        let worktree_root_name = action
159                            .file_path()
160                            .components()
161                            .next()
162                            .context("Invalid path")?;
163                        let worktree = project
164                            .worktree_for_root_name(
165                                &worktree_root_name.as_os_str().to_string_lossy(),
166                                cx,
167                            )
168                            .context("Directory not found in project")?;
169                        anyhow::Ok(ProjectPath {
170                            worktree_id: worktree.read(cx).id(),
171                            path: Arc::from(
172                                action.file_path().strip_prefix(worktree_root_name).unwrap(),
173                            ),
174                        })
175                    })??;
176
177                    let buffer = project
178                        .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
179                        .await?;
180
181                    let diff = buffer
182                        .read_with(&cx, |buffer, cx| {
183                            let new_text = match action {
184                                EditAction::Replace { old, new, .. } => {
185                                    // TODO: Replace in background?
186                                    buffer.text().replace(&old, &new)
187                                }
188                                EditAction::Write { content, .. } => content,
189                            };
190
191                            buffer.diff(new_text, cx)
192                        })?
193                        .await;
194
195                    let _clock =
196                        buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
197
198                    changed_buffers.insert(buffer);
199
200                    applied_edits += 1;
201                }
202            }
203
204            let mut answer = match changed_buffers.len() {
205                0 => "No files were edited.".to_string(),
206                1 => "Successfully edited ".to_string(),
207                _ => "Successfully edited these files:\n\n".to_string(),
208            };
209
210            // Save each buffer once at the end
211            for buffer in changed_buffers {
212                project
213                    .update(&mut cx, |project, cx| {
214                        if let Some(file) = buffer.read(&cx).file() {
215                            let _ = writeln!(&mut answer, "{}", &file.path().display());
216                        }
217
218                        project.save_buffer(buffer, cx)
219                    })?
220                    .await?;
221            }
222
223            let errors = parser.errors();
224
225            if errors.is_empty() {
226                Ok(answer.trim_end().to_string())
227            } else {
228                let error_message = errors
229                    .iter()
230                    .map(|e| e.to_string())
231                    .collect::<Vec<_>>()
232                    .join("\n");
233
234                if applied_edits > 0 {
235                    Err(anyhow!(
236                        "Applied {} edit(s), but some blocks failed to parse:\n{}",
237                        applied_edits,
238                        error_message
239                    ))
240                } else {
241                    Err(anyhow!(error_message))
242                }
243            }
244        })
245    }
246}