Supermaven enhanced (#11521)

Kyle Kelley , max , and jacob created

Fixes #11422 by accepting just the start of the line.

Release Notes:

- N/A

---------

Co-authored-by: max <max@zed.dev>
Co-authored-by: jacob <jacob@supermaven.com>

Change summary

crates/copilot/src/copilot_completion_provider.rs       |  8 
crates/editor/src/editor.rs                             |  1 
crates/editor/src/inline_completion_provider.rs         |  2 
crates/supermaven/src/supermaven.rs                     | 75 +++++++++-
crates/supermaven/src/supermaven_completion_provider.rs | 62 ++++----
5 files changed, 100 insertions(+), 48 deletions(-)

Detailed changes

crates/copilot/src/copilot_completion_provider.rs 🔗

@@ -215,12 +215,12 @@ impl InlineCompletionProvider for CopilotCompletionProvider {
         }
     }
 
-    fn active_completion_text(
-        &self,
+    fn active_completion_text<'a>(
+        &'a self,
         buffer: &Model<Buffer>,
         cursor_position: language::Anchor,
-        cx: &AppContext,
-    ) -> Option<&str> {
+        cx: &'a AppContext,
+    ) -> Option<&'a str> {
         let buffer_id = buffer.entity_id();
         let buffer = buffer.read(cx);
         let completion = self.active_completion()?;

crates/editor/src/editor.rs 🔗

@@ -4356,6 +4356,7 @@ impl Editor {
                 text: completion.text.to_string().into(),
             });
             self.insert_with_autoindent_mode(&completion.text.to_string(), None, cx);
+            self.refresh_inline_completion(true, cx);
             cx.notify();
             true
         } else {

crates/editor/src/inline_completion_provider.rs 🔗

@@ -30,7 +30,7 @@ pub trait InlineCompletionProvider: 'static + Sized {
         buffer: &Model<Buffer>,
         cursor_position: language::Anchor,
         cx: &'a AppContext,
-    ) -> Option<&str>;
+    ) -> Option<&'a str>;
 }
 
 pub trait InlineCompletionProviderHandle {

crates/supermaven/src/supermaven.rs 🔗

@@ -10,7 +10,9 @@ use collections::BTreeMap;
 
 use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt};
 use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel};
-use language::{language_settings::all_language_settings, Anchor, Buffer, ToOffset};
+use language::{
+    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, ToOffset,
+};
 use messages::*;
 use postage::watch;
 use serde::{Deserialize, Serialize};
@@ -19,7 +21,7 @@ use smol::{
     io::AsyncWriteExt,
     process::{Child, ChildStdin, ChildStdout, Command},
 };
-use std::{ops::Range, path::PathBuf, process::Stdio, sync::Arc};
+use std::{path::PathBuf, process::Stdio, sync::Arc};
 use ui::prelude::*;
 use util::ResultExt;
 
@@ -128,9 +130,9 @@ impl Supermaven {
                 state_id,
                 SupermavenCompletionState {
                     buffer_id,
-                    range: cursor_position.bias_left(buffer)..cursor_position.bias_right(buffer),
-                    completion: Vec::new(),
+                    prefix_anchor: cursor_position,
                     text: String::new(),
+                    dedent: String::new(),
                     updates_tx,
                 },
             );
@@ -158,16 +160,64 @@ impl Supermaven {
 
     pub fn completion(
         &self,
-        id: SupermavenCompletionStateId,
-    ) -> Option<&SupermavenCompletionState> {
+        buffer: &Model<Buffer>,
+        cursor_position: Anchor,
+        cx: &AppContext,
+    ) -> Option<&str> {
         if let Self::Spawned(agent) = self {
-            agent.states.get(&id)
+            find_relevant_completion(
+                &agent.states,
+                buffer.entity_id(),
+                &buffer.read(cx).snapshot(),
+                cursor_position,
+            )
         } else {
             None
         }
     }
 }
 
+fn find_relevant_completion<'a>(
+    states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
+    buffer_id: EntityId,
+    buffer: &BufferSnapshot,
+    cursor_position: Anchor,
+) -> Option<&'a str> {
+    let mut best_completion: Option<&str> = None;
+    'completions: for state in states.values() {
+        if state.buffer_id != buffer_id {
+            continue;
+        }
+        let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
+            continue;
+        };
+
+        let current_cursor_offset = cursor_position.to_offset(buffer);
+        let original_cursor_offset = state.prefix_anchor.to_offset(buffer);
+        if current_cursor_offset < original_cursor_offset {
+            continue;
+        }
+
+        let text_inserted_since_completion_request =
+            buffer.text_for_range(original_cursor_offset..current_cursor_offset);
+        let mut trimmed_completion = state_completion;
+        for chunk in text_inserted_since_completion_request {
+            if let Some(suffix) = trimmed_completion.strip_prefix(chunk) {
+                trimmed_completion = suffix;
+            } else {
+                continue 'completions;
+            }
+        }
+
+        if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) {
+            continue;
+        }
+
+        best_completion = Some(trimmed_completion);
+    }
+    best_completion
+}
+
 pub struct SupermavenAgent {
     _process: Child,
     next_state_id: SupermavenCompletionStateId,
@@ -311,11 +361,12 @@ impl SupermavenAgent {
                 let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
                 if let Some(state) = self.states.get_mut(&state_id) {
                     for item in &response.items {
-                        if let ResponseItem::Text { text } = item {
-                            state.text.push_str(text);
+                        match item {
+                            ResponseItem::Text { text } => state.text.push_str(text),
+                            ResponseItem::Dedent { text } => state.dedent.push_str(text),
+                            _ => {}
                         }
                     }
-                    state.completion.extend(response.items);
                     *state.updates_tx.borrow_mut() = ();
                 }
             }
@@ -333,9 +384,9 @@ pub struct SupermavenCompletionStateId(usize);
 #[allow(dead_code)]
 pub struct SupermavenCompletionState {
     buffer_id: EntityId,
-    range: Range<Anchor>,
-    completion: Vec<ResponseItem>,
+    prefix_anchor: Anchor,
     text: String,
+    dedent: String,
     updates_tx: watch::Sender<()>,
 }
 

crates/supermaven/src/supermaven_completion_provider.rs 🔗

@@ -3,9 +3,7 @@ use anyhow::Result;
 use editor::{Direction, InlineCompletionProvider};
 use futures::StreamExt as _;
 use gpui::{AppContext, Model, ModelContext, Task};
-use language::{
-    language_settings::all_language_settings, Anchor, Buffer, OffsetRangeExt as _, ToOffset,
-};
+use language::{language_settings::all_language_settings, Anchor, Buffer};
 use std::time::Duration;
 
 pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
@@ -92,29 +90,16 @@ impl InlineCompletionProvider for SupermavenCompletionProvider {
         cursor_position: Anchor,
         cx: &'a AppContext,
     ) -> Option<&'a str> {
-        let completion_id = self.completion_id?;
-        let buffer = buffer.read(cx);
-        let cursor_offset = cursor_position.to_offset(buffer);
-        let completion = self.supermaven.read(cx).completion(completion_id)?;
-
-        let mut completion_range = completion.range.to_offset(buffer);
-
-        let prefix_len = common_prefix(
-            buffer.chars_for_range(completion_range.clone()),
-            completion.text.chars(),
-        );
-        completion_range.start += prefix_len;
-        let suffix_len = common_prefix(
-            buffer.reversed_chars_for_range(completion_range.clone()),
-            completion.text[prefix_len..].chars().rev(),
-        );
-        completion_range.end = completion_range.end.saturating_sub(suffix_len);
-
-        let completion_text = &completion.text[prefix_len..completion.text.len() - suffix_len];
-        if completion_range.is_empty()
-            && completion_range.start == cursor_offset
-            && !completion_text.trim().is_empty()
-        {
+        let completion_text = self
+            .supermaven
+            .read(cx)
+            .completion(buffer, cursor_position, cx)?;
+
+        let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text);
+
+        let completion_text = completion_text.trim_end();
+
+        if !completion_text.trim().is_empty() {
             Some(completion_text)
         } else {
             None
@@ -122,9 +107,24 @@ impl InlineCompletionProvider for SupermavenCompletionProvider {
     }
 }
 
-fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
-    a.zip(b)
-        .take_while(|(a, b)| a == b)
-        .map(|(a, _)| a.len_utf8())
-        .sum()
+fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str {
+    if has_leading_newline(&text) {
+        text
+    } else if let Some(i) = text.find('\n') {
+        &text[..i]
+    } else {
+        text
+    }
+}
+
+fn has_leading_newline(text: &str) -> bool {
+    for c in text.chars() {
+        if c == '\n' {
+            return true;
+        }
+        if !c.is_whitespace() {
+            return false;
+        }
+    }
+    false
 }