1use crate::{Supermaven, SupermavenCompletionStateId};
2use anyhow::Result;
3use edit_prediction_types::{EditPrediction, EditPredictionDelegate, EditPredictionIconSet};
4use futures::StreamExt as _;
5use gpui::{App, Context, Entity, EntityId, Task};
6use language::{Anchor, Buffer, BufferSnapshot};
7use std::{
8 ops::{AddAssign, Range},
9 path::Path,
10 sync::Arc,
11 time::Duration,
12};
13use text::{ToOffset, ToPoint};
14use ui::prelude::*;
15use unicode_segmentation::UnicodeSegmentation;
16
17pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
18
19pub struct SupermavenEditPredictionDelegate {
20 supermaven: Entity<Supermaven>,
21 buffer_id: Option<EntityId>,
22 completion_id: Option<SupermavenCompletionStateId>,
23 completion_text: Option<String>,
24 file_extension: Option<String>,
25 pending_refresh: Option<Task<Result<()>>>,
26 completion_position: Option<language::Anchor>,
27}
28
29impl SupermavenEditPredictionDelegate {
30 pub fn new(supermaven: Entity<Supermaven>) -> Self {
31 Self {
32 supermaven,
33 buffer_id: None,
34 completion_id: None,
35 completion_text: None,
36 file_extension: None,
37 pending_refresh: None,
38 completion_position: None,
39 }
40 }
41}
42
43// Computes the edit prediction from the difference between the completion text.
44// This is defined by greedily matching the buffer text against the completion text.
45// Inlays are inserted for parts of the completion text that are not present in the buffer text.
46// For example, given the completion text "axbyc" and the buffer text "xy", the rendered output in the editor would be "[a]x[b]y[c]".
47// The parts in brackets are the inlays.
48fn completion_from_diff(
49 snapshot: BufferSnapshot,
50 completion_text: &str,
51 position: Anchor,
52 delete_range: Range<Anchor>,
53) -> EditPrediction {
54 let buffer_text = snapshot.text_for_range(delete_range).collect::<String>();
55
56 let mut edits: Vec<(Range<language::Anchor>, Arc<str>)> = Vec::new();
57
58 let completion_graphemes: Vec<&str> = completion_text.graphemes(true).collect();
59 let buffer_graphemes: Vec<&str> = buffer_text.graphemes(true).collect();
60
61 let mut offset = position.to_offset(&snapshot);
62
63 let mut i = 0;
64 let mut j = 0;
65 while i < completion_graphemes.len() && j < buffer_graphemes.len() {
66 // find the next instance of the buffer text in the completion text.
67 let k = completion_graphemes[i..]
68 .iter()
69 .position(|c| *c == buffer_graphemes[j]);
70 match k {
71 Some(k) => {
72 if k != 0 {
73 let offset = snapshot.anchor_after(offset);
74 // the range from the current position to item is an inlay.
75 let edit = (
76 offset..offset,
77 completion_graphemes[i..i + k].join("").into(),
78 );
79 edits.push(edit);
80 }
81 i += k + 1;
82 j += 1;
83 offset.add_assign(buffer_graphemes[j - 1].len());
84 }
85 None => {
86 // there are no more matching completions, so drop the remaining
87 // completion text as an inlay.
88 break;
89 }
90 }
91 }
92
93 if j == buffer_graphemes.len() && i < completion_graphemes.len() {
94 let offset = snapshot.anchor_after(offset);
95 // there is leftover completion text, so drop it as an inlay.
96 let edit_range = offset..offset;
97 let edit_text = completion_graphemes[i..].join("");
98 edits.push((edit_range, edit_text.into()));
99 }
100
101 EditPrediction::Local {
102 id: None,
103 edits,
104 cursor_position: None,
105 edit_preview: None,
106 }
107}
108
109impl EditPredictionDelegate for SupermavenEditPredictionDelegate {
110 fn name() -> &'static str {
111 "supermaven"
112 }
113
114 fn display_name() -> &'static str {
115 "Supermaven"
116 }
117
118 fn show_predictions_in_menu() -> bool {
119 true
120 }
121
122 fn show_tab_accept_marker() -> bool {
123 true
124 }
125
126 fn supports_jump_to_edit() -> bool {
127 false
128 }
129
130 fn icons(&self, _cx: &App) -> EditPredictionIconSet {
131 EditPredictionIconSet::new(IconName::Supermaven)
132 .with_disabled(IconName::SupermavenDisabled)
133 .with_error(IconName::SupermavenError)
134 }
135
136 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
137 self.supermaven.read(cx).is_enabled()
138 }
139
140 fn is_refreshing(&self, _cx: &App) -> bool {
141 self.pending_refresh.is_some() && self.completion_id.is_none()
142 }
143
144 fn refresh(
145 &mut self,
146 buffer_handle: Entity<Buffer>,
147 cursor_position: Anchor,
148 debounce: bool,
149 cx: &mut Context<Self>,
150 ) {
151 // Only make new completion requests when debounce is true (i.e., when text is typed)
152 // When debounce is false (i.e., cursor movement), we should not make new requests
153 if !debounce {
154 return;
155 }
156
157 reset_completion_cache(self, cx);
158
159 let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| {
160 supermaven.complete(&buffer_handle, cursor_position, cx)
161 }) else {
162 return;
163 };
164
165 self.pending_refresh = Some(cx.spawn(async move |this, cx| {
166 if debounce {
167 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
168 }
169
170 while let Some(()) = completion.updates.next().await {
171 this.update(cx, |this, cx| {
172 // Get the completion text and cache it
173 if let Some(text) =
174 this.supermaven
175 .read(cx)
176 .completion(&buffer_handle, cursor_position, cx)
177 {
178 this.completion_text = Some(text.to_string());
179
180 this.completion_position = Some(cursor_position);
181 }
182
183 this.completion_id = Some(completion.id);
184 this.buffer_id = Some(buffer_handle.entity_id());
185 this.file_extension = buffer_handle.read(cx).file().and_then(|file| {
186 Some(
187 Path::new(file.file_name(cx))
188 .extension()?
189 .to_str()?
190 .to_string(),
191 )
192 });
193 cx.notify();
194 })?;
195 }
196 Ok(())
197 }));
198 }
199
200 fn accept(&mut self, _cx: &mut Context<Self>) {
201 reset_completion_cache(self, _cx);
202 }
203
204 fn discard(&mut self, _cx: &mut Context<Self>) {
205 reset_completion_cache(self, _cx);
206 }
207
208 fn suggest(
209 &mut self,
210 buffer: &Entity<Buffer>,
211 cursor_position: Anchor,
212 cx: &mut Context<Self>,
213 ) -> Option<EditPrediction> {
214 if self.buffer_id != Some(buffer.entity_id()) {
215 return None;
216 }
217
218 if self.completion_id.is_none() {
219 return None;
220 }
221
222 let completion_text = if let Some(cached_text) = &self.completion_text {
223 cached_text.as_str()
224 } else {
225 let text = self
226 .supermaven
227 .read(cx)
228 .completion(buffer, cursor_position, cx)?;
229 self.completion_text = Some(text.to_string());
230 text
231 };
232
233 // Check if the cursor is still at the same position as the completion request
234 // If we don't have a completion position stored, don't show the completion
235 if let Some(completion_position) = self.completion_position {
236 if cursor_position != completion_position {
237 return None;
238 }
239 } else {
240 return None;
241 }
242
243 let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text);
244
245 let completion_text = completion_text.trim_end();
246
247 if !completion_text.trim().is_empty() {
248 let snapshot = buffer.read(cx).snapshot();
249
250 // Calculate the range from cursor to end of line correctly
251 let cursor_point = cursor_position.to_point(&snapshot);
252 let end_of_line = snapshot.anchor_after(language::Point::new(
253 cursor_point.row,
254 snapshot.line_len(cursor_point.row),
255 ));
256 let delete_range = cursor_position..end_of_line;
257
258 Some(completion_from_diff(
259 snapshot,
260 completion_text,
261 cursor_position,
262 delete_range,
263 ))
264 } else {
265 None
266 }
267 }
268}
269
270fn reset_completion_cache(
271 provider: &mut SupermavenEditPredictionDelegate,
272 _cx: &mut Context<SupermavenEditPredictionDelegate>,
273) {
274 provider.pending_refresh = None;
275 provider.completion_id = None;
276 provider.completion_text = None;
277 provider.completion_position = None;
278 provider.buffer_id = None;
279}
280
281fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str {
282 if has_leading_newline(text) {
283 text
284 } else if let Some(i) = text.find('\n') {
285 &text[..i]
286 } else {
287 text
288 }
289}
290
291fn has_leading_newline(text: &str) -> bool {
292 for c in text.chars() {
293 if c == '\n' {
294 return true;
295 }
296 if !c.is_whitespace() {
297 return false;
298 }
299 }
300 false
301}