1use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
2use collections::HashMap;
3use editor::{Editor, ToOffset};
4use futures::{channel::mpsc, SinkExt, StreamExt};
5use gpui::{
6 actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
7 WeakViewHandle,
8};
9use menu::Confirm;
10use similar::{Change, ChangeTag, TextDiff};
11use std::{env, iter, ops::Range, sync::Arc};
12use util::TryFutureExt;
13use workspace::{Modal, Workspace};
14
15actions!(assistant, [Refactor]);
16
17pub fn init(cx: &mut AppContext) {
18 cx.set_global(RefactoringAssistant::new());
19 cx.add_action(RefactoringModal::deploy);
20 cx.add_action(RefactoringModal::confirm);
21}
22
23pub struct RefactoringAssistant {
24 pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
25}
26
27impl RefactoringAssistant {
28 fn new() -> Self {
29 Self {
30 pending_edits_by_editor: Default::default(),
31 }
32 }
33
34 fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
35 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
36 let selection = editor.read(cx).selections.newest_anchor().clone();
37 let selected_text = snapshot
38 .text_for_range(selection.start..selection.end)
39 .collect::<String>();
40 let language_name = snapshot
41 .language_at(selection.start)
42 .map(|language| language.name());
43 let language_name = language_name.as_deref().unwrap_or("");
44 let request = OpenAIRequest {
45 model: "gpt-4".into(),
46 messages: vec![
47 RequestMessage {
48 role: Role::User,
49 content: format!(
50 "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Never make remarks and reply only with the new code. Never change the leading whitespace on each line."
51 ),
52 }],
53 stream: true,
54 };
55 let api_key = env::var("OPENAI_API_KEY").unwrap();
56 let response = stream_completion(api_key, cx.background().clone(), request);
57 let editor = editor.downgrade();
58 self.pending_edits_by_editor.insert(
59 editor.id(),
60 cx.spawn(|mut cx| {
61 async move {
62 let mut edit_start = selection.start.to_offset(&snapshot);
63
64 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
65 let diff = cx.background().spawn(async move {
66 let mut messages = response.await?.ready_chunks(4);
67 let mut diff = crate::diff::Diff::new(selected_text);
68
69 while let Some(messages) = messages.next().await {
70 let mut new_text = String::new();
71 for message in messages {
72 let mut message = message?;
73 if let Some(choice) = message.choices.pop() {
74 if let Some(text) = choice.delta.content {
75 new_text.push_str(&text);
76 }
77 }
78 }
79
80 let hunks = diff.push_new(&new_text);
81 hunks_tx.send((hunks, new_text)).await?;
82 }
83
84 hunks_tx.send((diff.finish(), String::new())).await?;
85
86 anyhow::Ok(())
87 });
88
89 while let Some((hunks, new_text)) = hunks_rx.next().await {
90 editor.update(&mut cx, |editor, cx| {
91 editor.buffer().update(cx, |buffer, cx| {
92 buffer.start_transaction(cx);
93 for hunk in hunks {
94 match hunk {
95 crate::diff::Hunk::Insert { text } => {
96 let edit_start = snapshot.anchor_after(edit_start);
97 buffer.edit([(edit_start..edit_start, text)], None, cx);
98 }
99 crate::diff::Hunk::Remove { len } => {
100 let edit_end = edit_start + len;
101 let edit_range = snapshot.anchor_after(edit_start)
102 ..snapshot.anchor_before(edit_end);
103 buffer.edit([(edit_range, "")], None, cx);
104 edit_start = edit_end;
105 }
106 crate::diff::Hunk::Keep { len } => {
107 edit_start += len;
108 }
109 }
110 }
111 buffer.end_transaction(cx);
112 })
113 })?;
114 }
115
116 diff.await?;
117 anyhow::Ok(())
118 }
119 .log_err()
120 }),
121 );
122 }
123}
124
125struct RefactoringModal {
126 editor: WeakViewHandle<Editor>,
127 prompt_editor: ViewHandle<Editor>,
128 has_focus: bool,
129}
130
131impl Entity for RefactoringModal {
132 type Event = ();
133}
134
135impl View for RefactoringModal {
136 fn ui_name() -> &'static str {
137 "RefactoringModal"
138 }
139
140 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
141 ChildView::new(&self.prompt_editor, cx).into_any()
142 }
143
144 fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
145 self.has_focus = true;
146 }
147
148 fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
149 self.has_focus = false;
150 }
151}
152
153impl Modal for RefactoringModal {
154 fn has_focus(&self) -> bool {
155 self.has_focus
156 }
157
158 fn dismiss_on_event(event: &Self::Event) -> bool {
159 // TODO
160 false
161 }
162}
163
164impl RefactoringModal {
165 fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
166 if let Some(editor) = workspace
167 .active_item(cx)
168 .and_then(|item| Some(item.downcast::<Editor>()?.downgrade()))
169 {
170 workspace.toggle_modal(cx, |_, cx| {
171 let prompt_editor = cx.add_view(|cx| {
172 let mut editor = Editor::auto_height(
173 4,
174 Some(Arc::new(|theme| theme.search.editor.input.clone())),
175 cx,
176 );
177 editor.set_text("Replace with match statement.", cx);
178 editor
179 });
180 cx.add_view(|_| RefactoringModal {
181 editor,
182 prompt_editor,
183 has_focus: false,
184 })
185 });
186 }
187 }
188
189 fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
190 if let Some(editor) = self.editor.upgrade(cx) {
191 let prompt = self.prompt_editor.read(cx).text(cx);
192 cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
193 assistant.refactor(&editor, &prompt, cx);
194 });
195 }
196 }
197}