refactor.rs

  1use crate::{diff::Diff, stream_completion, OpenAIRequest, RequestMessage, Role};
  2use collections::HashMap;
  3use editor::{Editor, ToOffset, ToPoint};
  4use futures::{channel::mpsc, SinkExt, StreamExt};
  5use gpui::{
  6    actions, elements::*, platform::MouseButton, AnyViewHandle, AppContext, Entity, Task, View,
  7    ViewContext, ViewHandle, WeakViewHandle,
  8};
  9use language::{Point, Rope};
 10use menu::{Cancel, Confirm};
 11use std::{cmp, env, sync::Arc};
 12use util::TryFutureExt;
 13use workspace::{Modal, Workspace};
 14
 15actions!(assistant, [Refactor]);
 16
 17pub fn init(cx: &mut AppContext) {
 18    cx.set_global(RefactoringAssistant::new());
 19    cx.add_action(RefactoringModal::deploy);
 20    cx.add_action(RefactoringModal::confirm);
 21    cx.add_action(RefactoringModal::cancel);
 22}
 23
 24pub struct RefactoringAssistant {
 25    pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
 26}
 27
 28impl RefactoringAssistant {
 29    fn new() -> Self {
 30        Self {
 31            pending_edits_by_editor: Default::default(),
 32        }
 33    }
 34
 35    fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
 36        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 37        let selection = editor.read(cx).selections.newest_anchor().clone();
 38        let selected_text = snapshot
 39            .text_for_range(selection.start..selection.end)
 40            .collect::<Rope>();
 41
 42        let mut normalized_selected_text = selected_text.clone();
 43        let mut base_indentation: Option<language::IndentSize> = None;
 44        let selection_start = selection.start.to_point(&snapshot);
 45        let selection_end = selection.end.to_point(&snapshot);
 46        if selection_start.row < selection_end.row {
 47            for row in selection_start.row..=selection_end.row {
 48                if snapshot.is_line_blank(row) {
 49                    continue;
 50                }
 51
 52                let line_indentation = snapshot.indent_size_for_line(row);
 53                if let Some(base_indentation) = base_indentation.as_mut() {
 54                    if line_indentation.len < base_indentation.len {
 55                        *base_indentation = line_indentation;
 56                    }
 57                } else {
 58                    base_indentation = Some(line_indentation);
 59                }
 60            }
 61        }
 62
 63        if let Some(base_indentation) = base_indentation {
 64            for row in selection_start.row..=selection_end.row {
 65                let selection_row = row - selection_start.row;
 66                let line_start =
 67                    normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
 68                let indentation_len = if row == selection_start.row {
 69                    base_indentation.len.saturating_sub(selection_start.column)
 70                } else {
 71                    let line_len = normalized_selected_text.line_len(selection_row);
 72                    cmp::min(line_len, base_indentation.len)
 73                };
 74                let indentation_end = cmp::min(
 75                    line_start + indentation_len as usize,
 76                    normalized_selected_text.len(),
 77                );
 78                normalized_selected_text.replace(line_start..indentation_end, "");
 79            }
 80        }
 81
 82        let language_name = snapshot
 83            .language_at(selection.start)
 84            .map(|language| language.name());
 85        let language_name = language_name.as_deref().unwrap_or("");
 86        let request = OpenAIRequest {
 87            model: "gpt-4".into(),
 88            messages: vec![
 89                RequestMessage {
 90                role: Role::User,
 91                content: format!(
 92                    "Given the following {language_name} snippet:\n{normalized_selected_text}\n{prompt}. Never make remarks and reply only with the new code."
 93                ),
 94            }],
 95            stream: true,
 96        };
 97        let api_key = env::var("OPENAI_API_KEY").unwrap();
 98        let response = stream_completion(api_key, cx.background().clone(), request);
 99        let editor = editor.downgrade();
100        self.pending_edits_by_editor.insert(
101            editor.id(),
102            cx.spawn(|mut cx| {
103                async move {
104                    let mut edit_start = selection.start.to_offset(&snapshot);
105
106                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
107                    let diff = cx.background().spawn(async move {
108                        let mut messages = response.await?.ready_chunks(4);
109                        let mut diff = Diff::new(selected_text.to_string());
110
111                        let indentation_len;
112                        let indentation_text;
113                        if let Some(base_indentation) = base_indentation {
114                            indentation_len = base_indentation.len;
115                            indentation_text = match base_indentation.kind {
116                                language::IndentKind::Space => " ",
117                                language::IndentKind::Tab => "\t",
118                            };
119                        } else {
120                            indentation_len = 0;
121                            indentation_text = "";
122                        };
123
124                        let mut new_text =
125                            indentation_text.repeat(
126                                indentation_len.saturating_sub(selection_start.column) as usize,
127                            );
128                        while let Some(messages) = messages.next().await {
129                            for message in messages {
130                                let mut message = message?;
131                                if let Some(choice) = message.choices.pop() {
132                                    if let Some(text) = choice.delta.content {
133                                        let mut lines = text.split('\n');
134                                        if let Some(first_line) = lines.next() {
135                                            new_text.push_str(&first_line);
136                                        }
137
138                                        for line in lines {
139                                            new_text.push('\n');
140                                            new_text.push_str(
141                                                &indentation_text.repeat(indentation_len as usize),
142                                            );
143                                            new_text.push_str(line);
144                                        }
145                                    }
146                                }
147                            }
148
149                            let hunks = diff.push_new(&new_text);
150                            hunks_tx.send(hunks).await?;
151                            new_text.clear();
152                        }
153                        hunks_tx.send(diff.finish()).await?;
154
155                        anyhow::Ok(())
156                    });
157
158                    let mut first_transaction = None;
159                    while let Some(hunks) = hunks_rx.next().await {
160                        editor.update(&mut cx, |editor, cx| {
161                            let mut highlights = Vec::new();
162
163                            editor.buffer().update(cx, |buffer, cx| {
164                                // Avoid grouping assistant edits with user edits.
165                                buffer.finalize_last_transaction(cx);
166
167                                buffer.start_transaction(cx);
168                                buffer.edit(
169                                    hunks.into_iter().filter_map(|hunk| match hunk {
170                                        crate::diff::Hunk::Insert { text } => {
171                                            let edit_start = snapshot.anchor_after(edit_start);
172                                            Some((edit_start..edit_start, text))
173                                        }
174                                        crate::diff::Hunk::Remove { len } => {
175                                            let edit_end = edit_start + len;
176                                            let edit_range = snapshot.anchor_after(edit_start)
177                                                ..snapshot.anchor_before(edit_end);
178                                            edit_start = edit_end;
179                                            Some((edit_range, String::new()))
180                                        }
181                                        crate::diff::Hunk::Keep { len } => {
182                                            let edit_end = edit_start + len;
183                                            let edit_range = snapshot.anchor_after(edit_start)
184                                                ..snapshot.anchor_before(edit_end);
185                                            edit_start += len;
186                                            highlights.push(edit_range);
187                                            None
188                                        }
189                                    }),
190                                    None,
191                                    cx,
192                                );
193                                if let Some(transaction) = buffer.end_transaction(cx) {
194                                    if let Some(first_transaction) = first_transaction {
195                                        // Group all assistant edits into the first transaction.
196                                        buffer.merge_transaction_into(
197                                            transaction,
198                                            first_transaction,
199                                            cx,
200                                        );
201                                    } else {
202                                        first_transaction = Some(transaction);
203                                        buffer.finalize_last_transaction(cx);
204                                    }
205                                }
206                            });
207
208                            editor.highlight_text::<Self>(
209                                highlights,
210                                gpui::fonts::HighlightStyle {
211                                    fade_out: Some(0.6),
212                                    ..Default::default()
213                                },
214                                cx,
215                            );
216                        })?;
217                    }
218
219                    diff.await?;
220                    editor.update(&mut cx, |editor, cx| {
221                        editor.clear_text_highlights::<Self>(cx);
222                    })?;
223
224                    anyhow::Ok(())
225                }
226                .log_err()
227            }),
228        );
229    }
230}
231
232enum Event {
233    Dismissed,
234}
235
236struct RefactoringModal {
237    active_editor: WeakViewHandle<Editor>,
238    prompt_editor: ViewHandle<Editor>,
239    has_focus: bool,
240}
241
242impl Entity for RefactoringModal {
243    type Event = Event;
244}
245
246impl View for RefactoringModal {
247    fn ui_name() -> &'static str {
248        "RefactoringModal"
249    }
250
251    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
252        let theme = theme::current(cx);
253
254        ChildView::new(&self.prompt_editor, cx)
255            .constrained()
256            .with_width(theme.assistant.modal.width)
257            .contained()
258            .with_style(theme.assistant.modal.container)
259            .mouse::<Self>(0)
260            .on_click_out(MouseButton::Left, |_, _, cx| cx.emit(Event::Dismissed))
261            .on_click_out(MouseButton::Right, |_, _, cx| cx.emit(Event::Dismissed))
262            .aligned()
263            .right()
264            .into_any()
265    }
266
267    fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
268        self.has_focus = true;
269        cx.focus(&self.prompt_editor);
270    }
271
272    fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
273        self.has_focus = false;
274    }
275}
276
277impl Modal for RefactoringModal {
278    fn has_focus(&self) -> bool {
279        self.has_focus
280    }
281
282    fn dismiss_on_event(event: &Self::Event) -> bool {
283        matches!(event, Self::Event::Dismissed)
284    }
285}
286
287impl RefactoringModal {
288    fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
289        if let Some(active_editor) = workspace
290            .active_item(cx)
291            .and_then(|item| Some(item.act_as::<Editor>(cx)?.downgrade()))
292        {
293            workspace.toggle_modal(cx, |_, cx| {
294                let prompt_editor = cx.add_view(|cx| {
295                    let mut editor = Editor::auto_height(
296                        theme::current(cx).assistant.modal.editor_max_lines,
297                        Some(Arc::new(|theme| theme.assistant.modal.editor.clone())),
298                        cx,
299                    );
300                    editor
301                        .set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
302                    editor
303                });
304                cx.add_view(|_| RefactoringModal {
305                    active_editor,
306                    prompt_editor,
307                    has_focus: false,
308                })
309            });
310        }
311    }
312
313    fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
314        cx.emit(Event::Dismissed);
315    }
316
317    fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
318        if let Some(editor) = self.active_editor.upgrade(cx) {
319            let prompt = self.prompt_editor.read(cx).text(cx);
320            cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
321                assistant.refactor(&editor, &prompt, cx);
322            });
323            cx.emit(Event::Dismissed);
324        }
325    }
326}