1mod edit_action;
2
3use anyhow::{anyhow, Context, Result};
4use assistant_tool::Tool;
5use collections::HashSet;
6use edit_action::{EditAction, EditActionParser};
7use futures::StreamExt;
8use gpui::{App, Entity, Task};
9use language_model::{
10 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
11};
12use project::{Project, ProjectPath};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use std::fmt::Write;
16use std::sync::Arc;
17
18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
19pub struct EditFilesToolInput {
20 /// High-level edit instructions. These will be interpreted by a smaller model,
21 /// so explain the edits you want that model to make and to which files need changing.
22 /// The description should be concise and clear. We will show this description to the user
23 /// as well.
24 ///
25 /// <example>
26 /// If you want to rename a function you can say "Rename the function 'foo' to 'bar'".
27 /// </example>
28 ///
29 /// <example>
30 /// If you want to add a new function you can say "Add a new method to the `User` struct that prints the age".
31 /// </example>
32 pub edit_instructions: String,
33}
34
35pub struct EditFilesTool;
36
37impl Tool for EditFilesTool {
38 fn name(&self) -> String {
39 "edit-files".into()
40 }
41
42 fn description(&self) -> String {
43 include_str!("./edit_files_tool/description.md").into()
44 }
45
46 fn input_schema(&self) -> serde_json::Value {
47 let schema = schemars::schema_for!(EditFilesToolInput);
48 serde_json::to_value(&schema).unwrap()
49 }
50
51 fn run(
52 self: Arc<Self>,
53 input: serde_json::Value,
54 messages: &[LanguageModelRequestMessage],
55 project: Entity<Project>,
56 cx: &mut App,
57 ) -> Task<Result<String>> {
58 let input = match serde_json::from_value::<EditFilesToolInput>(input) {
59 Ok(input) => input,
60 Err(err) => return Task::ready(Err(anyhow!(err))),
61 };
62
63 let model_registry = LanguageModelRegistry::read_global(cx);
64 let Some(model) = model_registry.editor_model() else {
65 return Task::ready(Err(anyhow!("No editor model configured")));
66 };
67
68 let mut messages = messages.to_vec();
69 if let Some(last_message) = messages.last_mut() {
70 // Strip out tool use from the last message because we're in the middle of executing a tool call.
71 last_message
72 .content
73 .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
74 }
75 messages.push(LanguageModelRequestMessage {
76 role: Role::User,
77 content: vec![
78 include_str!("./edit_files_tool/edit_prompt.md").into(),
79 input.edit_instructions.into(),
80 ],
81 cache: false,
82 });
83
84 cx.spawn(|mut cx| async move {
85 let request = LanguageModelRequest {
86 messages,
87 tools: vec![],
88 stop: vec![],
89 temperature: None,
90 };
91
92 let mut parser = EditActionParser::new();
93
94 let stream = model.stream_completion_text(request, &cx);
95 let mut chunks = stream.await?;
96
97 let mut changed_buffers = HashSet::default();
98 let mut applied_edits = 0;
99
100 while let Some(chunk) = chunks.stream.next().await {
101 for action in parser.parse_chunk(&chunk?) {
102 let project_path = project.read_with(&cx, |project, cx| {
103 let worktree_root_name = action
104 .file_path()
105 .components()
106 .next()
107 .context("Invalid path")?;
108 let worktree = project
109 .worktree_for_root_name(
110 &worktree_root_name.as_os_str().to_string_lossy(),
111 cx,
112 )
113 .context("Directory not found in project")?;
114 anyhow::Ok(ProjectPath {
115 worktree_id: worktree.read(cx).id(),
116 path: Arc::from(
117 action.file_path().strip_prefix(worktree_root_name).unwrap(),
118 ),
119 })
120 })??;
121
122 let buffer = project
123 .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
124 .await?;
125
126 let diff = buffer
127 .read_with(&cx, |buffer, cx| {
128 let new_text = match action {
129 EditAction::Replace { old, new, .. } => {
130 // TODO: Replace in background?
131 buffer.text().replace(&old, &new)
132 }
133 EditAction::Write { content, .. } => content,
134 };
135
136 buffer.diff(new_text, cx)
137 })?
138 .await;
139
140 let _clock =
141 buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
142
143 changed_buffers.insert(buffer);
144
145 applied_edits += 1;
146 }
147 }
148
149 let mut answer = match changed_buffers.len() {
150 0 => "No files were edited.".to_string(),
151 1 => "Successfully edited ".to_string(),
152 _ => "Successfully edited these files:\n\n".to_string(),
153 };
154
155 // Save each buffer once at the end
156 for buffer in changed_buffers {
157 project
158 .update(&mut cx, |project, cx| {
159 if let Some(file) = buffer.read(&cx).file() {
160 let _ = write!(&mut answer, "{}\n\n", &file.path().display());
161 }
162
163 project.save_buffer(buffer, cx)
164 })?
165 .await?;
166 }
167
168 let errors = parser.errors();
169
170 if errors.is_empty() {
171 Ok(answer.trim_end().to_string())
172 } else {
173 let error_message = errors
174 .iter()
175 .map(|e| e.to_string())
176 .collect::<Vec<_>>()
177 .join("\n");
178
179 if applied_edits > 0 {
180 Err(anyhow!(
181 "Applied {} edit(s), but some blocks failed to parse:\n{}",
182 applied_edits,
183 error_message
184 ))
185 } else {
186 Err(anyhow!(error_message))
187 }
188 }
189 })
190 }
191}