supermaven.rs

  1mod messages;
  2mod supermaven_completion_provider;
  3
  4pub use supermaven_completion_provider::*;
  5
  6use anyhow::{Context as _, Result};
  7#[allow(unused_imports)]
  8use client::{proto, Client};
  9use collections::BTreeMap;
 10
 11use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt};
 12use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel};
 13use language::{
 14    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, ToOffset,
 15};
 16use messages::*;
 17use postage::watch;
 18use serde::{Deserialize, Serialize};
 19use settings::SettingsStore;
 20use smol::{
 21    io::AsyncWriteExt,
 22    process::{Child, ChildStdin, ChildStdout, Command},
 23};
 24use std::{path::PathBuf, process::Stdio, sync::Arc};
 25use ui::prelude::*;
 26use util::ResultExt;
 27
 28pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 29    let supermaven = cx.new_model(|_| Supermaven::Starting);
 30    Supermaven::set_global(supermaven.clone(), cx);
 31
 32    let mut provider = all_language_settings(None, cx).inline_completions.provider;
 33    if provider == language::language_settings::InlineCompletionProvider::Supermaven {
 34        supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 35    }
 36
 37    cx.observe_global::<SettingsStore>(move |cx| {
 38        let new_provider = all_language_settings(None, cx).inline_completions.provider;
 39        if new_provider != provider {
 40            provider = new_provider;
 41            if provider == language::language_settings::InlineCompletionProvider::Supermaven {
 42                supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 43            } else {
 44                supermaven.update(cx, |supermaven, _cx| supermaven.stop());
 45            }
 46        }
 47    })
 48    .detach();
 49}
 50
 51pub enum Supermaven {
 52    Starting,
 53    FailedDownload { error: anyhow::Error },
 54    Spawned(SupermavenAgent),
 55    Error { error: anyhow::Error },
 56}
 57
 58#[derive(Clone)]
 59pub enum AccountStatus {
 60    Unknown,
 61    NeedsActivation { activate_url: String },
 62    Ready,
 63}
 64
 65#[derive(Clone)]
 66struct SupermavenGlobal(Model<Supermaven>);
 67
 68impl Global for SupermavenGlobal {}
 69
 70impl Supermaven {
 71    pub fn global(cx: &AppContext) -> Option<Model<Self>> {
 72        cx.try_global::<SupermavenGlobal>()
 73            .map(|model| model.0.clone())
 74    }
 75
 76    pub fn set_global(supermaven: Model<Self>, cx: &mut AppContext) {
 77        cx.set_global(SupermavenGlobal(supermaven));
 78    }
 79
 80    pub fn start(&mut self, client: Arc<Client>, cx: &mut ModelContext<Self>) {
 81        if let Self::Starting = self {
 82            cx.spawn(|this, mut cx| async move {
 83                let binary_path =
 84                    supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
 85
 86                this.update(&mut cx, |this, cx| {
 87                    if let Self::Starting = this {
 88                        *this =
 89                            Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
 90                    }
 91                    anyhow::Ok(())
 92                })
 93            })
 94            .detach_and_log_err(cx)
 95        }
 96    }
 97
 98    pub fn stop(&mut self) {
 99        *self = Self::Starting;
100    }
101
102    pub fn is_enabled(&self) -> bool {
103        matches!(self, Self::Spawned { .. })
104    }
105
106    pub fn complete(
107        &mut self,
108        buffer: &Model<Buffer>,
109        cursor_position: Anchor,
110        cx: &AppContext,
111    ) -> Option<SupermavenCompletion> {
112        if let Self::Spawned(agent) = self {
113            let buffer_id = buffer.entity_id();
114            let buffer = buffer.read(cx);
115            let path = buffer
116                .file()
117                .and_then(|file| Some(file.as_local()?.abs_path(cx)))
118                .unwrap_or_else(|| PathBuf::from("untitled"))
119                .to_string_lossy()
120                .to_string();
121            let content = buffer.text();
122            let offset = cursor_position.to_offset(buffer);
123            let state_id = agent.next_state_id;
124            agent.next_state_id.0 += 1;
125
126            let (updates_tx, mut updates_rx) = watch::channel();
127            postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
128
129            agent.states.insert(
130                state_id,
131                SupermavenCompletionState {
132                    buffer_id,
133                    prefix_anchor: cursor_position,
134                    text: String::new(),
135                    dedent: String::new(),
136                    updates_tx,
137                },
138            );
139            let _ = agent
140                .outgoing_tx
141                .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
142                    new_id: state_id.0.to_string(),
143                    updates: vec![
144                        StateUpdate::FileUpdate(FileUpdateMessage {
145                            path: path.clone(),
146                            content,
147                        }),
148                        StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
149                    ],
150                }));
151
152            Some(SupermavenCompletion {
153                id: state_id,
154                updates: updates_rx,
155            })
156        } else {
157            None
158        }
159    }
160
161    pub fn completion(
162        &self,
163        buffer: &Model<Buffer>,
164        cursor_position: Anchor,
165        cx: &AppContext,
166    ) -> Option<&str> {
167        if let Self::Spawned(agent) = self {
168            find_relevant_completion(
169                &agent.states,
170                buffer.entity_id(),
171                &buffer.read(cx).snapshot(),
172                cursor_position,
173            )
174        } else {
175            None
176        }
177    }
178}
179
180fn find_relevant_completion<'a>(
181    states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
182    buffer_id: EntityId,
183    buffer: &BufferSnapshot,
184    cursor_position: Anchor,
185) -> Option<&'a str> {
186    let mut best_completion: Option<&str> = None;
187    'completions: for state in states.values() {
188        if state.buffer_id != buffer_id {
189            continue;
190        }
191        let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
192            continue;
193        };
194
195        let current_cursor_offset = cursor_position.to_offset(buffer);
196        let original_cursor_offset = state.prefix_anchor.to_offset(buffer);
197        if current_cursor_offset < original_cursor_offset {
198            continue;
199        }
200
201        let text_inserted_since_completion_request =
202            buffer.text_for_range(original_cursor_offset..current_cursor_offset);
203        let mut trimmed_completion = state_completion;
204        for chunk in text_inserted_since_completion_request {
205            if let Some(suffix) = trimmed_completion.strip_prefix(chunk) {
206                trimmed_completion = suffix;
207            } else {
208                continue 'completions;
209            }
210        }
211
212        if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) {
213            continue;
214        }
215
216        best_completion = Some(trimmed_completion);
217    }
218    best_completion
219}
220
221pub struct SupermavenAgent {
222    _process: Child,
223    next_state_id: SupermavenCompletionStateId,
224    states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
225    outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
226    _handle_outgoing_messages: Task<Result<()>>,
227    _handle_incoming_messages: Task<Result<()>>,
228    pub account_status: AccountStatus,
229    service_tier: Option<ServiceTier>,
230    #[allow(dead_code)]
231    client: Arc<Client>,
232}
233
234impl SupermavenAgent {
235    fn new(
236        binary_path: PathBuf,
237        client: Arc<Client>,
238        cx: &mut ModelContext<Supermaven>,
239    ) -> Result<Self> {
240        let mut process = Command::new(&binary_path)
241            .arg("stdio")
242            .stdin(Stdio::piped())
243            .stdout(Stdio::piped())
244            .stderr(Stdio::piped())
245            .kill_on_drop(true)
246            .spawn()
247            .context("failed to start the binary")?;
248
249        let stdin = process
250            .stdin
251            .take()
252            .context("failed to get stdin for process")?;
253        let stdout = process
254            .stdout
255            .take()
256            .context("failed to get stdout for process")?;
257
258        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
259
260        cx.spawn({
261            let client = client.clone();
262            let outgoing_tx = outgoing_tx.clone();
263            move |this, mut cx| async move {
264                let mut status = client.status();
265                while let Some(status) = status.next().await {
266                    if status.is_connected() {
267                        let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
268                        outgoing_tx
269                            .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
270                            .ok();
271                        this.update(&mut cx, |this, cx| {
272                            if let Supermaven::Spawned(this) = this {
273                                this.account_status = AccountStatus::Ready;
274                                cx.notify();
275                            }
276                        })?;
277                        break;
278                    }
279                }
280                return anyhow::Ok(());
281            }
282        })
283        .detach();
284
285        Ok(Self {
286            _process: process,
287            next_state_id: SupermavenCompletionStateId::default(),
288            states: BTreeMap::default(),
289            outgoing_tx,
290            _handle_outgoing_messages: cx
291                .spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)),
292            _handle_incoming_messages: cx
293                .spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)),
294            account_status: AccountStatus::Unknown,
295            service_tier: None,
296            client,
297        })
298    }
299
300    async fn handle_outgoing_messages(
301        mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
302        mut stdin: ChildStdin,
303    ) -> Result<()> {
304        while let Some(message) = outgoing.next().await {
305            let bytes = serde_json::to_vec(&message)?;
306            stdin.write_all(&bytes).await?;
307            stdin.write_all(&[b'\n']).await?;
308        }
309        Ok(())
310    }
311
312    async fn handle_incoming_messages(
313        this: WeakModel<Supermaven>,
314        stdout: ChildStdout,
315        mut cx: AsyncAppContext,
316    ) -> Result<()> {
317        const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
318
319        let stdout = BufReader::new(stdout);
320        let mut lines = stdout.lines();
321        while let Some(line) = lines.next().await {
322            let Some(line) = line.context("failed to read line from stdout").log_err() else {
323                continue;
324            };
325            let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
326                continue;
327            };
328            let Some(message) = serde_json::from_str::<SupermavenMessage>(&line)
329                .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
330                .log_err()
331            else {
332                continue;
333            };
334
335            this.update(&mut cx, |this, _cx| {
336                if let Supermaven::Spawned(this) = this {
337                    this.handle_message(message);
338                }
339                Task::ready(anyhow::Ok(()))
340            })?
341            .await?;
342        }
343
344        Ok(())
345    }
346
347    fn handle_message(&mut self, message: SupermavenMessage) {
348        match message {
349            SupermavenMessage::ActivationRequest(request) => {
350                self.account_status = match request.activate_url {
351                    Some(activate_url) => AccountStatus::NeedsActivation {
352                        activate_url: activate_url.clone(),
353                    },
354                    None => AccountStatus::Ready,
355                };
356            }
357            SupermavenMessage::ServiceTier { service_tier } => {
358                self.service_tier = Some(service_tier);
359            }
360            SupermavenMessage::Response(response) => {
361                let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
362                if let Some(state) = self.states.get_mut(&state_id) {
363                    for item in &response.items {
364                        match item {
365                            ResponseItem::Text { text } => state.text.push_str(text),
366                            ResponseItem::Dedent { text } => state.dedent.push_str(text),
367                            _ => {}
368                        }
369                    }
370                    *state.updates_tx.borrow_mut() = ();
371                }
372            }
373            SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
374            _ => {
375                log::warn!("unhandled message: {:?}", message);
376            }
377        }
378    }
379}
380
381#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
382pub struct SupermavenCompletionStateId(usize);
383
384#[allow(dead_code)]
385pub struct SupermavenCompletionState {
386    buffer_id: EntityId,
387    prefix_anchor: Anchor,
388    text: String,
389    dedent: String,
390    updates_tx: watch::Sender<()>,
391}
392
393pub struct SupermavenCompletion {
394    pub id: SupermavenCompletionStateId,
395    pub updates: watch::Receiver<()>,
396}