1mod edit_action;
2
3use anyhow::{anyhow, Context, Result};
4use assistant_tool::Tool;
5use collections::HashSet;
6use edit_action::{EditAction, EditActionParser};
7use futures::StreamExt;
8use gpui::{App, Entity, Task};
9use language_model::{
10 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
11};
12use project::{Project, ProjectPath};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16
17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
18pub struct EditFilesToolInput {
19 /// High-level edit instructions. These will be interpreted by a smaller model,
20 /// so explain the edits you want that model to make and to which files need changing.
21 /// The description should be concise and clear. We will show this description to the user
22 /// as well.
23 ///
24 /// <example>
25 /// If you want to rename a function you can say "Rename the function 'foo' to 'bar'".
26 /// </example>
27 ///
28 /// <example>
29 /// If you want to add a new function you can say "Add a new method to the `User` struct that prints the age".
30 /// </example>
31 pub edit_instructions: String,
32}
33
34pub struct EditFilesTool;
35
36impl Tool for EditFilesTool {
37 fn name(&self) -> String {
38 "edit-files".into()
39 }
40
41 fn description(&self) -> String {
42 include_str!("./edit_files_tool/description.md").into()
43 }
44
45 fn input_schema(&self) -> serde_json::Value {
46 let schema = schemars::schema_for!(EditFilesToolInput);
47 serde_json::to_value(&schema).unwrap()
48 }
49
50 fn run(
51 self: Arc<Self>,
52 input: serde_json::Value,
53 messages: &[LanguageModelRequestMessage],
54 project: Entity<Project>,
55 cx: &mut App,
56 ) -> Task<Result<String>> {
57 let input = match serde_json::from_value::<EditFilesToolInput>(input) {
58 Ok(input) => input,
59 Err(err) => return Task::ready(Err(anyhow!(err))),
60 };
61
62 let model_registry = LanguageModelRegistry::read_global(cx);
63 let Some(model) = model_registry.editor_model() else {
64 return Task::ready(Err(anyhow!("No editor model configured")));
65 };
66
67 let mut messages = messages.to_vec();
68 if let Some(last_message) = messages.last_mut() {
69 // Strip out tool use from the last message because we're in the middle of executing a tool call.
70 last_message
71 .content
72 .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
73 }
74 messages.push(LanguageModelRequestMessage {
75 role: Role::User,
76 content: vec![
77 include_str!("./edit_files_tool/edit_prompt.md").into(),
78 input.edit_instructions.into(),
79 ],
80 cache: false,
81 });
82
83 cx.spawn(|mut cx| async move {
84 let request = LanguageModelRequest {
85 messages,
86 tools: vec![],
87 stop: vec![],
88 temperature: None,
89 };
90
91 let mut parser = EditActionParser::new();
92
93 let stream = model.stream_completion_text(request, &cx);
94 let mut chunks = stream.await?;
95
96 let mut changed_buffers = HashSet::default();
97 let mut applied_edits = 0;
98
99 while let Some(chunk) = chunks.stream.next().await {
100 for action in parser.parse_chunk(&chunk?) {
101 let project_path = project.read_with(&cx, |project, cx| {
102 let worktree_root_name = action
103 .file_path()
104 .components()
105 .next()
106 .context("Invalid path")?;
107 let worktree = project
108 .worktree_for_root_name(
109 &worktree_root_name.as_os_str().to_string_lossy(),
110 cx,
111 )
112 .context("Directory not found in project")?;
113 anyhow::Ok(ProjectPath {
114 worktree_id: worktree.read(cx).id(),
115 path: Arc::from(
116 action.file_path().strip_prefix(worktree_root_name).unwrap(),
117 ),
118 })
119 })??;
120
121 let buffer = project
122 .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
123 .await?;
124
125 let diff = buffer
126 .read_with(&cx, |buffer, cx| {
127 let new_text = match action {
128 EditAction::Replace { old, new, .. } => {
129 // TODO: Replace in background?
130 buffer.text().replace(&old, &new)
131 }
132 EditAction::Write { content, .. } => content,
133 };
134
135 buffer.diff(new_text, cx)
136 })?
137 .await;
138
139 let _clock =
140 buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
141
142 changed_buffers.insert(buffer);
143
144 applied_edits += 1;
145 }
146 }
147
148 // Save each buffer once at the end
149 for buffer in changed_buffers {
150 project
151 .update(&mut cx, |project, cx| project.save_buffer(buffer, cx))?
152 .await?;
153 }
154
155 let errors = parser.errors();
156
157 if errors.is_empty() {
158 Ok("Successfully applied all edits".into())
159 } else {
160 let error_message = errors
161 .iter()
162 .map(|e| e.to_string())
163 .collect::<Vec<_>>()
164 .join("\n");
165
166 if applied_edits > 0 {
167 Err(anyhow!(
168 "Applied {} edit(s), but some blocks failed to parse:\n{}",
169 applied_edits,
170 error_message
171 ))
172 } else {
173 Err(anyhow!(error_message))
174 }
175 }
176 })
177 }
178}