1mod edit_action;
2
3use collections::HashSet;
4use std::{path::Path, sync::Arc};
5
6use anyhow::{anyhow, Result};
7use assistant_tool::Tool;
8use edit_action::{EditAction, EditActionParser};
9use futures::StreamExt;
10use gpui::{App, Entity, Task};
11use language_model::{
12 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
13};
14use project::{Project, ProjectPath, WorktreeId};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
19pub struct EditFilesToolInput {
20 /// The ID of the worktree in which the files reside.
21 pub worktree_id: usize,
22 /// Instruct how to modify the files.
23 pub edit_instructions: String,
24}
25
26pub struct EditFilesTool;
27
28impl Tool for EditFilesTool {
29 fn name(&self) -> String {
30 "edit-files".into()
31 }
32
33 fn description(&self) -> String {
34 include_str!("./edit_files_tool/description.md").into()
35 }
36
37 fn input_schema(&self) -> serde_json::Value {
38 let schema = schemars::schema_for!(EditFilesToolInput);
39 serde_json::to_value(&schema).unwrap()
40 }
41
42 fn run(
43 self: Arc<Self>,
44 input: serde_json::Value,
45 messages: &[LanguageModelRequestMessage],
46 project: Entity<Project>,
47 cx: &mut App,
48 ) -> Task<Result<String>> {
49 let input = match serde_json::from_value::<EditFilesToolInput>(input) {
50 Ok(input) => input,
51 Err(err) => return Task::ready(Err(anyhow!(err))),
52 };
53
54 let model_registry = LanguageModelRegistry::read_global(cx);
55 let Some(model) = model_registry.editor_model() else {
56 return Task::ready(Err(anyhow!("No editor model configured")));
57 };
58
59 let mut messages = messages.to_vec();
60 if let Some(last_message) = messages.last_mut() {
61 // Strip out tool use from the last message because we're in the middle of executing a tool call.
62 last_message
63 .content
64 .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
65 }
66 messages.push(LanguageModelRequestMessage {
67 role: Role::User,
68 content: vec![
69 include_str!("./edit_files_tool/edit_prompt.md").into(),
70 input.edit_instructions.into(),
71 ],
72 cache: false,
73 });
74
75 cx.spawn(|mut cx| async move {
76 let request = LanguageModelRequest {
77 messages,
78 tools: vec![],
79 stop: vec![],
80 temperature: None,
81 };
82
83 let mut parser = EditActionParser::new();
84
85 let stream = model.stream_completion_text(request, &cx);
86 let mut chunks = stream.await?;
87
88 let mut changed_buffers = HashSet::default();
89 let mut applied_edits = 0;
90
91 while let Some(chunk) = chunks.stream.next().await {
92 for action in parser.parse_chunk(&chunk?) {
93 let project_path = ProjectPath {
94 worktree_id: WorktreeId::from_usize(input.worktree_id),
95 path: Path::new(action.file_path()).into(),
96 };
97
98 let buffer = project
99 .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
100 .await?;
101
102 let diff = buffer
103 .read_with(&cx, |buffer, cx| {
104 let new_text = match action {
105 EditAction::Replace { old, new, .. } => {
106 // TODO: Replace in background?
107 buffer.text().replace(&old, &new)
108 }
109 EditAction::Write { content, .. } => content,
110 };
111
112 buffer.diff(new_text, cx)
113 })?
114 .await;
115
116 let _clock =
117 buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
118
119 changed_buffers.insert(buffer);
120
121 applied_edits += 1;
122 }
123 }
124
125 // Save each buffer once at the end
126 for buffer in changed_buffers {
127 project
128 .update(&mut cx, |project, cx| project.save_buffer(buffer, cx))?
129 .await?;
130 }
131
132 let errors = parser.errors();
133
134 if errors.is_empty() {
135 Ok("Successfully applied all edits".into())
136 } else {
137 let error_message = errors
138 .iter()
139 .map(|e| e.to_string())
140 .collect::<Vec<_>>()
141 .join("\n");
142
143 if applied_edits > 0 {
144 Err(anyhow!(
145 "Applied {} edit(s), but some blocks failed to parse:\n{}",
146 applied_edits,
147 error_message
148 ))
149 } else {
150 Err(anyhow!(error_message))
151 }
152 }
153 })
154 }
155}