supermaven.rs

  1mod messages;
  2mod supermaven_completion_provider;
  3
  4pub use supermaven_completion_provider::*;
  5
  6use anyhow::{Context as _, Result};
  7#[allow(unused_imports)]
  8use client::{Client, proto};
  9use collections::BTreeMap;
 10
 11use futures::{AsyncBufReadExt, StreamExt, channel::mpsc, io::BufReader};
 12use gpui::{App, AsyncApp, Context, Entity, EntityId, Global, Task, WeakEntity, actions};
 13use language::{
 14    Anchor, Buffer, BufferSnapshot, ToOffset, language_settings::all_language_settings,
 15};
 16use messages::*;
 17use postage::watch;
 18use serde::{Deserialize, Serialize};
 19use settings::SettingsStore;
 20use smol::{
 21    io::AsyncWriteExt,
 22    process::{Child, ChildStdin, ChildStdout},
 23};
 24use std::{path::PathBuf, process::Stdio, sync::Arc};
 25use ui::prelude::*;
 26use util::ResultExt;
 27
 28actions!(supermaven, [SignOut]);
 29
 30pub fn init(client: Arc<Client>, cx: &mut App) {
 31    let supermaven = cx.new(|_| Supermaven::Starting);
 32    Supermaven::set_global(supermaven.clone(), cx);
 33
 34    let mut provider = all_language_settings(None, cx).edit_predictions.provider;
 35    if provider == language::language_settings::EditPredictionProvider::Supermaven {
 36        supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 37    }
 38
 39    cx.observe_global::<SettingsStore>(move |cx| {
 40        let new_provider = all_language_settings(None, cx).edit_predictions.provider;
 41        if new_provider != provider {
 42            provider = new_provider;
 43            if provider == language::language_settings::EditPredictionProvider::Supermaven {
 44                supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 45            } else {
 46                supermaven.update(cx, |supermaven, _cx| supermaven.stop());
 47            }
 48        }
 49    })
 50    .detach();
 51
 52    cx.on_action(|_: &SignOut, cx| {
 53        if let Some(supermaven) = Supermaven::global(cx) {
 54            supermaven.update(cx, |supermaven, _cx| supermaven.sign_out());
 55        }
 56    });
 57}
 58
 59pub enum Supermaven {
 60    Starting,
 61    FailedDownload { error: anyhow::Error },
 62    Spawned(SupermavenAgent),
 63    Error { error: anyhow::Error },
 64}
 65
 66#[derive(Clone)]
 67pub enum AccountStatus {
 68    Unknown,
 69    NeedsActivation { activate_url: String },
 70    Ready,
 71}
 72
 73#[derive(Clone)]
 74struct SupermavenGlobal(Entity<Supermaven>);
 75
 76impl Global for SupermavenGlobal {}
 77
 78impl Supermaven {
 79    pub fn global(cx: &App) -> Option<Entity<Self>> {
 80        cx.try_global::<SupermavenGlobal>()
 81            .map(|model| model.0.clone())
 82    }
 83
 84    pub fn set_global(supermaven: Entity<Self>, cx: &mut App) {
 85        cx.set_global(SupermavenGlobal(supermaven));
 86    }
 87
 88    pub fn start(&mut self, client: Arc<Client>, cx: &mut Context<Self>) {
 89        if let Self::Starting = self {
 90            cx.spawn(async move |this, cx| {
 91                let binary_path =
 92                    supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
 93
 94                this.update(cx, |this, cx| {
 95                    if let Self::Starting = this {
 96                        *this =
 97                            Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
 98                    }
 99                    anyhow::Ok(())
100                })
101            })
102            .detach_and_log_err(cx)
103        }
104    }
105
106    pub fn stop(&mut self) {
107        *self = Self::Starting;
108    }
109
110    pub fn is_enabled(&self) -> bool {
111        matches!(self, Self::Spawned { .. })
112    }
113
114    pub fn complete(
115        &mut self,
116        buffer: &Entity<Buffer>,
117        cursor_position: Anchor,
118        cx: &App,
119    ) -> Option<SupermavenCompletion> {
120        if let Self::Spawned(agent) = self {
121            let buffer_id = buffer.entity_id();
122            let buffer = buffer.read(cx);
123            let path = buffer
124                .file()
125                .and_then(|file| Some(file.as_local()?.abs_path(cx)))
126                .unwrap_or_else(|| PathBuf::from("untitled"))
127                .to_string_lossy()
128                .to_string();
129            let content = buffer.text();
130            let offset = cursor_position.to_offset(buffer);
131            let state_id = agent.next_state_id;
132            agent.next_state_id.0 += 1;
133
134            let (updates_tx, mut updates_rx) = watch::channel();
135            postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
136
137            agent.states.insert(
138                state_id,
139                SupermavenCompletionState {
140                    buffer_id,
141                    prefix_anchor: cursor_position,
142                    prefix_offset: offset,
143                    text: String::new(),
144                    dedent: String::new(),
145                    updates_tx,
146                },
147            );
148            // ensure the states map is max 1000 elements
149            if agent.states.len() > 1000 {
150                // state id is monotonic so it's sufficient to remove the first element
151                agent
152                    .states
153                    .remove(&agent.states.keys().next().unwrap().clone());
154            }
155
156            let _ = agent
157                .outgoing_tx
158                .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
159                    new_id: state_id.0.to_string(),
160                    updates: vec![
161                        StateUpdate::FileUpdate(FileUpdateMessage {
162                            path: path.clone(),
163                            content,
164                        }),
165                        StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
166                    ],
167                }));
168
169            Some(SupermavenCompletion {
170                id: state_id,
171                updates: updates_rx,
172            })
173        } else {
174            None
175        }
176    }
177
178    pub fn completion(
179        &self,
180        buffer: &Entity<Buffer>,
181        cursor_position: Anchor,
182        cx: &App,
183    ) -> Option<&str> {
184        if let Self::Spawned(agent) = self {
185            find_relevant_completion(
186                &agent.states,
187                buffer.entity_id(),
188                &buffer.read(cx).snapshot(),
189                cursor_position,
190            )
191        } else {
192            None
193        }
194    }
195
196    pub fn sign_out(&mut self) {
197        if let Self::Spawned(agent) = self {
198            agent
199                .outgoing_tx
200                .unbounded_send(OutboundMessage::Logout)
201                .ok();
202            // The account status will get set to RequiresActivation or Ready when the next
203            // message from the agent comes in. Until that happens, set the status to Unknown
204            // to disable the button.
205            agent.account_status = AccountStatus::Unknown;
206        }
207    }
208}
209
210fn find_relevant_completion<'a>(
211    states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
212    buffer_id: EntityId,
213    buffer: &BufferSnapshot,
214    cursor_position: Anchor,
215) -> Option<&'a str> {
216    let mut best_completion: Option<&str> = None;
217    'completions: for state in states.values() {
218        if state.buffer_id != buffer_id {
219            continue;
220        }
221        let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
222            continue;
223        };
224
225        let current_cursor_offset = cursor_position.to_offset(buffer);
226        if current_cursor_offset < state.prefix_offset {
227            continue;
228        }
229
230        let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left);
231        let text_inserted_since_completion_request =
232            buffer.text_for_range(original_cursor_offset..current_cursor_offset);
233        let mut trimmed_completion = state_completion;
234        for chunk in text_inserted_since_completion_request {
235            if let Some(suffix) = trimmed_completion.strip_prefix(chunk) {
236                trimmed_completion = suffix;
237            } else {
238                continue 'completions;
239            }
240        }
241
242        if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) {
243            continue;
244        }
245
246        best_completion = Some(trimmed_completion);
247    }
248    best_completion
249}
250
251pub struct SupermavenAgent {
252    _process: Child,
253    next_state_id: SupermavenCompletionStateId,
254    states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
255    outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
256    _handle_outgoing_messages: Task<Result<()>>,
257    _handle_incoming_messages: Task<Result<()>>,
258    pub account_status: AccountStatus,
259    service_tier: Option<ServiceTier>,
260    #[allow(dead_code)]
261    client: Arc<Client>,
262}
263
264impl SupermavenAgent {
265    fn new(
266        binary_path: PathBuf,
267        client: Arc<Client>,
268        cx: &mut Context<Supermaven>,
269    ) -> Result<Self> {
270        let mut process = util::command::new_smol_command(&binary_path)
271            .arg("stdio")
272            .stdin(Stdio::piped())
273            .stdout(Stdio::piped())
274            .stderr(Stdio::piped())
275            .kill_on_drop(true)
276            .spawn()
277            .context("failed to start the binary")?;
278
279        let stdin = process
280            .stdin
281            .take()
282            .context("failed to get stdin for process")?;
283        let stdout = process
284            .stdout
285            .take()
286            .context("failed to get stdout for process")?;
287
288        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
289
290        cx.spawn({
291            let client = client.clone();
292            let outgoing_tx = outgoing_tx.clone();
293            async move |this, cx| {
294                let mut status = client.status();
295                while let Some(status) = status.next().await {
296                    if status.is_connected() {
297                        let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
298                        outgoing_tx
299                            .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
300                            .ok();
301                        this.update(cx, |this, cx| {
302                            if let Supermaven::Spawned(this) = this {
303                                this.account_status = AccountStatus::Ready;
304                                cx.notify();
305                            }
306                        })?;
307                        break;
308                    }
309                }
310                anyhow::Ok(())
311            }
312        })
313        .detach();
314
315        Ok(Self {
316            _process: process,
317            next_state_id: SupermavenCompletionStateId::default(),
318            states: BTreeMap::default(),
319            outgoing_tx,
320            _handle_outgoing_messages: cx.spawn(async move |_, _cx| {
321                Self::handle_outgoing_messages(outgoing_rx, stdin).await
322            }),
323            _handle_incoming_messages: cx.spawn(async move |this, cx| {
324                Self::handle_incoming_messages(this, stdout, cx).await
325            }),
326            account_status: AccountStatus::Unknown,
327            service_tier: None,
328            client,
329        })
330    }
331
332    async fn handle_outgoing_messages(
333        mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
334        mut stdin: ChildStdin,
335    ) -> Result<()> {
336        while let Some(message) = outgoing.next().await {
337            let bytes = serde_json::to_vec(&message)?;
338            stdin.write_all(&bytes).await?;
339            stdin.write_all(&[b'\n']).await?;
340        }
341        Ok(())
342    }
343
344    async fn handle_incoming_messages(
345        this: WeakEntity<Supermaven>,
346        stdout: ChildStdout,
347        cx: &mut AsyncApp,
348    ) -> Result<()> {
349        const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
350
351        let stdout = BufReader::new(stdout);
352        let mut lines = stdout.lines();
353        while let Some(line) = lines.next().await {
354            let Some(line) = line.context("failed to read line from stdout").log_err() else {
355                continue;
356            };
357            let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
358                continue;
359            };
360            let Some(message) = serde_json::from_str::<SupermavenMessage>(line)
361                .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
362                .log_err()
363            else {
364                continue;
365            };
366
367            this.update(cx, |this, _cx| {
368                if let Supermaven::Spawned(this) = this {
369                    this.handle_message(message);
370                }
371                Task::ready(anyhow::Ok(()))
372            })?
373            .await?;
374        }
375
376        Ok(())
377    }
378
379    fn handle_message(&mut self, message: SupermavenMessage) {
380        match message {
381            SupermavenMessage::ActivationRequest(request) => {
382                self.account_status = match request.activate_url {
383                    Some(activate_url) => AccountStatus::NeedsActivation {
384                        activate_url: activate_url.clone(),
385                    },
386                    None => AccountStatus::Ready,
387                };
388            }
389            SupermavenMessage::ActivationSuccess => {
390                self.account_status = AccountStatus::Ready;
391            }
392            SupermavenMessage::ServiceTier { service_tier } => {
393                self.account_status = AccountStatus::Ready;
394                self.service_tier = Some(service_tier);
395            }
396            SupermavenMessage::Response(response) => {
397                let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
398                if let Some(state) = self.states.get_mut(&state_id) {
399                    for item in &response.items {
400                        match item {
401                            ResponseItem::Text { text } => state.text.push_str(text),
402                            ResponseItem::Dedent { text } => state.dedent.push_str(text),
403                            _ => {}
404                        }
405                    }
406                    *state.updates_tx.borrow_mut() = ();
407                }
408            }
409            SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
410            _ => {
411                log::warn!("unhandled message: {:?}", message);
412            }
413        }
414    }
415}
416
417#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
418pub struct SupermavenCompletionStateId(usize);
419
420#[allow(dead_code)]
421pub struct SupermavenCompletionState {
422    buffer_id: EntityId,
423    prefix_anchor: Anchor,
424    // prefix_offset is tracked independently because the anchor biases left which
425    // doesn't allow us to determine if the prior text has been deleted.
426    prefix_offset: usize,
427    text: String,
428    dedent: String,
429    updates_tx: watch::Sender<()>,
430}
431
432pub struct SupermavenCompletion {
433    pub id: SupermavenCompletionStateId,
434    pub updates: watch::Receiver<()>,
435}