1use collections::HashMap;
2use editor::{Editor, ToOffset, ToPoint};
3use futures::{channel::mpsc, SinkExt, StreamExt};
4use gpui::{AppContext, Task, ViewHandle};
5use language::{Point, Rope};
6use std::{cmp, env, fmt::Write};
7use util::TryFutureExt;
8
9use crate::{
10 stream_completion,
11 streaming_diff::{Hunk, StreamingDiff},
12 OpenAIRequest, RequestMessage, Role,
13};
14
15pub struct RefactoringAssistant {
16 pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
17}
18
19impl RefactoringAssistant {
20 fn new() -> Self {
21 Self {
22 pending_edits_by_editor: Default::default(),
23 }
24 }
25
26 pub fn update<F, T>(cx: &mut AppContext, f: F) -> T
27 where
28 F: FnOnce(&mut Self, &mut AppContext) -> T,
29 {
30 if !cx.has_global::<Self>() {
31 cx.set_global(Self::new());
32 }
33
34 cx.update_global(f)
35 }
36
37 pub fn refactor(
38 &mut self,
39 editor: &ViewHandle<Editor>,
40 user_prompt: &str,
41 cx: &mut AppContext,
42 ) {
43 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
44 api_key
45 } else {
46 // TODO: ensure the API key is present by going through the assistant panel's flow.
47 return;
48 };
49
50 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
51 let selection = editor.read(cx).selections.newest_anchor().clone();
52 let selected_text = snapshot
53 .text_for_range(selection.start..selection.end)
54 .collect::<Rope>();
55
56 let mut normalized_selected_text = selected_text.clone();
57 let mut base_indentation: Option<language::IndentSize> = None;
58 let selection_start = selection.start.to_point(&snapshot);
59 let selection_end = selection.end.to_point(&snapshot);
60 if selection_start.row < selection_end.row {
61 for row in selection_start.row..=selection_end.row {
62 if snapshot.is_line_blank(row) {
63 continue;
64 }
65
66 let line_indentation = snapshot.indent_size_for_line(row);
67 if let Some(base_indentation) = base_indentation.as_mut() {
68 if line_indentation.len < base_indentation.len {
69 *base_indentation = line_indentation;
70 }
71 } else {
72 base_indentation = Some(line_indentation);
73 }
74 }
75 }
76
77 if let Some(base_indentation) = base_indentation {
78 for row in selection_start.row..=selection_end.row {
79 let selection_row = row - selection_start.row;
80 let line_start =
81 normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
82 let indentation_len = if row == selection_start.row {
83 base_indentation.len.saturating_sub(selection_start.column)
84 } else {
85 let line_len = normalized_selected_text.line_len(selection_row);
86 cmp::min(line_len, base_indentation.len)
87 };
88 let indentation_end = cmp::min(
89 line_start + indentation_len as usize,
90 normalized_selected_text.len(),
91 );
92 normalized_selected_text.replace(line_start..indentation_end, "");
93 }
94 }
95
96 let language_name = snapshot
97 .language_at(selection.start)
98 .map(|language| language.name());
99 let language_name = language_name.as_deref().unwrap_or("");
100
101 let mut prompt = String::new();
102 writeln!(prompt, "Given the following {language_name} snippet:").unwrap();
103 writeln!(prompt, "{normalized_selected_text}").unwrap();
104 writeln!(prompt, "{user_prompt}.").unwrap();
105 writeln!(prompt, "Never make remarks, reply only with the new code.").unwrap();
106 let request = OpenAIRequest {
107 model: "gpt-4".into(),
108 messages: vec![RequestMessage {
109 role: Role::User,
110 content: prompt,
111 }],
112 stream: true,
113 };
114 let response = stream_completion(api_key, cx.background().clone(), request);
115 let editor = editor.downgrade();
116 self.pending_edits_by_editor.insert(
117 editor.id(),
118 cx.spawn(|mut cx| {
119 async move {
120 let _clear_highlights = util::defer({
121 let mut cx = cx.clone();
122 let editor = editor.clone();
123 move || {
124 let _ = editor.update(&mut cx, |editor, cx| {
125 editor.clear_text_highlights::<Self>(cx);
126 });
127 }
128 });
129
130 let mut edit_start = selection.start.to_offset(&snapshot);
131
132 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
133 let diff = cx.background().spawn(async move {
134 let mut messages = response.await?.ready_chunks(4);
135 let mut diff = StreamingDiff::new(selected_text.to_string());
136
137 let indentation_len;
138 let indentation_text;
139 if let Some(base_indentation) = base_indentation {
140 indentation_len = base_indentation.len;
141 indentation_text = match base_indentation.kind {
142 language::IndentKind::Space => " ",
143 language::IndentKind::Tab => "\t",
144 };
145 } else {
146 indentation_len = 0;
147 indentation_text = "";
148 };
149
150 let mut new_text =
151 indentation_text.repeat(
152 indentation_len.saturating_sub(selection_start.column) as usize,
153 );
154 while let Some(messages) = messages.next().await {
155 for message in messages {
156 let mut message = message?;
157 if let Some(choice) = message.choices.pop() {
158 if let Some(text) = choice.delta.content {
159 let mut lines = text.split('\n');
160 if let Some(first_line) = lines.next() {
161 new_text.push_str(&first_line);
162 }
163
164 for line in lines {
165 new_text.push('\n');
166 new_text.push_str(
167 &indentation_text.repeat(indentation_len as usize),
168 );
169 new_text.push_str(line);
170 }
171 }
172 }
173 }
174
175 let hunks = diff.push_new(&new_text);
176 hunks_tx.send(hunks).await?;
177 new_text.clear();
178 }
179 hunks_tx.send(diff.finish()).await?;
180
181 anyhow::Ok(())
182 });
183
184 let mut first_transaction = None;
185 while let Some(hunks) = hunks_rx.next().await {
186 editor.update(&mut cx, |editor, cx| {
187 let mut highlights = Vec::new();
188
189 editor.buffer().update(cx, |buffer, cx| {
190 // Avoid grouping assistant edits with user edits.
191 buffer.finalize_last_transaction(cx);
192
193 buffer.start_transaction(cx);
194 buffer.edit(
195 hunks.into_iter().filter_map(|hunk| match hunk {
196 Hunk::Insert { text } => {
197 let edit_start = snapshot.anchor_after(edit_start);
198 Some((edit_start..edit_start, text))
199 }
200 Hunk::Remove { len } => {
201 let edit_end = edit_start + len;
202 let edit_range = snapshot.anchor_after(edit_start)
203 ..snapshot.anchor_before(edit_end);
204 edit_start = edit_end;
205 Some((edit_range, String::new()))
206 }
207 Hunk::Keep { len } => {
208 let edit_end = edit_start + len;
209 let edit_range = snapshot.anchor_after(edit_start)
210 ..snapshot.anchor_before(edit_end);
211 edit_start += len;
212 highlights.push(edit_range);
213 None
214 }
215 }),
216 None,
217 cx,
218 );
219 if let Some(transaction) = buffer.end_transaction(cx) {
220 if let Some(first_transaction) = first_transaction {
221 // Group all assistant edits into the first transaction.
222 buffer.merge_transactions(
223 transaction,
224 first_transaction,
225 cx,
226 );
227 } else {
228 first_transaction = Some(transaction);
229 buffer.finalize_last_transaction(cx);
230 }
231 }
232 });
233
234 editor.highlight_text::<Self>(
235 highlights,
236 gpui::fonts::HighlightStyle {
237 fade_out: Some(0.6),
238 ..Default::default()
239 },
240 cx,
241 );
242 })?;
243 }
244 diff.await?;
245
246 anyhow::Ok(())
247 }
248 .log_err()
249 }),
250 );
251 }
252}