1mod edit_action;
2pub mod log;
3
4use anyhow::{anyhow, Context, Result};
5use assistant_tool::Tool;
6use collections::HashSet;
7use edit_action::{EditAction, EditActionParser};
8use futures::StreamExt;
9use gpui::{App, Entity, Task};
10use language_model::{
11 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
12};
13use log::{EditToolLog, EditToolRequestId};
14use project::{Project, ProjectPath};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use std::fmt::Write;
18use std::sync::Arc;
19use util::ResultExt;
20
21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
22pub struct EditFilesToolInput {
23 /// High-level edit instructions. These will be interpreted by a smaller model,
24 /// so explain the edits you want that model to make and to which files need changing.
25 /// The description should be concise and clear. We will show this description to the user
26 /// as well.
27 ///
28 /// <example>
29 /// If you want to rename a function you can say "Rename the function 'foo' to 'bar'".
30 /// </example>
31 ///
32 /// <example>
33 /// If you want to add a new function you can say "Add a new method to the `User` struct that prints the age".
34 /// </example>
35 pub edit_instructions: String,
36}
37
38pub struct EditFilesTool;
39
40impl Tool for EditFilesTool {
41 fn name(&self) -> String {
42 "edit-files".into()
43 }
44
45 fn description(&self) -> String {
46 include_str!("./edit_files_tool/description.md").into()
47 }
48
49 fn input_schema(&self) -> serde_json::Value {
50 let schema = schemars::schema_for!(EditFilesToolInput);
51 serde_json::to_value(&schema).unwrap()
52 }
53
54 fn run(
55 self: Arc<Self>,
56 input: serde_json::Value,
57 messages: &[LanguageModelRequestMessage],
58 project: Entity<Project>,
59 cx: &mut App,
60 ) -> Task<Result<String>> {
61 let input = match serde_json::from_value::<EditFilesToolInput>(input) {
62 Ok(input) => input,
63 Err(err) => return Task::ready(Err(anyhow!(err))),
64 };
65
66 match EditToolLog::try_global(cx) {
67 Some(log) => {
68 let req_id = log.update(cx, |log, cx| {
69 log.new_request(input.edit_instructions.clone(), cx)
70 });
71
72 let task =
73 EditFilesTool::run(input, messages, project, Some((log.clone(), req_id)), cx);
74
75 cx.spawn(|mut cx| async move {
76 let result = task.await;
77
78 let str_result = match &result {
79 Ok(out) => Ok(out.clone()),
80 Err(err) => Err(err.to_string()),
81 };
82
83 log.update(&mut cx, |log, cx| {
84 log.set_tool_output(req_id, str_result, cx)
85 })
86 .log_err();
87
88 result
89 })
90 }
91
92 None => EditFilesTool::run(input, messages, project, None, cx),
93 }
94 }
95}
96
97impl EditFilesTool {
98 fn run(
99 input: EditFilesToolInput,
100 messages: &[LanguageModelRequestMessage],
101 project: Entity<Project>,
102 log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
103 cx: &mut App,
104 ) -> Task<Result<String>> {
105 let model_registry = LanguageModelRegistry::read_global(cx);
106 let Some(model) = model_registry.editor_model() else {
107 return Task::ready(Err(anyhow!("No editor model configured")));
108 };
109
110 let mut messages = messages.to_vec();
111 if let Some(last_message) = messages.last_mut() {
112 // Strip out tool use from the last message because we're in the middle of executing a tool call.
113 last_message
114 .content
115 .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
116 }
117 messages.push(LanguageModelRequestMessage {
118 role: Role::User,
119 content: vec![
120 include_str!("./edit_files_tool/edit_prompt.md").into(),
121 input.edit_instructions.into(),
122 ],
123 cache: false,
124 });
125
126 cx.spawn(|mut cx| async move {
127 let request = LanguageModelRequest {
128 messages,
129 tools: vec![],
130 stop: vec![],
131 temperature: None,
132 };
133
134 let mut parser = EditActionParser::new();
135
136 let stream = model.stream_completion_text(request, &cx);
137 let mut chunks = stream.await?;
138
139 let mut changed_buffers = HashSet::default();
140 let mut applied_edits = 0;
141
142 let log = log.clone();
143
144 while let Some(chunk) = chunks.stream.next().await {
145 let chunk = chunk?;
146
147 let new_actions = parser.parse_chunk(&chunk);
148
149 if let Some((ref log, req_id)) = log {
150 log.update(&mut cx, |log, cx| {
151 log.push_editor_response_chunk(req_id, &chunk, &new_actions, cx)
152 })
153 .log_err();
154 }
155
156 for action in new_actions {
157 let project_path = project.read_with(&cx, |project, cx| {
158 let worktree_root_name = action
159 .file_path()
160 .components()
161 .next()
162 .context("Invalid path")?;
163 let worktree = project
164 .worktree_for_root_name(
165 &worktree_root_name.as_os_str().to_string_lossy(),
166 cx,
167 )
168 .context("Directory not found in project")?;
169 anyhow::Ok(ProjectPath {
170 worktree_id: worktree.read(cx).id(),
171 path: Arc::from(
172 action.file_path().strip_prefix(worktree_root_name).unwrap(),
173 ),
174 })
175 })??;
176
177 let buffer = project
178 .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
179 .await?;
180
181 let diff = buffer
182 .read_with(&cx, |buffer, cx| {
183 let new_text = match action {
184 EditAction::Replace { old, new, .. } => {
185 // TODO: Replace in background?
186 buffer.text().replace(&old, &new)
187 }
188 EditAction::Write { content, .. } => content,
189 };
190
191 buffer.diff(new_text, cx)
192 })?
193 .await;
194
195 let _clock =
196 buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
197
198 changed_buffers.insert(buffer);
199
200 applied_edits += 1;
201 }
202 }
203
204 let mut answer = match changed_buffers.len() {
205 0 => "No files were edited.".to_string(),
206 1 => "Successfully edited ".to_string(),
207 _ => "Successfully edited these files:\n\n".to_string(),
208 };
209
210 // Save each buffer once at the end
211 for buffer in changed_buffers {
212 project
213 .update(&mut cx, |project, cx| {
214 if let Some(file) = buffer.read(&cx).file() {
215 let _ = writeln!(&mut answer, "{}", &file.path().display());
216 }
217
218 project.save_buffer(buffer, cx)
219 })?
220 .await?;
221 }
222
223 let errors = parser.errors();
224
225 if errors.is_empty() {
226 Ok(answer.trim_end().to_string())
227 } else {
228 let error_message = errors
229 .iter()
230 .map(|e| e.to_string())
231 .collect::<Vec<_>>()
232 .join("\n");
233
234 if applied_edits > 0 {
235 Err(anyhow!(
236 "Applied {} edit(s), but some blocks failed to parse:\n{}",
237 applied_edits,
238 error_message
239 ))
240 } else {
241 Err(anyhow!(error_message))
242 }
243 }
244 })
245 }
246}