1mod edit_action;
2pub mod log;
3
4use anyhow::{anyhow, Context, Result};
5use assistant_tool::{ActionLog, Tool};
6use collections::HashSet;
7use edit_action::{EditAction, EditActionParser};
8use futures::StreamExt;
9use gpui::{App, AsyncApp, Entity, Task};
10use language_model::{
11 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
12};
13use log::{EditToolLog, EditToolRequestId};
14use project::{search::SearchQuery, Project};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use std::fmt::Write;
18use std::sync::Arc;
19use util::paths::PathMatcher;
20use util::ResultExt;
21
22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
23pub struct EditFilesToolInput {
24 /// High-level edit instructions. These will be interpreted by a smaller
25 /// model, so explain the changes you want that model to make and which
26 /// file paths need changing.
27 ///
28 /// The description should be concise and clear. We will show this
29 /// description to the user as well.
30 ///
31 /// WARNING: When specifying which file paths need changing, you MUST
32 /// start each path with one of the project's root directories.
33 ///
34 /// WARNING: NEVER include code blocks or snippets in edit instructions.
35 /// Only provide natural language descriptions of the changes needed! The tool will
36 /// reject any instructions that contain code blocks or snippets.
37 ///
38 /// The following examples assume we have two root directories in the project:
39 /// - root-1
40 /// - root-2
41 ///
42 /// <example>
43 /// If you want to introduce a new quit function to kill the process, your
44 /// instructions should be: "Add a new `quit` function to
45 /// `root-1/src/main.rs` to kill the process".
46 ///
47 /// Notice how the file path starts with root-1. Without that, the path
48 /// would be ambiguous and the call would fail!
49 /// </example>
50 ///
51 /// <example>
52 /// If you want to change documentation to always start with a capital
53 /// letter, your instructions should be: "In `root-2/db.js`,
54 /// `root-2/inMemory.js` and `root-2/sql.js`, change all the documentation
55 /// to start with a capital letter".
56 ///
57 /// Notice how we never specify code snippets in the instructions!
58 /// </example>
59 pub edit_instructions: String,
60}
61
62pub struct EditFilesTool;
63
64impl Tool for EditFilesTool {
65 fn name(&self) -> String {
66 "edit-files".into()
67 }
68
69 fn description(&self) -> String {
70 include_str!("./edit_files_tool/description.md").into()
71 }
72
73 fn input_schema(&self) -> serde_json::Value {
74 let schema = schemars::schema_for!(EditFilesToolInput);
75 serde_json::to_value(&schema).unwrap()
76 }
77
78 fn run(
79 self: Arc<Self>,
80 input: serde_json::Value,
81 messages: &[LanguageModelRequestMessage],
82 project: Entity<Project>,
83 action_log: Entity<ActionLog>,
84 cx: &mut App,
85 ) -> Task<Result<String>> {
86 let input = match serde_json::from_value::<EditFilesToolInput>(input) {
87 Ok(input) => input,
88 Err(err) => return Task::ready(Err(anyhow!(err))),
89 };
90
91 match EditToolLog::try_global(cx) {
92 Some(log) => {
93 let req_id = log.update(cx, |log, cx| {
94 log.new_request(input.edit_instructions.clone(), cx)
95 });
96
97 let task = EditToolRequest::new(
98 input,
99 messages,
100 project,
101 action_log,
102 Some((log.clone(), req_id)),
103 cx,
104 );
105
106 cx.spawn(|mut cx| async move {
107 let result = task.await;
108
109 let str_result = match &result {
110 Ok(out) => Ok(out.clone()),
111 Err(err) => Err(err.to_string()),
112 };
113
114 log.update(&mut cx, |log, cx| {
115 log.set_tool_output(req_id, str_result, cx)
116 })
117 .log_err();
118
119 result
120 })
121 }
122
123 None => EditToolRequest::new(input, messages, project, action_log, None, cx),
124 }
125 }
126}
127
128struct EditToolRequest {
129 parser: EditActionParser,
130 changed_buffers: HashSet<Entity<language::Buffer>>,
131 bad_searches: Vec<BadSearch>,
132 project: Entity<Project>,
133 action_log: Entity<ActionLog>,
134 tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
135}
136
137#[derive(Debug)]
138enum DiffResult {
139 BadSearch(BadSearch),
140 Diff(language::Diff),
141}
142
143#[derive(Debug)]
144struct BadSearch {
145 file_path: String,
146 search: String,
147}
148
149impl EditToolRequest {
150 fn new(
151 input: EditFilesToolInput,
152 messages: &[LanguageModelRequestMessage],
153 project: Entity<Project>,
154 action_log: Entity<ActionLog>,
155 tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
156 cx: &mut App,
157 ) -> Task<Result<String>> {
158 let model_registry = LanguageModelRegistry::read_global(cx);
159 let Some(model) = model_registry.editor_model() else {
160 return Task::ready(Err(anyhow!("No editor model configured")));
161 };
162
163 let mut messages = messages.to_vec();
164 // Remove the last tool use (this run) to prevent an invalid request
165 'outer: for message in messages.iter_mut().rev() {
166 for (index, content) in message.content.iter().enumerate().rev() {
167 match content {
168 MessageContent::ToolUse(_) => {
169 message.content.remove(index);
170 break 'outer;
171 }
172 MessageContent::ToolResult(_) => {
173 // If we find any tool results before a tool use, the request is already valid
174 break 'outer;
175 }
176 MessageContent::Text(_) | MessageContent::Image(_) => {}
177 }
178 }
179 }
180
181 messages.push(LanguageModelRequestMessage {
182 role: Role::User,
183 content: vec![
184 include_str!("./edit_files_tool/edit_prompt.md").into(),
185 input.edit_instructions.into(),
186 ],
187 cache: false,
188 });
189
190 cx.spawn(|mut cx| async move {
191 let llm_request = LanguageModelRequest {
192 messages,
193 tools: vec![],
194 stop: vec![],
195 temperature: Some(0.0),
196 };
197
198 let stream = model.stream_completion_text(llm_request, &cx);
199 let mut chunks = stream.await?;
200
201 let mut request = Self {
202 parser: EditActionParser::new(),
203 changed_buffers: HashSet::default(),
204 bad_searches: Vec::new(),
205 action_log,
206 project,
207 tool_log,
208 };
209
210 while let Some(chunk) = chunks.stream.next().await {
211 request.process_response_chunk(&chunk?, &mut cx).await?;
212 }
213
214 request.finalize(&mut cx).await
215 })
216 }
217
218 async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
219 let new_actions = self.parser.parse_chunk(chunk);
220
221 if let Some((ref log, req_id)) = self.tool_log {
222 log.update(cx, |log, cx| {
223 log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
224 })
225 .log_err();
226 }
227
228 for action in new_actions {
229 self.apply_action(action, cx).await?;
230 }
231
232 Ok(())
233 }
234
235 async fn apply_action(&mut self, action: EditAction, cx: &mut AsyncApp) -> Result<()> {
236 let project_path = self.project.read_with(cx, |project, cx| {
237 project
238 .find_project_path(action.file_path(), cx)
239 .context("Path not found in project")
240 })??;
241
242 let buffer = self
243 .project
244 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
245 .await?;
246
247 let result = match action {
248 EditAction::Replace {
249 old,
250 new,
251 file_path,
252 } => {
253 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
254
255 cx.background_executor()
256 .spawn(Self::replace_diff(old, new, file_path, snapshot))
257 .await
258 }
259 EditAction::Write { content, .. } => Ok(DiffResult::Diff(
260 buffer
261 .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
262 .await,
263 )),
264 }?;
265
266 match result {
267 DiffResult::BadSearch(invalid_replace) => {
268 self.bad_searches.push(invalid_replace);
269 }
270 DiffResult::Diff(diff) => {
271 let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
272
273 self.changed_buffers.insert(buffer);
274 }
275 }
276
277 Ok(())
278 }
279
280 async fn replace_diff(
281 old: String,
282 new: String,
283 file_path: std::path::PathBuf,
284 snapshot: language::BufferSnapshot,
285 ) -> Result<DiffResult> {
286 let query = SearchQuery::text(
287 old.clone(),
288 false,
289 true,
290 true,
291 PathMatcher::new(&[])?,
292 PathMatcher::new(&[])?,
293 None,
294 )?;
295
296 let matches = query.search(&snapshot, None).await;
297
298 if matches.is_empty() {
299 return Ok(DiffResult::BadSearch(BadSearch {
300 search: new.clone(),
301 file_path: file_path.display().to_string(),
302 }));
303 }
304
305 let edit_range = matches[0].clone();
306 let diff = language::text_diff(&old, &new);
307
308 let edits = diff
309 .into_iter()
310 .map(|(old_range, text)| {
311 let start = edit_range.start + old_range.start;
312 let end = edit_range.start + old_range.end;
313 (start..end, text)
314 })
315 .collect::<Vec<_>>();
316
317 let diff = language::Diff {
318 base_version: snapshot.version().clone(),
319 line_ending: snapshot.line_ending(),
320 edits,
321 };
322
323 anyhow::Ok(DiffResult::Diff(diff))
324 }
325
326 async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
327 let mut answer = match self.changed_buffers.len() {
328 0 => "No files were edited.".to_string(),
329 1 => "Successfully edited ".to_string(),
330 _ => "Successfully edited these files:\n\n".to_string(),
331 };
332
333 // Save each buffer once at the end
334 for buffer in &self.changed_buffers {
335 let (path, save_task) = self.project.update(cx, |project, cx| {
336 let path = buffer
337 .read(cx)
338 .file()
339 .map(|file| file.path().display().to_string());
340
341 let task = project.save_buffer(buffer.clone(), cx);
342
343 (path, task)
344 })?;
345
346 save_task.await?;
347
348 if let Some(path) = path {
349 writeln!(&mut answer, "{}", path)?;
350 }
351 }
352
353 self.action_log
354 .update(cx, |log, cx| {
355 log.notify_buffers_changed(self.changed_buffers, cx)
356 })
357 .log_err();
358
359 let errors = self.parser.errors();
360
361 if errors.is_empty() && self.bad_searches.is_empty() {
362 let answer = answer.trim_end().to_string();
363 Ok(answer)
364 } else {
365 if !self.bad_searches.is_empty() {
366 writeln!(
367 &mut answer,
368 "\nThese searches failed because they didn't match any strings:"
369 )?;
370
371 for replace in self.bad_searches {
372 writeln!(
373 &mut answer,
374 "- '{}' does not appear in `{}`",
375 replace.search.replace("\r", "\\r").replace("\n", "\\n"),
376 replace.file_path
377 )?;
378 }
379
380 writeln!(&mut answer, "Make sure to use exact searches.")?;
381 }
382
383 if !errors.is_empty() {
384 writeln!(
385 &mut answer,
386 "\nThese SEARCH/REPLACE blocks failed to parse:"
387 )?;
388
389 for error in errors {
390 writeln!(&mut answer, "- {}", error)?;
391 }
392 }
393
394 writeln!(
395 &mut answer,
396 "\nYou can fix errors by running the tool again. You can include instructions,\
397 but errors are part of the conversation so you don't need to repeat them."
398 )?;
399
400 Err(anyhow!(answer.trim_end().to_string()))
401 }
402 }
403}