supermaven.rs

  1mod messages;
  2mod supermaven_edit_prediction_delegate;
  3
  4pub use supermaven_edit_prediction_delegate::*;
  5
  6use anyhow::{Context as _, Result};
  7#[allow(unused_imports)]
  8use client::{Client, proto};
  9use collections::BTreeMap;
 10
 11use futures::{AsyncBufReadExt, StreamExt, channel::mpsc, io::BufReader};
 12use gpui::{App, AsyncApp, Context, Entity, EntityId, Global, Task, WeakEntity, actions};
 13use language::{
 14    Anchor, Buffer, BufferSnapshot, ToOffset, language_settings::all_language_settings,
 15};
 16use messages::*;
 17use postage::watch;
 18use serde::{Deserialize, Serialize};
 19use settings::SettingsStore;
 20use smol::io::AsyncWriteExt;
 21use std::{path::PathBuf, sync::Arc};
 22use ui::prelude::*;
 23use util::ResultExt;
 24use util::command::Child;
 25use util::command::Stdio;
 26
 27actions!(
 28    supermaven,
 29    [
 30        /// Signs out of Supermaven.
 31        SignOut
 32    ]
 33);
 34
 35pub fn init(client: Arc<Client>, cx: &mut App) {
 36    let supermaven = cx.new(|_| Supermaven::Starting);
 37    Supermaven::set_global(supermaven.clone(), cx);
 38
 39    let mut provider = all_language_settings(None, cx).edit_predictions.provider;
 40    if provider == language::language_settings::EditPredictionProvider::Supermaven {
 41        supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 42    }
 43
 44    cx.observe_global::<SettingsStore>(move |cx| {
 45        let new_provider = all_language_settings(None, cx).edit_predictions.provider;
 46        if new_provider != provider {
 47            provider = new_provider;
 48            if provider == language::language_settings::EditPredictionProvider::Supermaven {
 49                supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 50            } else {
 51                supermaven.update(cx, |supermaven, _cx| supermaven.stop());
 52            }
 53        }
 54    })
 55    .detach();
 56
 57    cx.on_action(|_: &SignOut, cx| {
 58        if let Some(supermaven) = Supermaven::global(cx) {
 59            supermaven.update(cx, |supermaven, _cx| supermaven.sign_out());
 60        }
 61    });
 62}
 63
 64pub enum Supermaven {
 65    Starting,
 66    FailedDownload { error: anyhow::Error },
 67    Spawned(SupermavenAgent),
 68    Error { error: anyhow::Error },
 69}
 70
 71#[derive(Clone)]
 72pub enum AccountStatus {
 73    Unknown,
 74    NeedsActivation { activate_url: String },
 75    Ready,
 76}
 77
 78#[derive(Clone)]
 79struct SupermavenGlobal(Entity<Supermaven>);
 80
 81impl Global for SupermavenGlobal {}
 82
 83impl Supermaven {
 84    pub fn global(cx: &App) -> Option<Entity<Self>> {
 85        cx.try_global::<SupermavenGlobal>()
 86            .map(|model| model.0.clone())
 87    }
 88
 89    pub fn set_global(supermaven: Entity<Self>, cx: &mut App) {
 90        cx.set_global(SupermavenGlobal(supermaven));
 91    }
 92
 93    pub fn start(&mut self, client: Arc<Client>, cx: &mut Context<Self>) {
 94        if let Self::Starting = self {
 95            cx.spawn(async move |this, cx| {
 96                let binary_path =
 97                    supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
 98
 99                this.update(cx, |this, cx| {
100                    if let Self::Starting = this {
101                        *this =
102                            Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
103                    }
104                    anyhow::Ok(())
105                })
106            })
107            .detach_and_log_err(cx)
108        }
109    }
110
111    pub fn stop(&mut self) {
112        *self = Self::Starting;
113    }
114
115    pub fn is_enabled(&self) -> bool {
116        matches!(self, Self::Spawned { .. })
117    }
118
119    pub fn complete(
120        &mut self,
121        buffer: &Entity<Buffer>,
122        cursor_position: Anchor,
123        cx: &App,
124    ) -> Option<SupermavenCompletion> {
125        if let Self::Spawned(agent) = self {
126            let buffer_id = buffer.entity_id();
127            let buffer = buffer.read(cx);
128            let path = buffer
129                .file()
130                .and_then(|file| Some(file.as_local()?.abs_path(cx)))
131                .unwrap_or_else(|| PathBuf::from("untitled"))
132                .to_string_lossy()
133                .to_string();
134            let content = buffer.text();
135            let offset = cursor_position.to_offset(buffer);
136            let state_id = agent.next_state_id;
137            agent.next_state_id.0 += 1;
138
139            let (updates_tx, mut updates_rx) = watch::channel();
140            postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
141
142            agent.states.insert(
143                state_id,
144                SupermavenCompletionState {
145                    buffer_id,
146                    prefix_anchor: cursor_position,
147                    prefix_offset: offset,
148                    text: String::new(),
149                    dedent: String::new(),
150                    updates_tx,
151                },
152            );
153            // ensure the states map is max 1000 elements
154            if agent.states.len() > 1000 {
155                // state id is monotonic so it's sufficient to remove the first element
156                agent
157                    .states
158                    .remove(&agent.states.keys().next().unwrap().clone());
159            }
160
161            let _ = agent
162                .outgoing_tx
163                .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
164                    new_id: state_id.0.to_string(),
165                    updates: vec![
166                        StateUpdate::FileUpdate(FileUpdateMessage {
167                            path: path.clone(),
168                            content,
169                        }),
170                        StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
171                    ],
172                }));
173
174            Some(SupermavenCompletion {
175                id: state_id,
176                updates: updates_rx,
177            })
178        } else {
179            None
180        }
181    }
182
183    pub fn completion(
184        &self,
185        buffer: &Entity<Buffer>,
186        cursor_position: Anchor,
187        cx: &App,
188    ) -> Option<&str> {
189        if let Self::Spawned(agent) = self {
190            find_relevant_completion(
191                &agent.states,
192                buffer.entity_id(),
193                &buffer.read(cx).snapshot(),
194                cursor_position,
195            )
196        } else {
197            None
198        }
199    }
200
201    pub fn sign_out(&mut self) {
202        if let Self::Spawned(agent) = self {
203            agent
204                .outgoing_tx
205                .unbounded_send(OutboundMessage::Logout)
206                .ok();
207            // The account status will get set to RequiresActivation or Ready when the next
208            // message from the agent comes in. Until that happens, set the status to Unknown
209            // to disable the button.
210            agent.account_status = AccountStatus::Unknown;
211        }
212    }
213}
214
215fn find_relevant_completion<'a>(
216    states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
217    buffer_id: EntityId,
218    buffer: &BufferSnapshot,
219    cursor_position: Anchor,
220) -> Option<&'a str> {
221    let mut best_completion: Option<&str> = None;
222    'completions: for state in states.values() {
223        if state.buffer_id != buffer_id {
224            continue;
225        }
226        let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
227            continue;
228        };
229
230        let current_cursor_offset = cursor_position.to_offset(buffer);
231        if current_cursor_offset < state.prefix_offset {
232            continue;
233        }
234
235        let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left);
236        let text_inserted_since_completion_request: String = buffer
237            .text_for_range(original_cursor_offset..current_cursor_offset)
238            .collect();
239        let trimmed_completion =
240            match state_completion.strip_prefix(&text_inserted_since_completion_request) {
241                Some(suffix) => suffix,
242                None => continue 'completions,
243            };
244
245        if best_completion.is_some_and(|best| best.len() > trimmed_completion.len()) {
246            continue;
247        }
248
249        best_completion = Some(trimmed_completion);
250    }
251    best_completion
252}
253
254pub struct SupermavenAgent {
255    _process: Child,
256    next_state_id: SupermavenCompletionStateId,
257    states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
258    outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
259    _handle_outgoing_messages: Task<Result<()>>,
260    _handle_incoming_messages: Task<Result<()>>,
261    pub account_status: AccountStatus,
262    service_tier: Option<ServiceTier>,
263    #[allow(dead_code)]
264    client: Arc<Client>,
265}
266
267impl SupermavenAgent {
268    fn new(
269        binary_path: PathBuf,
270        client: Arc<Client>,
271        cx: &mut Context<Supermaven>,
272    ) -> Result<Self> {
273        let mut process = util::command::new_command(&binary_path)
274            .arg("stdio")
275            .stdin(Stdio::piped())
276            .stdout(Stdio::piped())
277            .stderr(Stdio::piped())
278            .kill_on_drop(true)
279            .spawn()
280            .context("failed to start the binary")?;
281
282        let stdin = process
283            .stdin
284            .take()
285            .context("failed to get stdin for process")?;
286        let stdout = process
287            .stdout
288            .take()
289            .context("failed to get stdout for process")?;
290
291        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
292
293        Ok(Self {
294            _process: process,
295            next_state_id: SupermavenCompletionStateId::default(),
296            states: BTreeMap::default(),
297            outgoing_tx,
298            _handle_outgoing_messages: cx.spawn(async move |_, _cx| {
299                Self::handle_outgoing_messages(outgoing_rx, stdin).await
300            }),
301            _handle_incoming_messages: cx.spawn(async move |this, cx| {
302                Self::handle_incoming_messages(this, stdout, cx).await
303            }),
304            account_status: AccountStatus::Unknown,
305            service_tier: None,
306            client,
307        })
308    }
309
310    async fn handle_outgoing_messages<W: smol::io::AsyncWrite + Unpin>(
311        mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
312        mut stdin: W,
313    ) -> Result<()> {
314        while let Some(message) = outgoing.next().await {
315            let bytes = serde_json::to_vec(&message)?;
316            stdin.write_all(&bytes).await?;
317            stdin.write_all(&[b'\n']).await?;
318        }
319        Ok(())
320    }
321
322    async fn handle_incoming_messages<R: smol::io::AsyncRead + Unpin>(
323        this: WeakEntity<Supermaven>,
324        stdout: R,
325        cx: &mut AsyncApp,
326    ) -> Result<()> {
327        const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
328
329        let stdout = BufReader::new(stdout);
330        let mut lines = stdout.lines();
331        while let Some(line) = lines.next().await {
332            let Some(line) = line.context("failed to read line from stdout").log_err() else {
333                continue;
334            };
335            let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
336                continue;
337            };
338            let Some(message) = serde_json::from_str::<SupermavenMessage>(line)
339                .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
340                .log_err()
341            else {
342                continue;
343            };
344
345            this.update(cx, |this, _cx| {
346                if let Supermaven::Spawned(this) = this {
347                    this.handle_message(message);
348                }
349                Task::ready(anyhow::Ok(()))
350            })?
351            .await?;
352        }
353
354        Ok(())
355    }
356
357    fn handle_message(&mut self, message: SupermavenMessage) {
358        match message {
359            SupermavenMessage::ActivationRequest(request) => {
360                self.account_status = match request.activate_url {
361                    Some(activate_url) => AccountStatus::NeedsActivation { activate_url },
362                    None => AccountStatus::Ready,
363                };
364            }
365            SupermavenMessage::ActivationSuccess => {
366                self.account_status = AccountStatus::Ready;
367            }
368            SupermavenMessage::ServiceTier { service_tier } => {
369                self.account_status = AccountStatus::Ready;
370                self.service_tier = Some(service_tier);
371            }
372            SupermavenMessage::Response(response) => {
373                let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
374                if let Some(state) = self.states.get_mut(&state_id) {
375                    for item in &response.items {
376                        match item {
377                            ResponseItem::Text { text } => state.text.push_str(text),
378                            ResponseItem::Dedent { text } => state.dedent.push_str(text),
379                            _ => {}
380                        }
381                    }
382                    *state.updates_tx.borrow_mut() = ();
383                }
384            }
385            SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
386            _ => {
387                log::warn!("unhandled message: {:?}", message);
388            }
389        }
390    }
391}
392
393#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
394pub struct SupermavenCompletionStateId(usize);
395
396#[allow(dead_code)]
397pub struct SupermavenCompletionState {
398    buffer_id: EntityId,
399    prefix_anchor: Anchor,
400    // prefix_offset is tracked independently because the anchor biases left which
401    // doesn't allow us to determine if the prior text has been deleted.
402    prefix_offset: usize,
403    text: String,
404    dedent: String,
405    updates_tx: watch::Sender<()>,
406}
407
408pub struct SupermavenCompletion {
409    pub id: SupermavenCompletionStateId,
410    pub updates: watch::Receiver<()>,
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use collections::BTreeMap;
417    use gpui::TestAppContext;
418    use language::Buffer;
419
420    #[gpui::test]
421    async fn test_find_relevant_completion_no_first_letter_skip(cx: &mut TestAppContext) {
422        let buffer = cx.new(|cx| Buffer::local("hello world", cx));
423        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
424
425        let mut states = BTreeMap::new();
426        let state_id = SupermavenCompletionStateId(1);
427        let (updates_tx, _) = watch::channel();
428
429        states.insert(
430            state_id,
431            SupermavenCompletionState {
432                buffer_id: buffer.entity_id(),
433                prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer
434                prefix_offset: 0,
435                text: "hello".to_string(),
436                dedent: String::new(),
437                updates_tx,
438            },
439        );
440
441        let cursor_position = buffer_snapshot.anchor_after(1);
442
443        let result = find_relevant_completion(
444            &states,
445            buffer.entity_id(),
446            &buffer_snapshot,
447            cursor_position,
448        );
449
450        assert_eq!(result, Some("ello"));
451    }
452
453    #[gpui::test]
454    async fn test_find_relevant_completion_with_multiple_chars(cx: &mut TestAppContext) {
455        let buffer = cx.new(|cx| Buffer::local("hello world", cx));
456        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
457
458        let mut states = BTreeMap::new();
459        let state_id = SupermavenCompletionStateId(1);
460        let (updates_tx, _) = watch::channel();
461
462        states.insert(
463            state_id,
464            SupermavenCompletionState {
465                buffer_id: buffer.entity_id(),
466                prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer
467                prefix_offset: 0,
468                text: "hello".to_string(),
469                dedent: String::new(),
470                updates_tx,
471            },
472        );
473
474        let cursor_position = buffer_snapshot.anchor_after(3);
475
476        let result = find_relevant_completion(
477            &states,
478            buffer.entity_id(),
479            &buffer_snapshot,
480            cursor_position,
481        );
482
483        assert_eq!(result, Some("lo"));
484    }
485}