supermaven_completion_provider.rs

  1use crate::{Supermaven, SupermavenCompletionStateId};
  2use anyhow::Result;
  3use client::telemetry::Telemetry;
  4use editor::{Direction, InlineCompletionProvider};
  5use futures::StreamExt as _;
  6use gpui::{AppContext, EntityId, Model, ModelContext, Task};
  7use language::{language_settings::all_language_settings, Anchor, Buffer};
  8use std::{ops::Range, path::Path, sync::Arc, time::Duration};
  9use text::ToPoint;
 10
 11pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
 12
 13pub struct SupermavenCompletionProvider {
 14    supermaven: Model<Supermaven>,
 15    buffer_id: Option<EntityId>,
 16    completion_id: Option<SupermavenCompletionStateId>,
 17    file_extension: Option<String>,
 18    pending_refresh: Task<Result<()>>,
 19    telemetry: Option<Arc<Telemetry>>,
 20}
 21
 22impl SupermavenCompletionProvider {
 23    pub fn new(supermaven: Model<Supermaven>) -> Self {
 24        Self {
 25            supermaven,
 26            buffer_id: None,
 27            completion_id: None,
 28            file_extension: None,
 29            pending_refresh: Task::ready(Ok(())),
 30            telemetry: None,
 31        }
 32    }
 33
 34    pub fn with_telemetry(mut self, telemetry: Arc<Telemetry>) -> Self {
 35        self.telemetry = Some(telemetry);
 36        self
 37    }
 38}
 39
 40impl InlineCompletionProvider for SupermavenCompletionProvider {
 41    fn name() -> &'static str {
 42        "supermaven"
 43    }
 44
 45    fn is_enabled(&self, buffer: &Model<Buffer>, cursor_position: Anchor, cx: &AppContext) -> bool {
 46        if !self.supermaven.read(cx).is_enabled() {
 47            return false;
 48        }
 49
 50        let buffer = buffer.read(cx);
 51        let file = buffer.file();
 52        let language = buffer.language_at(cursor_position);
 53        let settings = all_language_settings(file, cx);
 54        settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()))
 55    }
 56
 57    fn refresh(
 58        &mut self,
 59        buffer_handle: Model<Buffer>,
 60        cursor_position: Anchor,
 61        debounce: bool,
 62        cx: &mut ModelContext<Self>,
 63    ) {
 64        let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| {
 65            supermaven.complete(&buffer_handle, cursor_position, cx)
 66        }) else {
 67            return;
 68        };
 69
 70        self.pending_refresh = cx.spawn(|this, mut cx| async move {
 71            if debounce {
 72                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
 73            }
 74
 75            while let Some(()) = completion.updates.next().await {
 76                this.update(&mut cx, |this, cx| {
 77                    this.completion_id = Some(completion.id);
 78                    this.buffer_id = Some(buffer_handle.entity_id());
 79                    this.file_extension = buffer_handle.read(cx).file().and_then(|file| {
 80                        Some(
 81                            Path::new(file.file_name(cx))
 82                                .extension()?
 83                                .to_str()?
 84                                .to_string(),
 85                        )
 86                    });
 87                    cx.notify();
 88                })?;
 89            }
 90            Ok(())
 91        });
 92    }
 93
 94    fn cycle(
 95        &mut self,
 96        _buffer: Model<Buffer>,
 97        _cursor_position: Anchor,
 98        _direction: Direction,
 99        _cx: &mut ModelContext<Self>,
100    ) {
101    }
102
103    fn accept(&mut self, _cx: &mut ModelContext<Self>) {
104        if self.completion_id.is_some() {
105            if let Some(telemetry) = self.telemetry.as_ref() {
106                telemetry.report_inline_completion_event(
107                    Self::name().to_string(),
108                    true,
109                    self.file_extension.clone(),
110                );
111            }
112        }
113        self.pending_refresh = Task::ready(Ok(()));
114        self.completion_id = None;
115    }
116
117    fn discard(
118        &mut self,
119        should_report_inline_completion_event: bool,
120        _cx: &mut ModelContext<Self>,
121    ) {
122        if should_report_inline_completion_event {
123            if self.completion_id.is_some() {
124                if let Some(telemetry) = self.telemetry.as_ref() {
125                    telemetry.report_inline_completion_event(
126                        Self::name().to_string(),
127                        false,
128                        self.file_extension.clone(),
129                    );
130                }
131            }
132        }
133
134        self.pending_refresh = Task::ready(Ok(()));
135        self.completion_id = None;
136    }
137
138    fn active_completion_text<'a>(
139        &'a self,
140        buffer: &Model<Buffer>,
141        cursor_position: Anchor,
142        cx: &'a AppContext,
143    ) -> Option<(&'a str, Option<Range<Anchor>>)> {
144        let completion_text = self
145            .supermaven
146            .read(cx)
147            .completion(buffer, cursor_position, cx)?;
148
149        let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text);
150
151        let completion_text = completion_text.trim_end();
152
153        if !completion_text.trim().is_empty() {
154            let snapshot = buffer.read(cx).snapshot();
155            let mut point = cursor_position.to_point(&snapshot);
156            point.column = snapshot.line_len(point.row);
157            let range = cursor_position..snapshot.anchor_after(point);
158            Some((completion_text, Some(range)))
159        } else {
160            None
161        }
162    }
163}
164
165fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str {
166    if has_leading_newline(&text) {
167        text
168    } else if let Some(i) = text.find('\n') {
169        &text[..i]
170    } else {
171        text
172    }
173}
174
175fn has_leading_newline(text: &str) -> bool {
176    for c in text.chars() {
177        if c == '\n' {
178            return true;
179        }
180        if !c.is_whitespace() {
181            return false;
182        }
183    }
184    false
185}