edit_files_tool.rs

  1mod edit_action;
  2
  3use collections::HashSet;
  4use std::{path::Path, sync::Arc};
  5
  6use anyhow::{anyhow, Result};
  7use assistant_tool::Tool;
  8use edit_action::{EditAction, EditActionParser};
  9use futures::StreamExt;
 10use gpui::{App, Entity, Task};
 11use language_model::{
 12    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 13};
 14use project::{Project, ProjectPath, WorktreeId};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17
 18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 19pub struct EditFilesToolInput {
 20    /// The ID of the worktree in which the files reside.
 21    pub worktree_id: usize,
 22    /// Instruct how to modify the files.
 23    pub edit_instructions: String,
 24}
 25
 26pub struct EditFilesTool;
 27
 28impl Tool for EditFilesTool {
 29    fn name(&self) -> String {
 30        "edit-files".into()
 31    }
 32
 33    fn description(&self) -> String {
 34        include_str!("./edit_files_tool/description.md").into()
 35    }
 36
 37    fn input_schema(&self) -> serde_json::Value {
 38        let schema = schemars::schema_for!(EditFilesToolInput);
 39        serde_json::to_value(&schema).unwrap()
 40    }
 41
 42    fn run(
 43        self: Arc<Self>,
 44        input: serde_json::Value,
 45        messages: &[LanguageModelRequestMessage],
 46        project: Entity<Project>,
 47        cx: &mut App,
 48    ) -> Task<Result<String>> {
 49        let input = match serde_json::from_value::<EditFilesToolInput>(input) {
 50            Ok(input) => input,
 51            Err(err) => return Task::ready(Err(anyhow!(err))),
 52        };
 53
 54        let model_registry = LanguageModelRegistry::read_global(cx);
 55        let Some(model) = model_registry.editor_model() else {
 56            return Task::ready(Err(anyhow!("No editor model configured")));
 57        };
 58
 59        let mut messages = messages.to_vec();
 60        if let Some(last_message) = messages.last_mut() {
 61            // Strip out tool use from the last message because we're in the middle of executing a tool call.
 62            last_message
 63                .content
 64                .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
 65        }
 66        messages.push(LanguageModelRequestMessage {
 67            role: Role::User,
 68            content: vec![
 69                include_str!("./edit_files_tool/edit_prompt.md").into(),
 70                input.edit_instructions.into(),
 71            ],
 72            cache: false,
 73        });
 74
 75        cx.spawn(|mut cx| async move {
 76            let request = LanguageModelRequest {
 77                messages,
 78                tools: vec![],
 79                stop: vec![],
 80                temperature: None,
 81            };
 82
 83            let mut parser = EditActionParser::new();
 84
 85            let stream = model.stream_completion_text(request, &cx);
 86            let mut chunks = stream.await?;
 87
 88            let mut changed_buffers = HashSet::default();
 89            let mut applied_edits = 0;
 90
 91            while let Some(chunk) = chunks.stream.next().await {
 92                for action in parser.parse_chunk(&chunk?) {
 93                    let project_path = ProjectPath {
 94                        worktree_id: WorktreeId::from_usize(input.worktree_id),
 95                        path: Path::new(action.file_path()).into(),
 96                    };
 97
 98                    let buffer = project
 99                        .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
100                        .await?;
101
102                    let diff = buffer
103                        .read_with(&cx, |buffer, cx| {
104                            let new_text = match action {
105                                EditAction::Replace { old, new, .. } => {
106                                    // TODO: Replace in background?
107                                    buffer.text().replace(&old, &new)
108                                }
109                                EditAction::Write { content, .. } => content,
110                            };
111
112                            buffer.diff(new_text, cx)
113                        })?
114                        .await;
115
116                    let _clock =
117                        buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
118
119                    changed_buffers.insert(buffer);
120
121                    applied_edits += 1;
122                }
123            }
124
125            // Save each buffer once at the end
126            for buffer in changed_buffers {
127                project
128                    .update(&mut cx, |project, cx| project.save_buffer(buffer, cx))?
129                    .await?;
130            }
131
132            let errors = parser.errors();
133
134            if errors.is_empty() {
135                Ok("Successfully applied all edits".into())
136            } else {
137                let error_message = errors
138                    .iter()
139                    .map(|e| e.to_string())
140                    .collect::<Vec<_>>()
141                    .join("\n");
142
143                if applied_edits > 0 {
144                    Err(anyhow!(
145                        "Applied {} edit(s), but some blocks failed to parse:\n{}",
146                        applied_edits,
147                        error_message
148                    ))
149                } else {
150                    Err(anyhow!(error_message))
151                }
152            }
153        })
154    }
155}