supermaven_edit_prediction_delegate.rs

  1use crate::{Supermaven, SupermavenCompletionStateId};
  2use anyhow::Result;
  3use edit_prediction_types::{
  4    EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
  5};
  6use futures::StreamExt as _;
  7use gpui::{App, Context, Entity, EntityId, Task};
  8use language::{Anchor, Buffer, BufferSnapshot};
  9use std::{
 10    ops::{AddAssign, Range},
 11    path::Path,
 12    sync::Arc,
 13    time::Duration,
 14};
 15use text::{ToOffset, ToPoint};
 16use ui::prelude::*;
 17use unicode_segmentation::UnicodeSegmentation;
 18
 19pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
 20
 21pub struct SupermavenEditPredictionDelegate {
 22    supermaven: Entity<Supermaven>,
 23    buffer_id: Option<EntityId>,
 24    completion_id: Option<SupermavenCompletionStateId>,
 25    completion_text: Option<String>,
 26    file_extension: Option<String>,
 27    pending_refresh: Option<Task<Result<()>>>,
 28    completion_position: Option<language::Anchor>,
 29}
 30
 31impl SupermavenEditPredictionDelegate {
 32    pub fn new(supermaven: Entity<Supermaven>) -> Self {
 33        Self {
 34            supermaven,
 35            buffer_id: None,
 36            completion_id: None,
 37            completion_text: None,
 38            file_extension: None,
 39            pending_refresh: None,
 40            completion_position: None,
 41        }
 42    }
 43}
 44
 45// Computes the edit prediction from the difference between the completion text.
 46// This is defined by greedily matching the buffer text against the completion text.
 47// Inlays are inserted for parts of the completion text that are not present in the buffer text.
 48// For example, given the completion text "axbyc" and the buffer text "xy", the rendered output in the editor would be "[a]x[b]y[c]".
 49// The parts in brackets are the inlays.
 50fn completion_from_diff(
 51    snapshot: BufferSnapshot,
 52    completion_text: &str,
 53    position: Anchor,
 54    delete_range: Range<Anchor>,
 55) -> EditPrediction {
 56    let buffer_text = snapshot.text_for_range(delete_range).collect::<String>();
 57
 58    let mut edits: Vec<(Range<language::Anchor>, Arc<str>)> = Vec::new();
 59
 60    let completion_graphemes: Vec<&str> = completion_text.graphemes(true).collect();
 61    let buffer_graphemes: Vec<&str> = buffer_text.graphemes(true).collect();
 62
 63    let mut offset = position.to_offset(&snapshot);
 64
 65    let mut i = 0;
 66    let mut j = 0;
 67    while i < completion_graphemes.len() && j < buffer_graphemes.len() {
 68        // find the next instance of the buffer text in the completion text.
 69        let k = completion_graphemes[i..]
 70            .iter()
 71            .position(|c| *c == buffer_graphemes[j]);
 72        match k {
 73            Some(k) => {
 74                if k != 0 {
 75                    let offset = snapshot.anchor_after(offset);
 76                    // the range from the current position to item is an inlay.
 77                    let edit = (
 78                        offset..offset,
 79                        completion_graphemes[i..i + k].join("").into(),
 80                    );
 81                    edits.push(edit);
 82                }
 83                i += k + 1;
 84                j += 1;
 85                offset.add_assign(buffer_graphemes[j - 1].len());
 86            }
 87            None => {
 88                // there are no more matching completions, so drop the remaining
 89                // completion text as an inlay.
 90                break;
 91            }
 92        }
 93    }
 94
 95    if j == buffer_graphemes.len() && i < completion_graphemes.len() {
 96        let offset = snapshot.anchor_after(offset);
 97        // there is leftover completion text, so drop it as an inlay.
 98        let edit_range = offset..offset;
 99        let edit_text = completion_graphemes[i..].join("");
100        edits.push((edit_range, edit_text.into()));
101    }
102
103    EditPrediction::Local {
104        id: None,
105        edits,
106        cursor_position: None,
107        edit_preview: None,
108    }
109}
110
111impl EditPredictionDelegate for SupermavenEditPredictionDelegate {
112    fn name() -> &'static str {
113        "supermaven"
114    }
115
116    fn display_name() -> &'static str {
117        "Supermaven"
118    }
119
120    fn show_predictions_in_menu() -> bool {
121        true
122    }
123
124    fn show_tab_accept_marker() -> bool {
125        true
126    }
127
128    fn supports_jump_to_edit() -> bool {
129        false
130    }
131
132    fn icons(&self, _cx: &App) -> EditPredictionIconSet {
133        EditPredictionIconSet::new(IconName::Supermaven)
134            .with_disabled(IconName::SupermavenDisabled)
135            .with_error(IconName::SupermavenError)
136    }
137
138    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
139        self.supermaven.read(cx).is_enabled()
140    }
141
142    fn is_refreshing(&self, _cx: &App) -> bool {
143        self.pending_refresh.is_some() && self.completion_id.is_none()
144    }
145
146    fn refresh(
147        &mut self,
148        buffer_handle: Entity<Buffer>,
149        cursor_position: Anchor,
150        debounce: bool,
151        cx: &mut Context<Self>,
152    ) {
153        // Only make new completion requests when debounce is true (i.e., when text is typed)
154        // When debounce is false (i.e., cursor movement), we should not make new requests
155        if !debounce {
156            return;
157        }
158
159        reset_completion_cache(self, cx);
160
161        let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| {
162            supermaven.complete(&buffer_handle, cursor_position, cx)
163        }) else {
164            return;
165        };
166
167        self.pending_refresh = Some(cx.spawn(async move |this, cx| {
168            if debounce {
169                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
170            }
171
172            while let Some(()) = completion.updates.next().await {
173                this.update(cx, |this, cx| {
174                    // Get the completion text and cache it
175                    if let Some(text) =
176                        this.supermaven
177                            .read(cx)
178                            .completion(&buffer_handle, cursor_position, cx)
179                    {
180                        this.completion_text = Some(text.to_string());
181
182                        this.completion_position = Some(cursor_position);
183                    }
184
185                    this.completion_id = Some(completion.id);
186                    this.buffer_id = Some(buffer_handle.entity_id());
187                    this.file_extension = buffer_handle.read(cx).file().and_then(|file| {
188                        Some(
189                            Path::new(file.file_name(cx))
190                                .extension()?
191                                .to_str()?
192                                .to_string(),
193                        )
194                    });
195                    cx.notify();
196                })?;
197            }
198            Ok(())
199        }));
200    }
201
202    fn accept(&mut self, _cx: &mut Context<Self>) {
203        reset_completion_cache(self, _cx);
204    }
205
206    fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
207        reset_completion_cache(self, _cx);
208    }
209
210    fn suggest(
211        &mut self,
212        buffer: &Entity<Buffer>,
213        cursor_position: Anchor,
214        cx: &mut Context<Self>,
215    ) -> Option<EditPrediction> {
216        if self.buffer_id != Some(buffer.entity_id()) {
217            return None;
218        }
219
220        if self.completion_id.is_none() {
221            return None;
222        }
223
224        let completion_text = if let Some(cached_text) = &self.completion_text {
225            cached_text.as_str()
226        } else {
227            let text = self
228                .supermaven
229                .read(cx)
230                .completion(buffer, cursor_position, cx)?;
231            self.completion_text = Some(text.to_string());
232            text
233        };
234
235        // Check if the cursor is still at the same position as the completion request
236        // If we don't have a completion position stored, don't show the completion
237        if let Some(completion_position) = self.completion_position {
238            if cursor_position != completion_position {
239                return None;
240            }
241        } else {
242            return None;
243        }
244
245        let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text);
246
247        let completion_text = completion_text.trim_end();
248
249        if !completion_text.trim().is_empty() {
250            let snapshot = buffer.read(cx).snapshot();
251
252            // Calculate the range from cursor to end of line correctly
253            let cursor_point = cursor_position.to_point(&snapshot);
254            let end_of_line = snapshot.anchor_after(language::Point::new(
255                cursor_point.row,
256                snapshot.line_len(cursor_point.row),
257            ));
258            let delete_range = cursor_position..end_of_line;
259
260            Some(completion_from_diff(
261                snapshot,
262                completion_text,
263                cursor_position,
264                delete_range,
265            ))
266        } else {
267            None
268        }
269    }
270}
271
272fn reset_completion_cache(
273    provider: &mut SupermavenEditPredictionDelegate,
274    _cx: &mut Context<SupermavenEditPredictionDelegate>,
275) {
276    provider.pending_refresh = None;
277    provider.completion_id = None;
278    provider.completion_text = None;
279    provider.completion_position = None;
280    provider.buffer_id = None;
281}
282
283fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str {
284    if has_leading_newline(text) {
285        text
286    } else if let Some(i) = text.find('\n') {
287        &text[..i]
288    } else {
289        text
290    }
291}
292
293fn has_leading_newline(text: &str) -> bool {
294    for c in text.chars() {
295        if c == '\n' {
296            return true;
297        }
298        if !c.is_whitespace() {
299            return false;
300        }
301    }
302    false
303}