supermaven_edit_prediction_delegate.rs

  1use crate::{Supermaven, SupermavenCompletionStateId};
  2use anyhow::Result;
  3use edit_prediction_types::{EditPrediction, EditPredictionDelegate, EditPredictionIconSet};
  4use futures::StreamExt as _;
  5use gpui::{App, Context, Entity, EntityId, Task};
  6use language::{Anchor, Buffer, BufferSnapshot};
  7use std::{
  8    ops::{AddAssign, Range},
  9    path::Path,
 10    sync::Arc,
 11    time::Duration,
 12};
 13use text::{ToOffset, ToPoint};
 14use ui::prelude::*;
 15use unicode_segmentation::UnicodeSegmentation;
 16
 17pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
 18
 19pub struct SupermavenEditPredictionDelegate {
 20    supermaven: Entity<Supermaven>,
 21    buffer_id: Option<EntityId>,
 22    completion_id: Option<SupermavenCompletionStateId>,
 23    completion_text: Option<String>,
 24    file_extension: Option<String>,
 25    pending_refresh: Option<Task<Result<()>>>,
 26    completion_position: Option<language::Anchor>,
 27}
 28
 29impl SupermavenEditPredictionDelegate {
 30    pub fn new(supermaven: Entity<Supermaven>) -> Self {
 31        Self {
 32            supermaven,
 33            buffer_id: None,
 34            completion_id: None,
 35            completion_text: None,
 36            file_extension: None,
 37            pending_refresh: None,
 38            completion_position: None,
 39        }
 40    }
 41}
 42
 43// Computes the edit prediction from the difference between the completion text.
 44// This is defined by greedily matching the buffer text against the completion text.
 45// Inlays are inserted for parts of the completion text that are not present in the buffer text.
 46// For example, given the completion text "axbyc" and the buffer text "xy", the rendered output in the editor would be "[a]x[b]y[c]".
 47// The parts in brackets are the inlays.
 48fn completion_from_diff(
 49    snapshot: BufferSnapshot,
 50    completion_text: &str,
 51    position: Anchor,
 52    delete_range: Range<Anchor>,
 53) -> EditPrediction {
 54    let buffer_text = snapshot.text_for_range(delete_range).collect::<String>();
 55
 56    let mut edits: Vec<(Range<language::Anchor>, Arc<str>)> = Vec::new();
 57
 58    let completion_graphemes: Vec<&str> = completion_text.graphemes(true).collect();
 59    let buffer_graphemes: Vec<&str> = buffer_text.graphemes(true).collect();
 60
 61    let mut offset = position.to_offset(&snapshot);
 62
 63    let mut i = 0;
 64    let mut j = 0;
 65    while i < completion_graphemes.len() && j < buffer_graphemes.len() {
 66        // find the next instance of the buffer text in the completion text.
 67        let k = completion_graphemes[i..]
 68            .iter()
 69            .position(|c| *c == buffer_graphemes[j]);
 70        match k {
 71            Some(k) => {
 72                if k != 0 {
 73                    let offset = snapshot.anchor_after(offset);
 74                    // the range from the current position to item is an inlay.
 75                    let edit = (
 76                        offset..offset,
 77                        completion_graphemes[i..i + k].join("").into(),
 78                    );
 79                    edits.push(edit);
 80                }
 81                i += k + 1;
 82                j += 1;
 83                offset.add_assign(buffer_graphemes[j - 1].len());
 84            }
 85            None => {
 86                // there are no more matching completions, so drop the remaining
 87                // completion text as an inlay.
 88                break;
 89            }
 90        }
 91    }
 92
 93    if j == buffer_graphemes.len() && i < completion_graphemes.len() {
 94        let offset = snapshot.anchor_after(offset);
 95        // there is leftover completion text, so drop it as an inlay.
 96        let edit_range = offset..offset;
 97        let edit_text = completion_graphemes[i..].join("");
 98        edits.push((edit_range, edit_text.into()));
 99    }
100
101    EditPrediction::Local {
102        id: None,
103        edits,
104        cursor_position: None,
105        edit_preview: None,
106    }
107}
108
109impl EditPredictionDelegate for SupermavenEditPredictionDelegate {
110    fn name() -> &'static str {
111        "supermaven"
112    }
113
114    fn display_name() -> &'static str {
115        "Supermaven"
116    }
117
118    fn show_predictions_in_menu() -> bool {
119        true
120    }
121
122    fn show_tab_accept_marker() -> bool {
123        true
124    }
125
126    fn supports_jump_to_edit() -> bool {
127        false
128    }
129
130    fn icons(&self, _cx: &App) -> EditPredictionIconSet {
131        EditPredictionIconSet::new(IconName::Supermaven)
132            .with_disabled(IconName::SupermavenDisabled)
133            .with_error(IconName::SupermavenError)
134    }
135
136    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
137        self.supermaven.read(cx).is_enabled()
138    }
139
140    fn is_refreshing(&self, _cx: &App) -> bool {
141        self.pending_refresh.is_some() && self.completion_id.is_none()
142    }
143
144    fn refresh(
145        &mut self,
146        buffer_handle: Entity<Buffer>,
147        cursor_position: Anchor,
148        debounce: bool,
149        cx: &mut Context<Self>,
150    ) {
151        // Only make new completion requests when debounce is true (i.e., when text is typed)
152        // When debounce is false (i.e., cursor movement), we should not make new requests
153        if !debounce {
154            return;
155        }
156
157        reset_completion_cache(self, cx);
158
159        let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| {
160            supermaven.complete(&buffer_handle, cursor_position, cx)
161        }) else {
162            return;
163        };
164
165        self.pending_refresh = Some(cx.spawn(async move |this, cx| {
166            if debounce {
167                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
168            }
169
170            while let Some(()) = completion.updates.next().await {
171                this.update(cx, |this, cx| {
172                    // Get the completion text and cache it
173                    if let Some(text) =
174                        this.supermaven
175                            .read(cx)
176                            .completion(&buffer_handle, cursor_position, cx)
177                    {
178                        this.completion_text = Some(text.to_string());
179
180                        this.completion_position = Some(cursor_position);
181                    }
182
183                    this.completion_id = Some(completion.id);
184                    this.buffer_id = Some(buffer_handle.entity_id());
185                    this.file_extension = buffer_handle.read(cx).file().and_then(|file| {
186                        Some(
187                            Path::new(file.file_name(cx))
188                                .extension()?
189                                .to_str()?
190                                .to_string(),
191                        )
192                    });
193                    cx.notify();
194                })?;
195            }
196            Ok(())
197        }));
198    }
199
200    fn accept(&mut self, _cx: &mut Context<Self>) {
201        reset_completion_cache(self, _cx);
202    }
203
204    fn discard(&mut self, _cx: &mut Context<Self>) {
205        reset_completion_cache(self, _cx);
206    }
207
208    fn suggest(
209        &mut self,
210        buffer: &Entity<Buffer>,
211        cursor_position: Anchor,
212        cx: &mut Context<Self>,
213    ) -> Option<EditPrediction> {
214        if self.buffer_id != Some(buffer.entity_id()) {
215            return None;
216        }
217
218        if self.completion_id.is_none() {
219            return None;
220        }
221
222        let completion_text = if let Some(cached_text) = &self.completion_text {
223            cached_text.as_str()
224        } else {
225            let text = self
226                .supermaven
227                .read(cx)
228                .completion(buffer, cursor_position, cx)?;
229            self.completion_text = Some(text.to_string());
230            text
231        };
232
233        // Check if the cursor is still at the same position as the completion request
234        // If we don't have a completion position stored, don't show the completion
235        if let Some(completion_position) = self.completion_position {
236            if cursor_position != completion_position {
237                return None;
238            }
239        } else {
240            return None;
241        }
242
243        let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text);
244
245        let completion_text = completion_text.trim_end();
246
247        if !completion_text.trim().is_empty() {
248            let snapshot = buffer.read(cx).snapshot();
249
250            // Calculate the range from cursor to end of line correctly
251            let cursor_point = cursor_position.to_point(&snapshot);
252            let end_of_line = snapshot.anchor_after(language::Point::new(
253                cursor_point.row,
254                snapshot.line_len(cursor_point.row),
255            ));
256            let delete_range = cursor_position..end_of_line;
257
258            Some(completion_from_diff(
259                snapshot,
260                completion_text,
261                cursor_position,
262                delete_range,
263            ))
264        } else {
265            None
266        }
267    }
268}
269
270fn reset_completion_cache(
271    provider: &mut SupermavenEditPredictionDelegate,
272    _cx: &mut Context<SupermavenEditPredictionDelegate>,
273) {
274    provider.pending_refresh = None;
275    provider.completion_id = None;
276    provider.completion_text = None;
277    provider.completion_position = None;
278    provider.buffer_id = None;
279}
280
281fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str {
282    if has_leading_newline(text) {
283        text
284    } else if let Some(i) = text.find('\n') {
285        &text[..i]
286    } else {
287        text
288    }
289}
290
291fn has_leading_newline(text: &str) -> bool {
292    for c in text.chars() {
293        if c == '\n' {
294            return true;
295        }
296        if !c.is_whitespace() {
297            return false;
298        }
299    }
300    false
301}