1use crate::{Supermaven, SupermavenCompletionStateId};
2use anyhow::Result;
3use edit_prediction_types::{
4 EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
5};
6use futures::StreamExt as _;
7use gpui::{App, Context, Entity, EntityId, Task};
8use language::{Anchor, Buffer, BufferSnapshot};
9use std::{
10 ops::{AddAssign, Range},
11 path::Path,
12 sync::Arc,
13 time::Duration,
14};
15use text::{ToOffset, ToPoint};
16use ui::prelude::*;
17use unicode_segmentation::UnicodeSegmentation;
18
19pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
20
21pub struct SupermavenEditPredictionDelegate {
22 supermaven: Entity<Supermaven>,
23 buffer_id: Option<EntityId>,
24 completion_id: Option<SupermavenCompletionStateId>,
25 completion_text: Option<String>,
26 file_extension: Option<String>,
27 pending_refresh: Option<Task<Result<()>>>,
28 completion_position: Option<language::Anchor>,
29}
30
31impl SupermavenEditPredictionDelegate {
32 pub fn new(supermaven: Entity<Supermaven>) -> Self {
33 Self {
34 supermaven,
35 buffer_id: None,
36 completion_id: None,
37 completion_text: None,
38 file_extension: None,
39 pending_refresh: None,
40 completion_position: None,
41 }
42 }
43}
44
45// Computes the edit prediction from the difference between the completion text.
46// This is defined by greedily matching the buffer text against the completion text.
47// Inlays are inserted for parts of the completion text that are not present in the buffer text.
48// For example, given the completion text "axbyc" and the buffer text "xy", the rendered output in the editor would be "[a]x[b]y[c]".
49// The parts in brackets are the inlays.
50fn completion_from_diff(
51 snapshot: BufferSnapshot,
52 completion_text: &str,
53 position: Anchor,
54 delete_range: Range<Anchor>,
55) -> EditPrediction {
56 let buffer_text = snapshot.text_for_range(delete_range).collect::<String>();
57
58 let mut edits: Vec<(Range<language::Anchor>, Arc<str>)> = Vec::new();
59
60 let completion_graphemes: Vec<&str> = completion_text.graphemes(true).collect();
61 let buffer_graphemes: Vec<&str> = buffer_text.graphemes(true).collect();
62
63 let mut offset = position.to_offset(&snapshot);
64
65 let mut i = 0;
66 let mut j = 0;
67 while i < completion_graphemes.len() && j < buffer_graphemes.len() {
68 // find the next instance of the buffer text in the completion text.
69 let k = completion_graphemes[i..]
70 .iter()
71 .position(|c| *c == buffer_graphemes[j]);
72 match k {
73 Some(k) => {
74 if k != 0 {
75 let offset = snapshot.anchor_after(offset);
76 // the range from the current position to item is an inlay.
77 let edit = (
78 offset..offset,
79 completion_graphemes[i..i + k].join("").into(),
80 );
81 edits.push(edit);
82 }
83 i += k + 1;
84 j += 1;
85 offset.add_assign(buffer_graphemes[j - 1].len());
86 }
87 None => {
88 // there are no more matching completions, so drop the remaining
89 // completion text as an inlay.
90 break;
91 }
92 }
93 }
94
95 if j == buffer_graphemes.len() && i < completion_graphemes.len() {
96 let offset = snapshot.anchor_after(offset);
97 // there is leftover completion text, so drop it as an inlay.
98 let edit_range = offset..offset;
99 let edit_text = completion_graphemes[i..].join("");
100 edits.push((edit_range, edit_text.into()));
101 }
102
103 EditPrediction::Local {
104 id: None,
105 edits,
106 cursor_position: None,
107 edit_preview: None,
108 }
109}
110
111impl EditPredictionDelegate for SupermavenEditPredictionDelegate {
112 fn name() -> &'static str {
113 "supermaven"
114 }
115
116 fn display_name() -> &'static str {
117 "Supermaven"
118 }
119
120 fn show_predictions_in_menu() -> bool {
121 true
122 }
123
124 fn show_tab_accept_marker() -> bool {
125 true
126 }
127
128 fn supports_jump_to_edit() -> bool {
129 false
130 }
131
132 fn icons(&self, _cx: &App) -> EditPredictionIconSet {
133 EditPredictionIconSet::new(IconName::Supermaven)
134 .with_disabled(IconName::SupermavenDisabled)
135 .with_error(IconName::SupermavenError)
136 }
137
138 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
139 self.supermaven.read(cx).is_enabled()
140 }
141
142 fn is_refreshing(&self, _cx: &App) -> bool {
143 self.pending_refresh.is_some() && self.completion_id.is_none()
144 }
145
146 fn refresh(
147 &mut self,
148 buffer_handle: Entity<Buffer>,
149 cursor_position: Anchor,
150 debounce: bool,
151 cx: &mut Context<Self>,
152 ) {
153 // Only make new completion requests when debounce is true (i.e., when text is typed)
154 // When debounce is false (i.e., cursor movement), we should not make new requests
155 if !debounce {
156 return;
157 }
158
159 reset_completion_cache(self, cx);
160
161 let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| {
162 supermaven.complete(&buffer_handle, cursor_position, cx)
163 }) else {
164 return;
165 };
166
167 self.pending_refresh = Some(cx.spawn(async move |this, cx| {
168 if debounce {
169 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
170 }
171
172 while let Some(()) = completion.updates.next().await {
173 this.update(cx, |this, cx| {
174 // Get the completion text and cache it
175 if let Some(text) =
176 this.supermaven
177 .read(cx)
178 .completion(&buffer_handle, cursor_position, cx)
179 {
180 this.completion_text = Some(text.to_string());
181
182 this.completion_position = Some(cursor_position);
183 }
184
185 this.completion_id = Some(completion.id);
186 this.buffer_id = Some(buffer_handle.entity_id());
187 this.file_extension = buffer_handle.read(cx).file().and_then(|file| {
188 Some(
189 Path::new(file.file_name(cx))
190 .extension()?
191 .to_str()?
192 .to_string(),
193 )
194 });
195 cx.notify();
196 })?;
197 }
198 Ok(())
199 }));
200 }
201
202 fn accept(&mut self, _cx: &mut Context<Self>) {
203 reset_completion_cache(self, _cx);
204 }
205
206 fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
207 reset_completion_cache(self, _cx);
208 }
209
210 fn suggest(
211 &mut self,
212 buffer: &Entity<Buffer>,
213 cursor_position: Anchor,
214 cx: &mut Context<Self>,
215 ) -> Option<EditPrediction> {
216 if self.buffer_id != Some(buffer.entity_id()) {
217 return None;
218 }
219
220 if self.completion_id.is_none() {
221 return None;
222 }
223
224 let completion_text = if let Some(cached_text) = &self.completion_text {
225 cached_text.as_str()
226 } else {
227 let text = self
228 .supermaven
229 .read(cx)
230 .completion(buffer, cursor_position, cx)?;
231 self.completion_text = Some(text.to_string());
232 text
233 };
234
235 // Check if the cursor is still at the same position as the completion request
236 // If we don't have a completion position stored, don't show the completion
237 if let Some(completion_position) = self.completion_position {
238 if cursor_position != completion_position {
239 return None;
240 }
241 } else {
242 return None;
243 }
244
245 let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text);
246
247 let completion_text = completion_text.trim_end();
248
249 if !completion_text.trim().is_empty() {
250 let snapshot = buffer.read(cx).snapshot();
251
252 // Calculate the range from cursor to end of line correctly
253 let cursor_point = cursor_position.to_point(&snapshot);
254 let end_of_line = snapshot.anchor_after(language::Point::new(
255 cursor_point.row,
256 snapshot.line_len(cursor_point.row),
257 ));
258 let delete_range = cursor_position..end_of_line;
259
260 Some(completion_from_diff(
261 snapshot,
262 completion_text,
263 cursor_position,
264 delete_range,
265 ))
266 } else {
267 None
268 }
269 }
270}
271
272fn reset_completion_cache(
273 provider: &mut SupermavenEditPredictionDelegate,
274 _cx: &mut Context<SupermavenEditPredictionDelegate>,
275) {
276 provider.pending_refresh = None;
277 provider.completion_id = None;
278 provider.completion_text = None;
279 provider.completion_position = None;
280 provider.buffer_id = None;
281}
282
283fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str {
284 if has_leading_newline(text) {
285 text
286 } else if let Some(i) = text.find('\n') {
287 &text[..i]
288 } else {
289 text
290 }
291}
292
293fn has_leading_newline(text: &str) -> bool {
294 for c in text.chars() {
295 if c == '\n' {
296 return true;
297 }
298 if !c.is_whitespace() {
299 return false;
300 }
301 }
302 false
303}