refactor.rs

  1use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
  2use collections::HashMap;
  3use editor::{Editor, ToOffset};
  4use futures::{channel::mpsc, SinkExt, StreamExt};
  5use gpui::{
  6    actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
  7    WeakViewHandle,
  8};
  9use menu::Confirm;
 10use similar::{Change, ChangeTag, TextDiff};
 11use std::{env, iter, ops::Range, sync::Arc};
 12use util::TryFutureExt;
 13use workspace::{Modal, Workspace};
 14
 15actions!(assistant, [Refactor]);
 16
 17pub fn init(cx: &mut AppContext) {
 18    cx.set_global(RefactoringAssistant::new());
 19    cx.add_action(RefactoringModal::deploy);
 20    cx.add_action(RefactoringModal::confirm);
 21}
 22
 23pub struct RefactoringAssistant {
 24    pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
 25}
 26
 27impl RefactoringAssistant {
 28    fn new() -> Self {
 29        Self {
 30            pending_edits_by_editor: Default::default(),
 31        }
 32    }
 33
 34    fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
 35        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 36        let selection = editor.read(cx).selections.newest_anchor().clone();
 37        let selected_text = snapshot
 38            .text_for_range(selection.start..selection.end)
 39            .collect::<String>();
 40        let language_name = snapshot
 41            .language_at(selection.start)
 42            .map(|language| language.name());
 43        let language_name = language_name.as_deref().unwrap_or("");
 44        let request = OpenAIRequest {
 45            model: "gpt-4".into(),
 46            messages: vec![
 47                RequestMessage {
 48                role: Role::User,
 49                content: format!(
 50                    "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Never make remarks and reply only with the new code. Never change the leading whitespace on each line."
 51                ),
 52            }],
 53            stream: true,
 54        };
 55        let api_key = env::var("OPENAI_API_KEY").unwrap();
 56        let response = stream_completion(api_key, cx.background().clone(), request);
 57        let editor = editor.downgrade();
 58        self.pending_edits_by_editor.insert(
 59            editor.id(),
 60            cx.spawn(|mut cx| {
 61                async move {
 62                    let mut edit_start = selection.start.to_offset(&snapshot);
 63
 64                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 65                    let diff = cx.background().spawn(async move {
 66                        let mut messages = response.await?.ready_chunks(4);
 67                        let mut diff = crate::diff::Diff::new(selected_text);
 68
 69                        while let Some(messages) = messages.next().await {
 70                            let mut new_text = String::new();
 71                            for message in messages {
 72                                let mut message = message?;
 73                                if let Some(choice) = message.choices.pop() {
 74                                    if let Some(text) = choice.delta.content {
 75                                        new_text.push_str(&text);
 76                                    }
 77                                }
 78                            }
 79
 80                            let hunks = diff.push_new(&new_text);
 81                            hunks_tx.send((hunks, new_text)).await?;
 82                        }
 83
 84                        hunks_tx.send((diff.finish(), String::new())).await?;
 85
 86                        anyhow::Ok(())
 87                    });
 88
 89                    while let Some((hunks, new_text)) = hunks_rx.next().await {
 90                        editor.update(&mut cx, |editor, cx| {
 91                            editor.buffer().update(cx, |buffer, cx| {
 92                                buffer.start_transaction(cx);
 93                                for hunk in hunks {
 94                                    match hunk {
 95                                        crate::diff::Hunk::Insert { text } => {
 96                                            let edit_start = snapshot.anchor_after(edit_start);
 97                                            buffer.edit([(edit_start..edit_start, text)], None, cx);
 98                                        }
 99                                        crate::diff::Hunk::Remove { len } => {
100                                            let edit_end = edit_start + len;
101                                            let edit_range = snapshot.anchor_after(edit_start)
102                                                ..snapshot.anchor_before(edit_end);
103                                            buffer.edit([(edit_range, "")], None, cx);
104                                            edit_start = edit_end;
105                                        }
106                                        crate::diff::Hunk::Keep { len } => {
107                                            edit_start += len;
108                                        }
109                                    }
110                                }
111                                buffer.end_transaction(cx);
112                            })
113                        })?;
114                    }
115
116                    diff.await?;
117                    anyhow::Ok(())
118                }
119                .log_err()
120            }),
121        );
122    }
123}
124
125struct RefactoringModal {
126    editor: WeakViewHandle<Editor>,
127    prompt_editor: ViewHandle<Editor>,
128    has_focus: bool,
129}
130
131impl Entity for RefactoringModal {
132    type Event = ();
133}
134
135impl View for RefactoringModal {
136    fn ui_name() -> &'static str {
137        "RefactoringModal"
138    }
139
140    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
141        ChildView::new(&self.prompt_editor, cx).into_any()
142    }
143
144    fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
145        self.has_focus = true;
146    }
147
148    fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
149        self.has_focus = false;
150    }
151}
152
153impl Modal for RefactoringModal {
154    fn has_focus(&self) -> bool {
155        self.has_focus
156    }
157
158    fn dismiss_on_event(event: &Self::Event) -> bool {
159        // TODO
160        false
161    }
162}
163
164impl RefactoringModal {
165    fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
166        if let Some(editor) = workspace
167            .active_item(cx)
168            .and_then(|item| Some(item.downcast::<Editor>()?.downgrade()))
169        {
170            workspace.toggle_modal(cx, |_, cx| {
171                let prompt_editor = cx.add_view(|cx| {
172                    let mut editor = Editor::auto_height(
173                        4,
174                        Some(Arc::new(|theme| theme.search.editor.input.clone())),
175                        cx,
176                    );
177                    editor.set_text("Replace with match statement.", cx);
178                    editor
179                });
180                cx.add_view(|_| RefactoringModal {
181                    editor,
182                    prompt_editor,
183                    has_focus: false,
184                })
185            });
186        }
187    }
188
189    fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
190        if let Some(editor) = self.editor.upgrade(cx) {
191            let prompt = self.prompt_editor.read(cx).text(cx);
192            cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
193                assistant.refactor(&editor, &prompt, cx);
194            });
195        }
196    }
197}