supermaven_completion_provider.rs

  1use crate::{Supermaven, SupermavenCompletionStateId};
  2use anyhow::Result;
  3use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
  4use futures::StreamExt as _;
  5use gpui::{App, Context, Entity, EntityId, Task};
  6use language::{Anchor, Buffer, BufferSnapshot};
  7use std::{
  8    ops::{AddAssign, Range},
  9    path::Path,
 10    time::Duration,
 11};
 12use text::{ToOffset, ToPoint};
 13use unicode_segmentation::UnicodeSegmentation;
 14
 15pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
 16
 17pub struct SupermavenCompletionProvider {
 18    supermaven: Entity<Supermaven>,
 19    buffer_id: Option<EntityId>,
 20    completion_id: Option<SupermavenCompletionStateId>,
 21    completion_text: Option<String>,
 22    file_extension: Option<String>,
 23    pending_refresh: Option<Task<Result<()>>>,
 24    completion_position: Option<language::Anchor>,
 25}
 26
 27impl SupermavenCompletionProvider {
 28    pub const fn new(supermaven: Entity<Supermaven>) -> Self {
 29        Self {
 30            supermaven,
 31            buffer_id: None,
 32            completion_id: None,
 33            completion_text: None,
 34            file_extension: None,
 35            pending_refresh: None,
 36            completion_position: None,
 37        }
 38    }
 39}
 40
 41// Computes the edit prediction from the difference between the completion text.
 42// This is defined by greedily matching the buffer text against the completion text.
 43// Inlays are inserted for parts of the completion text that are not present in the buffer text.
 44// For example, given the completion text "axbyc" and the buffer text "xy", the rendered output in the editor would be "[a]x[b]y[c]".
 45// The parts in brackets are the inlays.
 46fn completion_from_diff(
 47    snapshot: BufferSnapshot,
 48    completion_text: &str,
 49    position: Anchor,
 50    delete_range: Range<Anchor>,
 51) -> EditPrediction {
 52    let buffer_text = snapshot.text_for_range(delete_range).collect::<String>();
 53
 54    let mut edits: Vec<(Range<language::Anchor>, String)> = Vec::new();
 55
 56    let completion_graphemes: Vec<&str> = completion_text.graphemes(true).collect();
 57    let buffer_graphemes: Vec<&str> = buffer_text.graphemes(true).collect();
 58
 59    let mut offset = position.to_offset(&snapshot);
 60
 61    let mut i = 0;
 62    let mut j = 0;
 63    while i < completion_graphemes.len() && j < buffer_graphemes.len() {
 64        // find the next instance of the buffer text in the completion text.
 65        let k = completion_graphemes[i..]
 66            .iter()
 67            .position(|c| *c == buffer_graphemes[j]);
 68        match k {
 69            Some(k) => {
 70                if k != 0 {
 71                    let offset = snapshot.anchor_after(offset);
 72                    // the range from the current position to item is an inlay.
 73                    let edit = (offset..offset, completion_graphemes[i..i + k].join(""));
 74                    edits.push(edit);
 75                }
 76                i += k + 1;
 77                j += 1;
 78                offset.add_assign(buffer_graphemes[j - 1].len());
 79            }
 80            None => {
 81                // there are no more matching completions, so drop the remaining
 82                // completion text as an inlay.
 83                break;
 84            }
 85        }
 86    }
 87
 88    if j == buffer_graphemes.len() && i < completion_graphemes.len() {
 89        let offset = snapshot.anchor_after(offset);
 90        // there is leftover completion text, so drop it as an inlay.
 91        let edit_range = offset..offset;
 92        let edit_text = completion_graphemes[i..].join("");
 93        edits.push((edit_range, edit_text));
 94    }
 95
 96    EditPrediction::Local {
 97        id: None,
 98        edits,
 99        edit_preview: None,
100    }
101}
102
103impl EditPredictionProvider for SupermavenCompletionProvider {
104    fn name() -> &'static str {
105        "supermaven"
106    }
107
108    fn display_name() -> &'static str {
109        "Supermaven"
110    }
111
112    fn show_completions_in_menu() -> bool {
113        true
114    }
115
116    fn show_tab_accept_marker() -> bool {
117        true
118    }
119
120    fn supports_jump_to_edit() -> bool {
121        false
122    }
123
124    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
125        self.supermaven.read(cx).is_enabled()
126    }
127
128    fn is_refreshing(&self) -> bool {
129        self.pending_refresh.is_some() && self.completion_id.is_none()
130    }
131
132    fn refresh(
133        &mut self,
134        buffer_handle: Entity<Buffer>,
135        cursor_position: Anchor,
136        debounce: bool,
137        cx: &mut Context<Self>,
138    ) {
139        // Only make new completion requests when debounce is true (i.e., when text is typed)
140        // When debounce is false (i.e., cursor movement), we should not make new requests
141        if !debounce {
142            return;
143        }
144
145        reset_completion_cache(self, cx);
146
147        let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| {
148            supermaven.complete(&buffer_handle, cursor_position, cx)
149        }) else {
150            return;
151        };
152
153        self.pending_refresh = Some(cx.spawn(async move |this, cx| {
154            if debounce {
155                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
156            }
157
158            while let Some(()) = completion.updates.next().await {
159                this.update(cx, |this, cx| {
160                    // Get the completion text and cache it
161                    if let Some(text) =
162                        this.supermaven
163                            .read(cx)
164                            .completion(&buffer_handle, cursor_position, cx)
165                    {
166                        this.completion_text = Some(text.to_string());
167
168                        this.completion_position = Some(cursor_position);
169                    }
170
171                    this.completion_id = Some(completion.id);
172                    this.buffer_id = Some(buffer_handle.entity_id());
173                    this.file_extension = buffer_handle.read(cx).file().and_then(|file| {
174                        Some(
175                            Path::new(file.file_name(cx))
176                                .extension()?
177                                .to_str()?
178                                .to_string(),
179                        )
180                    });
181                    cx.notify();
182                })?;
183            }
184            Ok(())
185        }));
186    }
187
188    fn cycle(
189        &mut self,
190        _buffer: Entity<Buffer>,
191        _cursor_position: Anchor,
192        _direction: Direction,
193        _cx: &mut Context<Self>,
194    ) {
195    }
196
197    fn accept(&mut self, _cx: &mut Context<Self>) {
198        reset_completion_cache(self, _cx);
199    }
200
201    fn discard(&mut self, _cx: &mut Context<Self>) {
202        reset_completion_cache(self, _cx);
203    }
204
205    fn suggest(
206        &mut self,
207        buffer: &Entity<Buffer>,
208        cursor_position: Anchor,
209        cx: &mut Context<Self>,
210    ) -> Option<EditPrediction> {
211        if self.buffer_id != Some(buffer.entity_id()) {
212            return None;
213        }
214
215        if self.completion_id.is_none() {
216            return None;
217        }
218
219        let completion_text = if let Some(cached_text) = &self.completion_text {
220            cached_text.as_str()
221        } else {
222            let text = self
223                .supermaven
224                .read(cx)
225                .completion(buffer, cursor_position, cx)?;
226            self.completion_text = Some(text.to_string());
227            text
228        };
229
230        // Check if the cursor is still at the same position as the completion request
231        // If we don't have a completion position stored, don't show the completion
232        if let Some(completion_position) = self.completion_position {
233            if cursor_position != completion_position {
234                return None;
235            }
236        } else {
237            return None;
238        }
239
240        let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text);
241
242        let completion_text = completion_text.trim_end();
243
244        if !completion_text.trim().is_empty() {
245            let snapshot = buffer.read(cx).snapshot();
246
247            // Calculate the range from cursor to end of line correctly
248            let cursor_point = cursor_position.to_point(&snapshot);
249            let end_of_line = snapshot.anchor_after(language::Point::new(
250                cursor_point.row,
251                snapshot.line_len(cursor_point.row),
252            ));
253            let delete_range = cursor_position..end_of_line;
254
255            Some(completion_from_diff(
256                snapshot,
257                completion_text,
258                cursor_position,
259                delete_range,
260            ))
261        } else {
262            None
263        }
264    }
265}
266
267fn reset_completion_cache(
268    provider: &mut SupermavenCompletionProvider,
269    _cx: &mut Context<SupermavenCompletionProvider>,
270) {
271    provider.pending_refresh = None;
272    provider.completion_id = None;
273    provider.completion_text = None;
274    provider.completion_position = None;
275    provider.buffer_id = None;
276}
277
278fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str {
279    if has_leading_newline(text) {
280        text
281    } else if let Some(i) = text.find('\n') {
282        &text[..i]
283    } else {
284        text
285    }
286}
287
288fn has_leading_newline(text: &str) -> bool {
289    for c in text.chars() {
290        if c == '\n' {
291            return true;
292        }
293        if !c.is_whitespace() {
294            return false;
295        }
296    }
297    false
298}