1use crate::{diff::Diff, stream_completion, OpenAIRequest, RequestMessage, Role};
2use collections::HashMap;
3use editor::{Editor, ToOffset, ToPoint};
4use futures::{channel::mpsc, SinkExt, StreamExt};
5use gpui::{
6 actions, elements::*, platform::MouseButton, AnyViewHandle, AppContext, Entity, Task, View,
7 ViewContext, ViewHandle, WeakViewHandle,
8};
9use language::{Point, Rope};
10use menu::{Cancel, Confirm};
11use std::{cmp, env, sync::Arc};
12use util::TryFutureExt;
13use workspace::{Modal, Workspace};
14
15actions!(assistant, [Refactor]);
16
17pub fn init(cx: &mut AppContext) {
18 cx.set_global(RefactoringAssistant::new());
19 cx.add_action(RefactoringModal::deploy);
20 cx.add_action(RefactoringModal::confirm);
21 cx.add_action(RefactoringModal::cancel);
22}
23
24pub struct RefactoringAssistant {
25 pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
26}
27
28impl RefactoringAssistant {
29 fn new() -> Self {
30 Self {
31 pending_edits_by_editor: Default::default(),
32 }
33 }
34
35 fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
36 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
37 let selection = editor.read(cx).selections.newest_anchor().clone();
38 let selected_text = snapshot
39 .text_for_range(selection.start..selection.end)
40 .collect::<Rope>();
41
42 let mut normalized_selected_text = selected_text.clone();
43 let mut base_indentation: Option<language::IndentSize> = None;
44 let selection_start = selection.start.to_point(&snapshot);
45 let selection_end = selection.end.to_point(&snapshot);
46 if selection_start.row < selection_end.row {
47 for row in selection_start.row..=selection_end.row {
48 if snapshot.is_line_blank(row) {
49 continue;
50 }
51
52 let line_indentation = snapshot.indent_size_for_line(row);
53 if let Some(base_indentation) = base_indentation.as_mut() {
54 if line_indentation.len < base_indentation.len {
55 *base_indentation = line_indentation;
56 }
57 } else {
58 base_indentation = Some(line_indentation);
59 }
60 }
61 }
62
63 if let Some(base_indentation) = base_indentation {
64 for row in selection_start.row..=selection_end.row {
65 let selection_row = row - selection_start.row;
66 let line_start =
67 normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
68 let indentation_len = if row == selection_start.row {
69 base_indentation.len.saturating_sub(selection_start.column)
70 } else {
71 let line_len = normalized_selected_text.line_len(selection_row);
72 cmp::min(line_len, base_indentation.len)
73 };
74 let indentation_end = cmp::min(
75 line_start + indentation_len as usize,
76 normalized_selected_text.len(),
77 );
78 normalized_selected_text.replace(line_start..indentation_end, "");
79 }
80 }
81
82 let language_name = snapshot
83 .language_at(selection.start)
84 .map(|language| language.name());
85 let language_name = language_name.as_deref().unwrap_or("");
86 let request = OpenAIRequest {
87 model: "gpt-4".into(),
88 messages: vec![
89 RequestMessage {
90 role: Role::User,
91 content: format!(
92 "Given the following {language_name} snippet:\n{normalized_selected_text}\n{prompt}. Never make remarks and reply only with the new code."
93 ),
94 }],
95 stream: true,
96 };
97 let api_key = env::var("OPENAI_API_KEY").unwrap();
98 let response = stream_completion(api_key, cx.background().clone(), request);
99 let editor = editor.downgrade();
100 self.pending_edits_by_editor.insert(
101 editor.id(),
102 cx.spawn(|mut cx| {
103 async move {
104 let mut edit_start = selection.start.to_offset(&snapshot);
105
106 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
107 let diff = cx.background().spawn(async move {
108 let mut messages = response.await?.ready_chunks(4);
109 let mut diff = Diff::new(selected_text.to_string());
110
111 let indentation_len;
112 let indentation_text;
113 if let Some(base_indentation) = base_indentation {
114 indentation_len = base_indentation.len;
115 indentation_text = match base_indentation.kind {
116 language::IndentKind::Space => " ",
117 language::IndentKind::Tab => "\t",
118 };
119 } else {
120 indentation_len = 0;
121 indentation_text = "";
122 };
123
124 let mut new_text =
125 indentation_text.repeat(
126 indentation_len.saturating_sub(selection_start.column) as usize,
127 );
128 while let Some(messages) = messages.next().await {
129 for message in messages {
130 let mut message = message?;
131 if let Some(choice) = message.choices.pop() {
132 if let Some(text) = choice.delta.content {
133 let mut lines = text.split('\n');
134 if let Some(first_line) = lines.next() {
135 new_text.push_str(&first_line);
136 }
137
138 for line in lines {
139 new_text.push('\n');
140 new_text.push_str(
141 &indentation_text.repeat(indentation_len as usize),
142 );
143 new_text.push_str(line);
144 }
145 }
146 }
147 }
148
149 let hunks = diff.push_new(&new_text);
150 hunks_tx.send(hunks).await?;
151 new_text.clear();
152 }
153 hunks_tx.send(diff.finish()).await?;
154
155 anyhow::Ok(())
156 });
157
158 let mut first_transaction = None;
159 while let Some(hunks) = hunks_rx.next().await {
160 editor.update(&mut cx, |editor, cx| {
161 let mut highlights = Vec::new();
162
163 editor.buffer().update(cx, |buffer, cx| {
164 // Avoid grouping assistant edits with user edits.
165 buffer.finalize_last_transaction(cx);
166
167 buffer.start_transaction(cx);
168 buffer.edit(
169 hunks.into_iter().filter_map(|hunk| match hunk {
170 crate::diff::Hunk::Insert { text } => {
171 let edit_start = snapshot.anchor_after(edit_start);
172 Some((edit_start..edit_start, text))
173 }
174 crate::diff::Hunk::Remove { len } => {
175 let edit_end = edit_start + len;
176 let edit_range = snapshot.anchor_after(edit_start)
177 ..snapshot.anchor_before(edit_end);
178 edit_start = edit_end;
179 Some((edit_range, String::new()))
180 }
181 crate::diff::Hunk::Keep { len } => {
182 let edit_end = edit_start + len;
183 let edit_range = snapshot.anchor_after(edit_start)
184 ..snapshot.anchor_before(edit_end);
185 edit_start += len;
186 highlights.push(edit_range);
187 None
188 }
189 }),
190 None,
191 cx,
192 );
193 if let Some(transaction) = buffer.end_transaction(cx) {
194 if let Some(first_transaction) = first_transaction {
195 // Group all assistant edits into the first transaction.
196 buffer.merge_transaction_into(
197 transaction,
198 first_transaction,
199 cx,
200 );
201 } else {
202 first_transaction = Some(transaction);
203 buffer.finalize_last_transaction(cx);
204 }
205 }
206 });
207
208 editor.highlight_text::<Self>(
209 highlights,
210 gpui::fonts::HighlightStyle {
211 fade_out: Some(0.6),
212 ..Default::default()
213 },
214 cx,
215 );
216 })?;
217 }
218
219 diff.await?;
220 editor.update(&mut cx, |editor, cx| {
221 editor.clear_text_highlights::<Self>(cx);
222 })?;
223
224 anyhow::Ok(())
225 }
226 .log_err()
227 }),
228 );
229 }
230}
231
232enum Event {
233 Dismissed,
234}
235
236struct RefactoringModal {
237 active_editor: WeakViewHandle<Editor>,
238 prompt_editor: ViewHandle<Editor>,
239 has_focus: bool,
240}
241
242impl Entity for RefactoringModal {
243 type Event = Event;
244}
245
246impl View for RefactoringModal {
247 fn ui_name() -> &'static str {
248 "RefactoringModal"
249 }
250
251 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
252 let theme = theme::current(cx);
253
254 ChildView::new(&self.prompt_editor, cx)
255 .constrained()
256 .with_width(theme.assistant.modal.width)
257 .contained()
258 .with_style(theme.assistant.modal.container)
259 .mouse::<Self>(0)
260 .on_click_out(MouseButton::Left, |_, _, cx| cx.emit(Event::Dismissed))
261 .on_click_out(MouseButton::Right, |_, _, cx| cx.emit(Event::Dismissed))
262 .aligned()
263 .right()
264 .into_any()
265 }
266
267 fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
268 self.has_focus = true;
269 cx.focus(&self.prompt_editor);
270 }
271
272 fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
273 self.has_focus = false;
274 }
275}
276
277impl Modal for RefactoringModal {
278 fn has_focus(&self) -> bool {
279 self.has_focus
280 }
281
282 fn dismiss_on_event(event: &Self::Event) -> bool {
283 matches!(event, Self::Event::Dismissed)
284 }
285}
286
287impl RefactoringModal {
288 fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
289 if let Some(active_editor) = workspace
290 .active_item(cx)
291 .and_then(|item| Some(item.act_as::<Editor>(cx)?.downgrade()))
292 {
293 workspace.toggle_modal(cx, |_, cx| {
294 let prompt_editor = cx.add_view(|cx| {
295 let mut editor = Editor::auto_height(
296 theme::current(cx).assistant.modal.editor_max_lines,
297 Some(Arc::new(|theme| theme.assistant.modal.editor.clone())),
298 cx,
299 );
300 editor
301 .set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
302 editor
303 });
304 cx.add_view(|_| RefactoringModal {
305 active_editor,
306 prompt_editor,
307 has_focus: false,
308 })
309 });
310 }
311 }
312
313 fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
314 cx.emit(Event::Dismissed);
315 }
316
317 fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
318 if let Some(editor) = self.active_editor.upgrade(cx) {
319 let prompt = self.prompt_editor.read(cx).text(cx);
320 cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
321 assistant.refactor(&editor, &prompt, cx);
322 });
323 cx.emit(Event::Dismissed);
324 }
325 }
326}