1use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
2use collections::HashMap;
3use editor::{Editor, ToOffset};
4use futures::{channel::mpsc, SinkExt, StreamExt};
5use gpui::{
6 actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
7 WeakViewHandle,
8};
9use menu::Confirm;
10use similar::{Change, ChangeTag, TextDiff};
11use std::{env, iter, ops::Range, sync::Arc};
12use util::TryFutureExt;
13use workspace::{Modal, Workspace};
14
15actions!(assistant, [Refactor]);
16
17pub fn init(cx: &mut AppContext) {
18 cx.set_global(RefactoringAssistant::new());
19 cx.add_action(RefactoringModal::deploy);
20 cx.add_action(RefactoringModal::confirm);
21}
22
23pub struct RefactoringAssistant {
24 pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
25}
26
27impl RefactoringAssistant {
28 fn new() -> Self {
29 Self {
30 pending_edits_by_editor: Default::default(),
31 }
32 }
33
34 fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
35 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
36 let selection = editor.read(cx).selections.newest_anchor().clone();
37 let selected_text = snapshot
38 .text_for_range(selection.start..selection.end)
39 .collect::<String>();
40 let language_name = snapshot
41 .language_at(selection.start)
42 .map(|language| language.name());
43 let language_name = language_name.as_deref().unwrap_or("");
44 let request = OpenAIRequest {
45 model: "gpt-4".into(),
46 messages: vec![
47 RequestMessage {
48 role: Role::User,
49 content: format!(
50 "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Never make remarks and reply only with the new code. Never change the leading whitespace on each line."
51 ),
52 }],
53 stream: true,
54 };
55 let api_key = env::var("OPENAI_API_KEY").unwrap();
56 let response = stream_completion(api_key, cx.background().clone(), request);
57 let editor = editor.downgrade();
58 self.pending_edits_by_editor.insert(
59 editor.id(),
60 cx.spawn(|mut cx| {
61 async move {
62 let mut edit_start = selection.start.to_offset(&snapshot);
63
64 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
65 let diff = cx.background().spawn(async move {
66 let mut messages = response.await?.ready_chunks(4);
67 let mut diff = crate::diff::Diff::new(selected_text);
68
69 while let Some(messages) = messages.next().await {
70 let mut new_text = String::new();
71 for message in messages {
72 let mut message = message?;
73 if let Some(choice) = message.choices.pop() {
74 if let Some(text) = choice.delta.content {
75 new_text.push_str(&text);
76 }
77 }
78 }
79
80 let hunks = diff.push_new(&new_text);
81 hunks_tx.send(hunks).await?;
82 }
83 hunks_tx.send(diff.finish()).await?;
84
85 anyhow::Ok(())
86 });
87
88 while let Some(hunks) = hunks_rx.next().await {
89 editor.update(&mut cx, |editor, cx| {
90 editor.buffer().update(cx, |buffer, cx| {
91 buffer.start_transaction(cx);
92 for hunk in hunks {
93 match hunk {
94 crate::diff::Hunk::Insert { text } => {
95 let edit_start = snapshot.anchor_after(edit_start);
96 buffer.edit([(edit_start..edit_start, text)], None, cx);
97 }
98 crate::diff::Hunk::Remove { len } => {
99 let edit_end = edit_start + len;
100 let edit_range = snapshot.anchor_after(edit_start)
101 ..snapshot.anchor_before(edit_end);
102 buffer.edit([(edit_range, "")], None, cx);
103 edit_start = edit_end;
104 }
105 crate::diff::Hunk::Keep { len } => {
106 edit_start += len;
107 }
108 }
109 }
110 buffer.end_transaction(cx);
111 })
112 })?;
113 }
114
115 diff.await?;
116 anyhow::Ok(())
117 }
118 .log_err()
119 }),
120 );
121 }
122}
123
124struct RefactoringModal {
125 editor: WeakViewHandle<Editor>,
126 prompt_editor: ViewHandle<Editor>,
127 has_focus: bool,
128}
129
130impl Entity for RefactoringModal {
131 type Event = ();
132}
133
134impl View for RefactoringModal {
135 fn ui_name() -> &'static str {
136 "RefactoringModal"
137 }
138
139 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
140 ChildView::new(&self.prompt_editor, cx).into_any()
141 }
142
143 fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
144 self.has_focus = true;
145 }
146
147 fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
148 self.has_focus = false;
149 }
150}
151
152impl Modal for RefactoringModal {
153 fn has_focus(&self) -> bool {
154 self.has_focus
155 }
156
157 fn dismiss_on_event(event: &Self::Event) -> bool {
158 // TODO
159 false
160 }
161}
162
163impl RefactoringModal {
164 fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
165 if let Some(editor) = workspace
166 .active_item(cx)
167 .and_then(|item| Some(item.downcast::<Editor>()?.downgrade()))
168 {
169 workspace.toggle_modal(cx, |_, cx| {
170 let prompt_editor = cx.add_view(|cx| {
171 let mut editor = Editor::auto_height(
172 4,
173 Some(Arc::new(|theme| theme.search.editor.input.clone())),
174 cx,
175 );
176 editor.set_text("Replace with match statement.", cx);
177 editor
178 });
179 cx.add_view(|_| RefactoringModal {
180 editor,
181 prompt_editor,
182 has_focus: false,
183 })
184 });
185 }
186 }
187
188 fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
189 if let Some(editor) = self.editor.upgrade(cx) {
190 let prompt = self.prompt_editor.read(cx).text(cx);
191 cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
192 assistant.refactor(&editor, &prompt, cx);
193 });
194 }
195 }
196}