@@ -10,7 +10,9 @@ use collections::BTreeMap;
use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel};
-use language::{language_settings::all_language_settings, Anchor, Buffer, ToOffset};
+use language::{
+ language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, ToOffset,
+};
use messages::*;
use postage::watch;
use serde::{Deserialize, Serialize};
@@ -19,7 +21,7 @@ use smol::{
io::AsyncWriteExt,
process::{Child, ChildStdin, ChildStdout, Command},
};
-use std::{ops::Range, path::PathBuf, process::Stdio, sync::Arc};
+use std::{path::PathBuf, process::Stdio, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
@@ -128,9 +130,9 @@ impl Supermaven {
state_id,
SupermavenCompletionState {
buffer_id,
- range: cursor_position.bias_left(buffer)..cursor_position.bias_right(buffer),
- completion: Vec::new(),
+ prefix_anchor: cursor_position,
text: String::new(),
+ dedent: String::new(),
updates_tx,
},
);
@@ -158,16 +160,64 @@ impl Supermaven {
pub fn completion(
&self,
- id: SupermavenCompletionStateId,
- ) -> Option<&SupermavenCompletionState> {
+ buffer: &Model<Buffer>,
+ cursor_position: Anchor,
+ cx: &AppContext,
+ ) -> Option<&str> {
if let Self::Spawned(agent) = self {
- agent.states.get(&id)
+ find_relevant_completion(
+ &agent.states,
+ buffer.entity_id(),
+ &buffer.read(cx).snapshot(),
+ cursor_position,
+ )
} else {
None
}
}
}
+fn find_relevant_completion<'a>(
+ states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
+ buffer_id: EntityId,
+ buffer: &BufferSnapshot,
+ cursor_position: Anchor,
+) -> Option<&'a str> {
+ let mut best_completion: Option<&str> = None;
+ 'completions: for state in states.values() {
+ if state.buffer_id != buffer_id {
+ continue;
+ }
+ let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
+ continue;
+ };
+
+ let current_cursor_offset = cursor_position.to_offset(buffer);
+ let original_cursor_offset = state.prefix_anchor.to_offset(buffer);
+ if current_cursor_offset < original_cursor_offset {
+ continue;
+ }
+
+ let text_inserted_since_completion_request =
+ buffer.text_for_range(original_cursor_offset..current_cursor_offset);
+ let mut trimmed_completion = state_completion;
+ for chunk in text_inserted_since_completion_request {
+ if let Some(suffix) = trimmed_completion.strip_prefix(chunk) {
+ trimmed_completion = suffix;
+ } else {
+ continue 'completions;
+ }
+ }
+
+ if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) {
+ continue;
+ }
+
+ best_completion = Some(trimmed_completion);
+ }
+ best_completion
+}
+
pub struct SupermavenAgent {
_process: Child,
next_state_id: SupermavenCompletionStateId,
@@ -311,11 +361,12 @@ impl SupermavenAgent {
let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
if let Some(state) = self.states.get_mut(&state_id) {
for item in &response.items {
- if let ResponseItem::Text { text } = item {
- state.text.push_str(text);
+ match item {
+ ResponseItem::Text { text } => state.text.push_str(text),
+ ResponseItem::Dedent { text } => state.dedent.push_str(text),
+ _ => {}
}
}
- state.completion.extend(response.items);
*state.updates_tx.borrow_mut() = ();
}
}
@@ -333,9 +384,9 @@ pub struct SupermavenCompletionStateId(usize);
#[allow(dead_code)]
pub struct SupermavenCompletionState {
buffer_id: EntityId,
- range: Range<Anchor>,
- completion: Vec<ResponseItem>,
+ prefix_anchor: Anchor,
text: String,
+ dedent: String,
updates_tx: watch::Sender<()>,
}
@@ -3,9 +3,7 @@ use anyhow::Result;
use editor::{Direction, InlineCompletionProvider};
use futures::StreamExt as _;
use gpui::{AppContext, Model, ModelContext, Task};
-use language::{
- language_settings::all_language_settings, Anchor, Buffer, OffsetRangeExt as _, ToOffset,
-};
+use language::{language_settings::all_language_settings, Anchor, Buffer};
use std::time::Duration;
pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
@@ -92,29 +90,16 @@ impl InlineCompletionProvider for SupermavenCompletionProvider {
cursor_position: Anchor,
cx: &'a AppContext,
) -> Option<&'a str> {
- let completion_id = self.completion_id?;
- let buffer = buffer.read(cx);
- let cursor_offset = cursor_position.to_offset(buffer);
- let completion = self.supermaven.read(cx).completion(completion_id)?;
-
- let mut completion_range = completion.range.to_offset(buffer);
-
- let prefix_len = common_prefix(
- buffer.chars_for_range(completion_range.clone()),
- completion.text.chars(),
- );
- completion_range.start += prefix_len;
- let suffix_len = common_prefix(
- buffer.reversed_chars_for_range(completion_range.clone()),
- completion.text[prefix_len..].chars().rev(),
- );
- completion_range.end = completion_range.end.saturating_sub(suffix_len);
-
- let completion_text = &completion.text[prefix_len..completion.text.len() - suffix_len];
- if completion_range.is_empty()
- && completion_range.start == cursor_offset
- && !completion_text.trim().is_empty()
- {
+ let completion_text = self
+ .supermaven
+ .read(cx)
+ .completion(buffer, cursor_position, cx)?;
+
+ let completion_text = trim_to_end_of_line_unless_leading_newline(completion_text);
+
+ let completion_text = completion_text.trim_end();
+
+ if !completion_text.trim().is_empty() {
Some(completion_text)
} else {
None
@@ -122,9 +107,24 @@ impl InlineCompletionProvider for SupermavenCompletionProvider {
}
}
-fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
- a.zip(b)
- .take_while(|(a, b)| a == b)
- .map(|(a, _)| a.len_utf8())
- .sum()
+fn trim_to_end_of_line_unless_leading_newline(text: &str) -> &str {
+ if has_leading_newline(&text) {
+ text
+ } else if let Some(i) = text.find('\n') {
+ &text[..i]
+ } else {
+ text
+ }
+}
+
+fn has_leading_newline(text: &str) -> bool {
+ for c in text.chars() {
+ if c == '\n' {
+ return true;
+ }
+ if !c.is_whitespace() {
+ return false;
+ }
+ }
+ false
}