edit_files_tool.rs

  1mod edit_action;
  2
  3use anyhow::{anyhow, Context, Result};
  4use assistant_tool::Tool;
  5use collections::HashSet;
  6use edit_action::{EditAction, EditActionParser};
  7use futures::StreamExt;
  8use gpui::{App, Entity, Task};
  9use language_model::{
 10    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 11};
 12use project::{Project, ProjectPath};
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use std::sync::Arc;
 16
 17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 18pub struct EditFilesToolInput {
 19    /// High-level edit instructions. These will be interpreted by a smaller model,
 20    /// so explain the edits you want that model to make and to which files need changing.
 21    /// The description should be concise and clear. We will show this description to the user
 22    /// as well.
 23    ///
 24    /// <example>
 25    /// If you want to rename a function you can say "Rename the function 'foo' to 'bar'".
 26    /// </example>
 27    ///
 28    /// <example>
 29    /// If you want to add a new function you can say "Add a new method to the `User` struct that prints the age".
 30    /// </example>
 31    pub edit_instructions: String,
 32}
 33
 34pub struct EditFilesTool;
 35
 36impl Tool for EditFilesTool {
 37    fn name(&self) -> String {
 38        "edit-files".into()
 39    }
 40
 41    fn description(&self) -> String {
 42        include_str!("./edit_files_tool/description.md").into()
 43    }
 44
 45    fn input_schema(&self) -> serde_json::Value {
 46        let schema = schemars::schema_for!(EditFilesToolInput);
 47        serde_json::to_value(&schema).unwrap()
 48    }
 49
 50    fn run(
 51        self: Arc<Self>,
 52        input: serde_json::Value,
 53        messages: &[LanguageModelRequestMessage],
 54        project: Entity<Project>,
 55        cx: &mut App,
 56    ) -> Task<Result<String>> {
 57        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
 58            Ok(input) => input,
 59            Err(err) => return Task::ready(Err(anyhow!(err))),
 60        };
 61
 62        let model_registry = LanguageModelRegistry::read_global(cx);
 63        let Some(model) = model_registry.editor_model() else {
 64            return Task::ready(Err(anyhow!("No editor model configured")));
 65        };
 66
 67        let mut messages = messages.to_vec();
 68        if let Some(last_message) = messages.last_mut() {
 69            // Strip out tool use from the last message because we're in the middle of executing a tool call.
 70            last_message
 71                .content
 72                .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
 73        }
 74        messages.push(LanguageModelRequestMessage {
 75            role: Role::User,
 76            content: vec![
 77                include_str!("./edit_files_tool/edit_prompt.md").into(),
 78                input.edit_instructions.into(),
 79            ],
 80            cache: false,
 81        });
 82
 83        cx.spawn(|mut cx| async move {
 84            let request = LanguageModelRequest {
 85                messages,
 86                tools: vec![],
 87                stop: vec![],
 88                temperature: None,
 89            };
 90
 91            let mut parser = EditActionParser::new();
 92
 93            let stream = model.stream_completion_text(request, &cx);
 94            let mut chunks = stream.await?;
 95
 96            let mut changed_buffers = HashSet::default();
 97            let mut applied_edits = 0;
 98
 99            while let Some(chunk) = chunks.stream.next().await {
100                for action in parser.parse_chunk(&chunk?) {
101                    let project_path = project.read_with(&cx, |project, cx| {
102                        let worktree_root_name = action
103                            .file_path()
104                            .components()
105                            .next()
106                            .context("Invalid path")?;
107                        let worktree = project
108                            .worktree_for_root_name(
109                                &worktree_root_name.as_os_str().to_string_lossy(),
110                                cx,
111                            )
112                            .context("Directory not found in project")?;
113                        anyhow::Ok(ProjectPath {
114                            worktree_id: worktree.read(cx).id(),
115                            path: Arc::from(
116                                action.file_path().strip_prefix(worktree_root_name).unwrap(),
117                            ),
118                        })
119                    })??;
120
121                    let buffer = project
122                        .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
123                        .await?;
124
125                    let diff = buffer
126                        .read_with(&cx, |buffer, cx| {
127                            let new_text = match action {
128                                EditAction::Replace { old, new, .. } => {
129                                    // TODO: Replace in background?
130                                    buffer.text().replace(&old, &new)
131                                }
132                                EditAction::Write { content, .. } => content,
133                            };
134
135                            buffer.diff(new_text, cx)
136                        })?
137                        .await;
138
139                    let _clock =
140                        buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
141
142                    changed_buffers.insert(buffer);
143
144                    applied_edits += 1;
145                }
146            }
147
148            // Save each buffer once at the end
149            for buffer in changed_buffers {
150                project
151                    .update(&mut cx, |project, cx| project.save_buffer(buffer, cx))?
152                    .await?;
153            }
154
155            let errors = parser.errors();
156
157            if errors.is_empty() {
158                Ok("Successfully applied all edits".into())
159            } else {
160                let error_message = errors
161                    .iter()
162                    .map(|e| e.to_string())
163                    .collect::<Vec<_>>()
164                    .join("\n");
165
166                if applied_edits > 0 {
167                    Err(anyhow!(
168                        "Applied {} edit(s), but some blocks failed to parse:\n{}",
169                        applied_edits,
170                        error_message
171                    ))
172                } else {
173                    Err(anyhow!(error_message))
174                }
175            }
176        })
177    }
178}