supermaven.rs

  1mod messages;
  2mod supermaven_completion_provider;
  3
  4pub use supermaven_completion_provider::*;
  5
  6use anyhow::{Context as _, Result};
  7#[allow(unused_imports)]
  8use client::{Client, proto};
  9use collections::BTreeMap;
 10
 11use futures::{AsyncBufReadExt, StreamExt, channel::mpsc, io::BufReader};
 12use gpui::{App, AsyncApp, Context, Entity, EntityId, Global, Task, WeakEntity, actions};
 13use language::{
 14    Anchor, Buffer, BufferSnapshot, ToOffset, language_settings::all_language_settings,
 15};
 16use messages::*;
 17use postage::watch;
 18use serde::{Deserialize, Serialize};
 19use settings::SettingsStore;
 20use smol::{
 21    io::AsyncWriteExt,
 22    process::{Child, ChildStdin, ChildStdout},
 23};
 24use std::{path::PathBuf, process::Stdio, sync::Arc};
 25use ui::prelude::*;
 26use util::ResultExt;
 27
 28actions!(
 29    supermaven,
 30    [
 31        /// Signs out of Supermaven.
 32        SignOut
 33    ]
 34);
 35
 36pub fn init(client: Arc<Client>, cx: &mut App) {
 37    let supermaven = cx.new(|_| Supermaven::Starting);
 38    Supermaven::set_global(supermaven.clone(), cx);
 39
 40    let mut provider = all_language_settings(None, cx).edit_predictions.provider;
 41    if provider == language::language_settings::EditPredictionProvider::Supermaven {
 42        supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 43    }
 44
 45    cx.observe_global::<SettingsStore>(move |cx| {
 46        let new_provider = all_language_settings(None, cx).edit_predictions.provider;
 47        if new_provider != provider {
 48            provider = new_provider;
 49            if provider == language::language_settings::EditPredictionProvider::Supermaven {
 50                supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 51            } else {
 52                supermaven.update(cx, |supermaven, _cx| supermaven.stop());
 53            }
 54        }
 55    })
 56    .detach();
 57
 58    cx.on_action(|_: &SignOut, cx| {
 59        if let Some(supermaven) = Supermaven::global(cx) {
 60            supermaven.update(cx, |supermaven, _cx| supermaven.sign_out());
 61        }
 62    });
 63}
 64
 65pub enum Supermaven {
 66    Starting,
 67    FailedDownload { error: anyhow::Error },
 68    Spawned(SupermavenAgent),
 69    Error { error: anyhow::Error },
 70}
 71
 72#[derive(Clone)]
 73pub enum AccountStatus {
 74    Unknown,
 75    NeedsActivation { activate_url: String },
 76    Ready,
 77}
 78
 79#[derive(Clone)]
 80struct SupermavenGlobal(Entity<Supermaven>);
 81
 82impl Global for SupermavenGlobal {}
 83
 84impl Supermaven {
 85    pub fn global(cx: &App) -> Option<Entity<Self>> {
 86        cx.try_global::<SupermavenGlobal>()
 87            .map(|model| model.0.clone())
 88    }
 89
 90    pub fn set_global(supermaven: Entity<Self>, cx: &mut App) {
 91        cx.set_global(SupermavenGlobal(supermaven));
 92    }
 93
 94    pub fn start(&mut self, client: Arc<Client>, cx: &mut Context<Self>) {
 95        if let Self::Starting = self {
 96            cx.spawn(async move |this, cx| {
 97                let binary_path =
 98                    supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
 99
100                this.update(cx, |this, cx| {
101                    if let Self::Starting = this {
102                        *this =
103                            Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
104                    }
105                    anyhow::Ok(())
106                })
107            })
108            .detach_and_log_err(cx)
109        }
110    }
111
112    pub fn stop(&mut self) {
113        *self = Self::Starting;
114    }
115
116    pub fn is_enabled(&self) -> bool {
117        matches!(self, Self::Spawned { .. })
118    }
119
120    pub fn complete(
121        &mut self,
122        buffer: &Entity<Buffer>,
123        cursor_position: Anchor,
124        cx: &App,
125    ) -> Option<SupermavenCompletion> {
126        if let Self::Spawned(agent) = self {
127            let buffer_id = buffer.entity_id();
128            let buffer = buffer.read(cx);
129            let path = buffer
130                .file()
131                .and_then(|file| Some(file.as_local()?.abs_path(cx)))
132                .unwrap_or_else(|| PathBuf::from("untitled"))
133                .to_string_lossy()
134                .to_string();
135            let content = buffer.text();
136            let offset = cursor_position.to_offset(buffer);
137            let state_id = agent.next_state_id;
138            agent.next_state_id.0 += 1;
139
140            let (updates_tx, mut updates_rx) = watch::channel();
141            postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
142
143            agent.states.insert(
144                state_id,
145                SupermavenCompletionState {
146                    buffer_id,
147                    prefix_anchor: cursor_position,
148                    prefix_offset: offset,
149                    text: String::new(),
150                    dedent: String::new(),
151                    updates_tx,
152                },
153            );
154            // ensure the states map is max 1000 elements
155            if agent.states.len() > 1000 {
156                // state id is monotonic so it's sufficient to remove the first element
157                agent
158                    .states
159                    .remove(&agent.states.keys().next().unwrap().clone());
160            }
161
162            let _ = agent
163                .outgoing_tx
164                .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
165                    new_id: state_id.0.to_string(),
166                    updates: vec![
167                        StateUpdate::FileUpdate(FileUpdateMessage {
168                            path: path.clone(),
169                            content,
170                        }),
171                        StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
172                    ],
173                }));
174
175            Some(SupermavenCompletion {
176                id: state_id,
177                updates: updates_rx,
178            })
179        } else {
180            None
181        }
182    }
183
184    pub fn completion(
185        &self,
186        buffer: &Entity<Buffer>,
187        cursor_position: Anchor,
188        cx: &App,
189    ) -> Option<&str> {
190        if let Self::Spawned(agent) = self {
191            find_relevant_completion(
192                &agent.states,
193                buffer.entity_id(),
194                &buffer.read(cx).snapshot(),
195                cursor_position,
196            )
197        } else {
198            None
199        }
200    }
201
202    pub fn sign_out(&mut self) {
203        if let Self::Spawned(agent) = self {
204            agent
205                .outgoing_tx
206                .unbounded_send(OutboundMessage::Logout)
207                .ok();
208            // The account status will get set to RequiresActivation or Ready when the next
209            // message from the agent comes in. Until that happens, set the status to Unknown
210            // to disable the button.
211            agent.account_status = AccountStatus::Unknown;
212        }
213    }
214}
215
216fn find_relevant_completion<'a>(
217    states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
218    buffer_id: EntityId,
219    buffer: &BufferSnapshot,
220    cursor_position: Anchor,
221) -> Option<&'a str> {
222    let mut best_completion: Option<&str> = None;
223    'completions: for state in states.values() {
224        if state.buffer_id != buffer_id {
225            continue;
226        }
227        let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
228            continue;
229        };
230
231        let current_cursor_offset = cursor_position.to_offset(buffer);
232        if current_cursor_offset < state.prefix_offset {
233            continue;
234        }
235
236        let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left);
237        let text_inserted_since_completion_request: String = buffer
238            .text_for_range(original_cursor_offset..current_cursor_offset)
239            .collect();
240        let trimmed_completion =
241            match state_completion.strip_prefix(&text_inserted_since_completion_request) {
242                Some(suffix) => suffix,
243                None => continue 'completions,
244            };
245
246        if best_completion.is_some_and(|best| best.len() > trimmed_completion.len()) {
247            continue;
248        }
249
250        best_completion = Some(trimmed_completion);
251    }
252    best_completion
253}
254
255pub struct SupermavenAgent {
256    _process: Child,
257    next_state_id: SupermavenCompletionStateId,
258    states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
259    outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
260    _handle_outgoing_messages: Task<Result<()>>,
261    _handle_incoming_messages: Task<Result<()>>,
262    pub account_status: AccountStatus,
263    service_tier: Option<ServiceTier>,
264    #[allow(dead_code)]
265    client: Arc<Client>,
266}
267
268impl SupermavenAgent {
269    fn new(
270        binary_path: PathBuf,
271        client: Arc<Client>,
272        cx: &mut Context<Supermaven>,
273    ) -> Result<Self> {
274        let mut process = util::command::new_smol_command(&binary_path)
275            .arg("stdio")
276            .stdin(Stdio::piped())
277            .stdout(Stdio::piped())
278            .stderr(Stdio::piped())
279            .kill_on_drop(true)
280            .spawn()
281            .context("failed to start the binary")?;
282
283        let stdin = process
284            .stdin
285            .take()
286            .context("failed to get stdin for process")?;
287        let stdout = process
288            .stdout
289            .take()
290            .context("failed to get stdout for process")?;
291
292        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
293
294        cx.spawn({
295            let client = client.clone();
296            let outgoing_tx = outgoing_tx.clone();
297            async move |this, cx| {
298                let mut status = client.status();
299                while let Some(status) = status.next().await {
300                    if status.is_connected() {
301                        let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
302                        outgoing_tx
303                            .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
304                            .ok();
305                        this.update(cx, |this, cx| {
306                            if let Supermaven::Spawned(this) = this {
307                                this.account_status = AccountStatus::Ready;
308                                cx.notify();
309                            }
310                        })?;
311                        break;
312                    }
313                }
314                anyhow::Ok(())
315            }
316        })
317        .detach();
318
319        Ok(Self {
320            _process: process,
321            next_state_id: SupermavenCompletionStateId::default(),
322            states: BTreeMap::default(),
323            outgoing_tx,
324            _handle_outgoing_messages: cx.spawn(async move |_, _cx| {
325                Self::handle_outgoing_messages(outgoing_rx, stdin).await
326            }),
327            _handle_incoming_messages: cx.spawn(async move |this, cx| {
328                Self::handle_incoming_messages(this, stdout, cx).await
329            }),
330            account_status: AccountStatus::Unknown,
331            service_tier: None,
332            client,
333        })
334    }
335
336    async fn handle_outgoing_messages(
337        mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
338        mut stdin: ChildStdin,
339    ) -> Result<()> {
340        while let Some(message) = outgoing.next().await {
341            let bytes = serde_json::to_vec(&message)?;
342            stdin.write_all(&bytes).await?;
343            stdin.write_all(&[b'\n']).await?;
344        }
345        Ok(())
346    }
347
348    async fn handle_incoming_messages(
349        this: WeakEntity<Supermaven>,
350        stdout: ChildStdout,
351        cx: &mut AsyncApp,
352    ) -> Result<()> {
353        const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
354
355        let stdout = BufReader::new(stdout);
356        let mut lines = stdout.lines();
357        while let Some(line) = lines.next().await {
358            let Some(line) = line.context("failed to read line from stdout").log_err() else {
359                continue;
360            };
361            let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
362                continue;
363            };
364            let Some(message) = serde_json::from_str::<SupermavenMessage>(line)
365                .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
366                .log_err()
367            else {
368                continue;
369            };
370
371            this.update(cx, |this, _cx| {
372                if let Supermaven::Spawned(this) = this {
373                    this.handle_message(message);
374                }
375                Task::ready(anyhow::Ok(()))
376            })?
377            .await?;
378        }
379
380        Ok(())
381    }
382
383    fn handle_message(&mut self, message: SupermavenMessage) {
384        match message {
385            SupermavenMessage::ActivationRequest(request) => {
386                self.account_status = match request.activate_url {
387                    Some(activate_url) => AccountStatus::NeedsActivation { activate_url },
388                    None => AccountStatus::Ready,
389                };
390            }
391            SupermavenMessage::ActivationSuccess => {
392                self.account_status = AccountStatus::Ready;
393            }
394            SupermavenMessage::ServiceTier { service_tier } => {
395                self.account_status = AccountStatus::Ready;
396                self.service_tier = Some(service_tier);
397            }
398            SupermavenMessage::Response(response) => {
399                let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
400                if let Some(state) = self.states.get_mut(&state_id) {
401                    for item in &response.items {
402                        match item {
403                            ResponseItem::Text { text } => state.text.push_str(text),
404                            ResponseItem::Dedent { text } => state.dedent.push_str(text),
405                            _ => {}
406                        }
407                    }
408                    *state.updates_tx.borrow_mut() = ();
409                }
410            }
411            SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
412            _ => {
413                log::warn!("unhandled message: {:?}", message);
414            }
415        }
416    }
417}
418
419#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
420pub struct SupermavenCompletionStateId(usize);
421
422#[allow(dead_code)]
423pub struct SupermavenCompletionState {
424    buffer_id: EntityId,
425    prefix_anchor: Anchor,
426    // prefix_offset is tracked independently because the anchor biases left which
427    // doesn't allow us to determine if the prior text has been deleted.
428    prefix_offset: usize,
429    text: String,
430    dedent: String,
431    updates_tx: watch::Sender<()>,
432}
433
434pub struct SupermavenCompletion {
435    pub id: SupermavenCompletionStateId,
436    pub updates: watch::Receiver<()>,
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use collections::BTreeMap;
443    use gpui::TestAppContext;
444    use language::Buffer;
445
446    #[gpui::test]
447    async fn test_find_relevant_completion_no_first_letter_skip(cx: &mut TestAppContext) {
448        let buffer = cx.new(|cx| Buffer::local("hello world", cx));
449        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
450
451        let mut states = BTreeMap::new();
452        let state_id = SupermavenCompletionStateId(1);
453        let (updates_tx, _) = watch::channel();
454
455        states.insert(
456            state_id,
457            SupermavenCompletionState {
458                buffer_id: buffer.entity_id(),
459                prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer
460                prefix_offset: 0,
461                text: "hello".to_string(),
462                dedent: String::new(),
463                updates_tx,
464            },
465        );
466
467        let cursor_position = buffer_snapshot.anchor_after(1);
468
469        let result = find_relevant_completion(
470            &states,
471            buffer.entity_id(),
472            &buffer_snapshot,
473            cursor_position,
474        );
475
476        assert_eq!(result, Some("ello"));
477    }
478
479    #[gpui::test]
480    async fn test_find_relevant_completion_with_multiple_chars(cx: &mut TestAppContext) {
481        let buffer = cx.new(|cx| Buffer::local("hello world", cx));
482        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
483
484        let mut states = BTreeMap::new();
485        let state_id = SupermavenCompletionStateId(1);
486        let (updates_tx, _) = watch::channel();
487
488        states.insert(
489            state_id,
490            SupermavenCompletionState {
491                buffer_id: buffer.entity_id(),
492                prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer
493                prefix_offset: 0,
494                text: "hello".to_string(),
495                dedent: String::new(),
496                updates_tx,
497            },
498        );
499
500        let cursor_position = buffer_snapshot.anchor_after(3);
501
502        let result = find_relevant_completion(
503            &states,
504            buffer.entity_id(),
505            &buffer_snapshot,
506            cursor_position,
507        );
508
509        assert_eq!(result, Some("lo"));
510    }
511}