supermaven_completion_provider.rs

  1use crate::{Supermaven, SupermavenCompletionStateId};
  2use anyhow::Result;
  3use client::telemetry::Telemetry;
  4use editor::{Direction, InlineCompletionProvider};
  5use futures::StreamExt as _;
  6use gpui::{AppContext, EntityId, Model, ModelContext, Task};
  7use language::{language_settings::all_language_settings, Anchor, Buffer};
  8use std::{path::Path, sync::Arc, time::Duration};
  9
 10pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
 11
 12pub struct SupermavenCompletionProvider {
 13    supermaven: Model<Supermaven>,
 14    buffer_id: Option<EntityId>,
 15    completion_id: Option<SupermavenCompletionStateId>,
 16    file_extension: Option<String>,
 17    pending_refresh: Task<Result<()>>,
 18    telemetry: Option<Arc<Telemetry>>,
 19}
 20
 21impl SupermavenCompletionProvider {
 22    pub fn new(supermaven: Model<Supermaven>) -> Self {
 23        Self {
 24            supermaven,
 25            buffer_id: None,
 26            completion_id: None,
 27            file_extension: None,
 28            pending_refresh: Task::ready(Ok(())),
 29            telemetry: None,
 30        }
 31    }
 32
 33    pub fn with_telemetry(mut self, telemetry: Arc<Telemetry>) -> Self {
 34        self.telemetry = Some(telemetry);
 35        self
 36    }
 37}
 38
 39impl InlineCompletionProvider for SupermavenCompletionProvider {
 40    fn name() -> &'static str {
 41        "supermaven"
 42    }
 43
 44    fn is_enabled(&self, buffer: &Model<Buffer>, cursor_position: Anchor, cx: &AppContext) -> bool {
 45        if !self.supermaven.read(cx).is_enabled() {
 46            return false;
 47        }
 48
 49        let buffer = buffer.read(cx);
 50        let file = buffer.file();
 51        let language = buffer.language_at(cursor_position);
 52        let settings = all_language_settings(file, cx);
 53        settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()))
 54    }
 55
 56    fn refresh(
 57        &mut self,
 58        buffer_handle: Model<Buffer>,
 59        cursor_position: Anchor,
 60        debounce: bool,
 61        cx: &mut ModelContext<Self>,
 62    ) {
 63        let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| {
 64            supermaven.complete(&buffer_handle, cursor_position, cx)
 65        }) else {
 66            return;
 67        };
 68
 69        self.pending_refresh = cx.spawn(|this, mut cx| async move {
 70            if debounce {
 71                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
 72            }
 73
 74            while let Some(()) = completion.updates.next().await {
 75                this.update(&mut cx, |this, cx| {
 76                    this.completion_id = Some(completion.id);
 77                    this.buffer_id = Some(buffer_handle.entity_id());
 78                    this.file_extension = buffer_handle.read(cx).file().and_then(|file| {
 79                        Some(
 80                            Path::new(file.file_name(cx))
 81                                .extension()?
 82                                .to_str()?
 83                                .to_string(),
 84                        )
 85                    });
 86                    cx.notify();
 87                })?;
 88            }
 89            Ok(())
 90        });
 91    }
 92
 93    fn cycle(
 94        &mut self,
 95        _buffer: Model<Buffer>,
 96        _cursor_position: Anchor,
 97        _direction: Direction,
 98        _cx: &mut ModelContext<Self>,
 99    ) {
100    }
101
102    fn accept(&mut self, _cx: &mut ModelContext<Self>) {
103        if self.completion_id.is_some() {
104            if let Some(telemetry) = self.telemetry.as_ref() {
105                telemetry.report_inline_completion_event(
106                    Self::name().to_string(),
107                    true,
108                    self.file_extension.clone(),
109                );
110            }
111        }
112        self.pending_refresh = Task::ready(Ok(()));
113        self.completion_id = None;
114    }
115
116    fn discard(
117        &mut self,
118        should_report_inline_completion_event: bool,
119        _cx: &mut ModelContext<Self>,
120    ) {
121        if should_report_inline_completion_event {
122            if self.completion_id.is_some() {
123                if let Some(telemetry) = self.telemetry.as_ref() {
124                    telemetry.report_inline_completion_event(
125                        Self::name().to_string(),
126                        false,
127                        self.file_extension.clone(),
128                    );
129                }
130            }
131        }
132
133        self.pending_refresh = Task::ready(Ok(()));
134        self.completion_id = None;
135    }
136
137    fn active_completion_text<'a>(
138        &'a self,
139        buffer: &Model<Buffer>,
140        cursor_position: Anchor,
141        cx: &'a AppContext,
142    ) -> Option<&'a str> {
143        let completion_text = self
144            .supermaven
145            .read(cx)
146            .completion(buffer, cursor_position, cx)?;
147
148        let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text);
149
150        let completion_text = completion_text.trim_end();
151
152        if !completion_text.trim().is_empty() {
153            Some(completion_text)
154        } else {
155            None
156        }
157    }
158}
159
160fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str {
161    if has_leading_newline(&text) {
162        text
163    } else if let Some(i) = text.find('\n') {
164        &text[..i]
165    } else {
166        text
167    }
168}
169
170fn has_leading_newline(text: &str) -> bool {
171    for c in text.chars() {
172        if c == '\n' {
173            return true;
174        }
175        if !c.is_whitespace() {
176            return false;
177        }
178    }
179    false
180}