1use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
2use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
3use editor::{Anchor, Editor, MultiBuffer, MultiBufferSnapshot, ToOffset};
4use futures::{io::BufWriter, AsyncReadExt, AsyncWriteExt, StreamExt};
5use gpui::{
6 actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
7 WeakViewHandle,
8};
9use menu::Confirm;
10use serde::Deserialize;
11use similar::ChangeTag;
12use std::{env, iter, ops::Range, sync::Arc};
13use util::TryFutureExt;
14use workspace::{Modal, Workspace};
15
16actions!(assistant, [Refactor]);
17
18pub fn init(cx: &mut AppContext) {
19 cx.set_global(RefactoringAssistant::new());
20 cx.add_action(RefactoringModal::deploy);
21 cx.add_action(RefactoringModal::confirm);
22}
23
24pub struct RefactoringAssistant {
25 pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
26}
27
28impl RefactoringAssistant {
29 fn new() -> Self {
30 Self {
31 pending_edits_by_editor: Default::default(),
32 }
33 }
34
35 fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
36 let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
37 let selection = editor.read(cx).selections.newest_anchor().clone();
38 let selected_text = buffer
39 .text_for_range(selection.start..selection.end)
40 .collect::<String>();
41 let language_name = buffer
42 .language_at(selection.start)
43 .map(|language| language.name());
44 let language_name = language_name.as_deref().unwrap_or("");
45 let request = OpenAIRequest {
46 model: "gpt-4".into(),
47 messages: vec![
48 RequestMessage {
49 role: Role::User,
50 content: format!(
51 "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code."
52 ),
53 }],
54 stream: true,
55 };
56 let api_key = env::var("OPENAI_API_KEY").unwrap();
57 let response = stream_completion(api_key, cx.background().clone(), request);
58 let editor = editor.downgrade();
59 self.pending_edits_by_editor.insert(
60 editor.id(),
61 cx.spawn(|mut cx| {
62 async move {
63 let selection_start = selection.start.to_offset(&buffer);
64
65 // Find unique words in the selected text to use as diff boundaries.
66 let mut duplicate_words = HashSet::default();
67 let mut unique_old_words = HashMap::default();
68 for (range, word) in words(&selected_text) {
69 if !duplicate_words.contains(word) {
70 if unique_old_words.insert(word, range.end).is_some() {
71 unique_old_words.remove(word);
72 duplicate_words.insert(word);
73 }
74 }
75 }
76
77 let mut new_text = String::new();
78 let mut messages = response.await?;
79 let mut new_word_search_start_ix = 0;
80 let mut last_old_word_end_ix = 0;
81
82 'outer: loop {
83 let start = new_word_search_start_ix;
84 let mut words = words(&new_text[start..]);
85 while let Some((range, new_word)) = words.next() {
86 // We found a word in the new text that was unique in the old text. We can use
87 // it as a diff boundary, and start applying edits.
88 if let Some(old_word_end_ix) = unique_old_words.remove(new_word) {
89 if old_word_end_ix > last_old_word_end_ix {
90 drop(words);
91
92 let remainder = new_text.split_off(start + range.end);
93 let edits = diff(
94 selection_start + last_old_word_end_ix,
95 &selected_text[last_old_word_end_ix..old_word_end_ix],
96 &new_text,
97 &buffer,
98 );
99 editor.update(&mut cx, |editor, cx| {
100 editor
101 .buffer()
102 .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
103 })?;
104
105 new_text = remainder;
106 new_word_search_start_ix = 0;
107 last_old_word_end_ix = old_word_end_ix;
108 continue 'outer;
109 }
110 }
111
112 new_word_search_start_ix = start + range.end;
113 }
114 drop(words);
115
116 // Buffer incoming text, stopping if the stream was exhausted.
117 if let Some(message) = messages.next().await {
118 let mut message = message?;
119 if let Some(choice) = message.choices.pop() {
120 if let Some(text) = choice.delta.content {
121 new_text.push_str(&text);
122 }
123 }
124 } else {
125 break;
126 }
127 }
128
129 let edits = diff(
130 selection_start + last_old_word_end_ix,
131 &selected_text[last_old_word_end_ix..],
132 &new_text,
133 &buffer,
134 );
135 editor.update(&mut cx, |editor, cx| {
136 editor
137 .buffer()
138 .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
139 })?;
140
141 anyhow::Ok(())
142 }
143 .log_err()
144 }),
145 );
146 }
147}
148
149struct RefactoringModal {
150 editor: WeakViewHandle<Editor>,
151 prompt_editor: ViewHandle<Editor>,
152 has_focus: bool,
153}
154
155impl Entity for RefactoringModal {
156 type Event = ();
157}
158
159impl View for RefactoringModal {
160 fn ui_name() -> &'static str {
161 "RefactoringModal"
162 }
163
164 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
165 ChildView::new(&self.prompt_editor, cx).into_any()
166 }
167
168 fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
169 self.has_focus = true;
170 }
171
172 fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
173 self.has_focus = false;
174 }
175}
176
177impl Modal for RefactoringModal {
178 fn has_focus(&self) -> bool {
179 self.has_focus
180 }
181
182 fn dismiss_on_event(event: &Self::Event) -> bool {
183 // TODO
184 false
185 }
186}
187
188impl RefactoringModal {
189 fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
190 if let Some(editor) = workspace
191 .active_item(cx)
192 .and_then(|item| Some(item.downcast::<Editor>()?.downgrade()))
193 {
194 workspace.toggle_modal(cx, |_, cx| {
195 let prompt_editor = cx.add_view(|cx| {
196 Editor::auto_height(
197 4,
198 Some(Arc::new(|theme| theme.search.editor.input.clone())),
199 cx,
200 )
201 });
202 cx.add_view(|_| RefactoringModal {
203 editor,
204 prompt_editor,
205 has_focus: false,
206 })
207 });
208 }
209 }
210
211 fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
212 if let Some(editor) = self.editor.upgrade(cx) {
213 let prompt = self.prompt_editor.read(cx).text(cx);
214 cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
215 assistant.refactor(&editor, &prompt, cx);
216 });
217 }
218 }
219}
220fn words(text: &str) -> impl Iterator<Item = (Range<usize>, &str)> {
221 let mut word_start_ix = None;
222 let mut chars = text.char_indices();
223 iter::from_fn(move || {
224 while let Some((ix, ch)) = chars.next() {
225 if let Some(start_ix) = word_start_ix {
226 if !ch.is_alphanumeric() {
227 let word = &text[start_ix..ix];
228 word_start_ix.take();
229 return Some((start_ix..ix, word));
230 }
231 } else {
232 if ch.is_alphanumeric() {
233 word_start_ix = Some(ix);
234 }
235 }
236 }
237 None
238 })
239}
240
241fn diff<'a>(
242 start_ix: usize,
243 old_text: &'a str,
244 new_text: &'a str,
245 old_buffer_snapshot: &MultiBufferSnapshot,
246) -> Vec<(Range<Anchor>, &'a str)> {
247 let mut edit_start = start_ix;
248 let mut edits = Vec::new();
249 let diff = similar::TextDiff::from_words(old_text, &new_text);
250 for change in diff.iter_all_changes() {
251 let value = change.value();
252 let edit_end = edit_start + value.len();
253 match change.tag() {
254 ChangeTag::Equal => {
255 edit_start = edit_end;
256 }
257 ChangeTag::Delete => {
258 edits.push((
259 old_buffer_snapshot.anchor_after(edit_start)
260 ..old_buffer_snapshot.anchor_before(edit_end),
261 "",
262 ));
263 edit_start = edit_end;
264 }
265 ChangeTag::Insert => {
266 edits.push((
267 old_buffer_snapshot.anchor_after(edit_start)
268 ..old_buffer_snapshot.anchor_after(edit_start),
269 value,
270 ));
271 }
272 }
273 }
274 edits
275}