refactor.rs

  1use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
  2use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
  3use editor::{Anchor, Editor, MultiBuffer, MultiBufferSnapshot, ToOffset};
  4use futures::{io::BufWriter, AsyncReadExt, AsyncWriteExt, StreamExt};
  5use gpui::{
  6    actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
  7    WeakViewHandle,
  8};
  9use menu::Confirm;
 10use serde::Deserialize;
 11use similar::ChangeTag;
 12use std::{env, iter, ops::Range, sync::Arc};
 13use util::TryFutureExt;
 14use workspace::{Modal, Workspace};
 15
 16actions!(assistant, [Refactor]);
 17
 18pub fn init(cx: &mut AppContext) {
 19    cx.set_global(RefactoringAssistant::new());
 20    cx.add_action(RefactoringModal::deploy);
 21    cx.add_action(RefactoringModal::confirm);
 22}
 23
 24pub struct RefactoringAssistant {
 25    pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
 26}
 27
 28impl RefactoringAssistant {
 29    fn new() -> Self {
 30        Self {
 31            pending_edits_by_editor: Default::default(),
 32        }
 33    }
 34
 35    fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
 36        let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
 37        let selection = editor.read(cx).selections.newest_anchor().clone();
 38        let selected_text = buffer
 39            .text_for_range(selection.start..selection.end)
 40            .collect::<String>();
 41        let language_name = buffer
 42            .language_at(selection.start)
 43            .map(|language| language.name());
 44        let language_name = language_name.as_deref().unwrap_or("");
 45        let request = OpenAIRequest {
 46            model: "gpt-4".into(),
 47            messages: vec![
 48                RequestMessage {
 49                role: Role::User,
 50                content: format!(
 51                    "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code."
 52                ),
 53            }],
 54            stream: true,
 55        };
 56        let api_key = env::var("OPENAI_API_KEY").unwrap();
 57        let response = stream_completion(api_key, cx.background().clone(), request);
 58        let editor = editor.downgrade();
 59        self.pending_edits_by_editor.insert(
 60            editor.id(),
 61            cx.spawn(|mut cx| {
 62                async move {
 63                    let selection_start = selection.start.to_offset(&buffer);
 64
 65                    // Find unique words in the selected text to use as diff boundaries.
 66                    let mut duplicate_words = HashSet::default();
 67                    let mut unique_old_words = HashMap::default();
 68                    for (range, word) in words(&selected_text) {
 69                        if !duplicate_words.contains(word) {
 70                            if unique_old_words.insert(word, range.end).is_some() {
 71                                unique_old_words.remove(word);
 72                                duplicate_words.insert(word);
 73                            }
 74                        }
 75                    }
 76
 77                    let mut new_text = String::new();
 78                    let mut messages = response.await?;
 79                    let mut new_word_search_start_ix = 0;
 80                    let mut last_old_word_end_ix = 0;
 81
 82                    'outer: loop {
 83                        let start = new_word_search_start_ix;
 84                        let mut words = words(&new_text[start..]);
 85                        while let Some((range, new_word)) = words.next() {
 86                            // We found a word in the new text that was unique in the old text. We can use
 87                            // it as a diff boundary, and start applying edits.
 88                            if let Some(old_word_end_ix) = unique_old_words.remove(new_word) {
 89                                if old_word_end_ix > last_old_word_end_ix {
 90                                    drop(words);
 91
 92                                    let remainder = new_text.split_off(start + range.end);
 93                                    let edits = diff(
 94                                        selection_start + last_old_word_end_ix,
 95                                        &selected_text[last_old_word_end_ix..old_word_end_ix],
 96                                        &new_text,
 97                                        &buffer,
 98                                    );
 99                                    editor.update(&mut cx, |editor, cx| {
100                                        editor
101                                            .buffer()
102                                            .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
103                                    })?;
104
105                                    new_text = remainder;
106                                    new_word_search_start_ix = 0;
107                                    last_old_word_end_ix = old_word_end_ix;
108                                    continue 'outer;
109                                }
110                            }
111
112                            new_word_search_start_ix = start + range.end;
113                        }
114                        drop(words);
115
116                        // Buffer incoming text, stopping if the stream was exhausted.
117                        if let Some(message) = messages.next().await {
118                            let mut message = message?;
119                            if let Some(choice) = message.choices.pop() {
120                                if let Some(text) = choice.delta.content {
121                                    new_text.push_str(&text);
122                                }
123                            }
124                        } else {
125                            break;
126                        }
127                    }
128
129                    let edits = diff(
130                        selection_start + last_old_word_end_ix,
131                        &selected_text[last_old_word_end_ix..],
132                        &new_text,
133                        &buffer,
134                    );
135                    editor.update(&mut cx, |editor, cx| {
136                        editor
137                            .buffer()
138                            .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
139                    })?;
140
141                    anyhow::Ok(())
142                }
143                .log_err()
144            }),
145        );
146    }
147}
148
149struct RefactoringModal {
150    editor: WeakViewHandle<Editor>,
151    prompt_editor: ViewHandle<Editor>,
152    has_focus: bool,
153}
154
155impl Entity for RefactoringModal {
156    type Event = ();
157}
158
159impl View for RefactoringModal {
160    fn ui_name() -> &'static str {
161        "RefactoringModal"
162    }
163
164    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
165        ChildView::new(&self.prompt_editor, cx).into_any()
166    }
167
168    fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
169        self.has_focus = true;
170    }
171
172    fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
173        self.has_focus = false;
174    }
175}
176
177impl Modal for RefactoringModal {
178    fn has_focus(&self) -> bool {
179        self.has_focus
180    }
181
182    fn dismiss_on_event(event: &Self::Event) -> bool {
183        // TODO
184        false
185    }
186}
187
188impl RefactoringModal {
189    fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
190        if let Some(editor) = workspace
191            .active_item(cx)
192            .and_then(|item| Some(item.downcast::<Editor>()?.downgrade()))
193        {
194            workspace.toggle_modal(cx, |_, cx| {
195                let prompt_editor = cx.add_view(|cx| {
196                    Editor::auto_height(
197                        4,
198                        Some(Arc::new(|theme| theme.search.editor.input.clone())),
199                        cx,
200                    )
201                });
202                cx.add_view(|_| RefactoringModal {
203                    editor,
204                    prompt_editor,
205                    has_focus: false,
206                })
207            });
208        }
209    }
210
211    fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
212        if let Some(editor) = self.editor.upgrade(cx) {
213            let prompt = self.prompt_editor.read(cx).text(cx);
214            cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
215                assistant.refactor(&editor, &prompt, cx);
216            });
217        }
218    }
219}
220fn words(text: &str) -> impl Iterator<Item = (Range<usize>, &str)> {
221    let mut word_start_ix = None;
222    let mut chars = text.char_indices();
223    iter::from_fn(move || {
224        while let Some((ix, ch)) = chars.next() {
225            if let Some(start_ix) = word_start_ix {
226                if !ch.is_alphanumeric() {
227                    let word = &text[start_ix..ix];
228                    word_start_ix.take();
229                    return Some((start_ix..ix, word));
230                }
231            } else {
232                if ch.is_alphanumeric() {
233                    word_start_ix = Some(ix);
234                }
235            }
236        }
237        None
238    })
239}
240
241fn diff<'a>(
242    start_ix: usize,
243    old_text: &'a str,
244    new_text: &'a str,
245    old_buffer_snapshot: &MultiBufferSnapshot,
246) -> Vec<(Range<Anchor>, &'a str)> {
247    let mut edit_start = start_ix;
248    let mut edits = Vec::new();
249    let diff = similar::TextDiff::from_words(old_text, &new_text);
250    for change in diff.iter_all_changes() {
251        let value = change.value();
252        let edit_end = edit_start + value.len();
253        match change.tag() {
254            ChangeTag::Equal => {
255                edit_start = edit_end;
256            }
257            ChangeTag::Delete => {
258                edits.push((
259                    old_buffer_snapshot.anchor_after(edit_start)
260                        ..old_buffer_snapshot.anchor_before(edit_end),
261                    "",
262                ));
263                edit_start = edit_end;
264            }
265            ChangeTag::Insert => {
266                edits.push((
267                    old_buffer_snapshot.anchor_after(edit_start)
268                        ..old_buffer_snapshot.anchor_after(edit_start),
269                    value,
270                ));
271            }
272        }
273    }
274    edits
275}