supermaven.rs

  1mod messages;
  2mod supermaven_completion_provider;
  3
  4pub use supermaven_completion_provider::*;
  5
  6use anyhow::{Context as _, Result};
  7#[allow(unused_imports)]
  8use client::{proto, Client};
  9use collections::BTreeMap;
 10
 11use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt};
 12use gpui::{
 13    actions, AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel,
 14};
 15use language::{
 16    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, ToOffset,
 17};
 18use messages::*;
 19use postage::watch;
 20use serde::{Deserialize, Serialize};
 21use settings::SettingsStore;
 22use smol::{
 23    io::AsyncWriteExt,
 24    process::{Child, ChildStdin, ChildStdout},
 25};
 26use std::{path::PathBuf, process::Stdio, sync::Arc};
 27use ui::prelude::*;
 28use util::ResultExt;
 29
 30actions!(supermaven, [SignOut]);
 31
 32pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 33    let supermaven = cx.new_model(|_| Supermaven::Starting);
 34    Supermaven::set_global(supermaven.clone(), cx);
 35
 36    let mut provider = all_language_settings(None, cx).inline_completions.provider;
 37    if provider == language::language_settings::InlineCompletionProvider::Supermaven {
 38        supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 39    }
 40
 41    cx.observe_global::<SettingsStore>(move |cx| {
 42        let new_provider = all_language_settings(None, cx).inline_completions.provider;
 43        if new_provider != provider {
 44            provider = new_provider;
 45            if provider == language::language_settings::InlineCompletionProvider::Supermaven {
 46                supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
 47            } else {
 48                supermaven.update(cx, |supermaven, _cx| supermaven.stop());
 49            }
 50        }
 51    })
 52    .detach();
 53
 54    cx.on_action(|_: &SignOut, cx| {
 55        if let Some(supermaven) = Supermaven::global(cx) {
 56            supermaven.update(cx, |supermaven, _cx| supermaven.sign_out());
 57        }
 58    });
 59}
 60
 61pub enum Supermaven {
 62    Starting,
 63    FailedDownload { error: anyhow::Error },
 64    Spawned(SupermavenAgent),
 65    Error { error: anyhow::Error },
 66}
 67
 68#[derive(Clone)]
 69pub enum AccountStatus {
 70    Unknown,
 71    NeedsActivation { activate_url: String },
 72    Ready,
 73}
 74
 75#[derive(Clone)]
 76struct SupermavenGlobal(Model<Supermaven>);
 77
 78impl Global for SupermavenGlobal {}
 79
 80impl Supermaven {
 81    pub fn global(cx: &AppContext) -> Option<Model<Self>> {
 82        cx.try_global::<SupermavenGlobal>()
 83            .map(|model| model.0.clone())
 84    }
 85
 86    pub fn set_global(supermaven: Model<Self>, cx: &mut AppContext) {
 87        cx.set_global(SupermavenGlobal(supermaven));
 88    }
 89
 90    pub fn start(&mut self, client: Arc<Client>, cx: &mut ModelContext<Self>) {
 91        if let Self::Starting = self {
 92            cx.spawn(|this, mut cx| async move {
 93                let binary_path =
 94                    supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
 95
 96                this.update(&mut cx, |this, cx| {
 97                    if let Self::Starting = this {
 98                        *this =
 99                            Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
100                    }
101                    anyhow::Ok(())
102                })
103            })
104            .detach_and_log_err(cx)
105        }
106    }
107
108    pub fn stop(&mut self) {
109        *self = Self::Starting;
110    }
111
112    pub fn is_enabled(&self) -> bool {
113        matches!(self, Self::Spawned { .. })
114    }
115
116    pub fn complete(
117        &mut self,
118        buffer: &Model<Buffer>,
119        cursor_position: Anchor,
120        cx: &AppContext,
121    ) -> Option<SupermavenCompletion> {
122        if let Self::Spawned(agent) = self {
123            let buffer_id = buffer.entity_id();
124            let buffer = buffer.read(cx);
125            let path = buffer
126                .file()
127                .and_then(|file| Some(file.as_local()?.abs_path(cx)))
128                .unwrap_or_else(|| PathBuf::from("untitled"))
129                .to_string_lossy()
130                .to_string();
131            let content = buffer.text();
132            let offset = cursor_position.to_offset(buffer);
133            let state_id = agent.next_state_id;
134            agent.next_state_id.0 += 1;
135
136            let (updates_tx, mut updates_rx) = watch::channel();
137            postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
138
139            agent.states.insert(
140                state_id,
141                SupermavenCompletionState {
142                    buffer_id,
143                    prefix_anchor: cursor_position,
144                    prefix_offset: offset,
145                    text: String::new(),
146                    dedent: String::new(),
147                    updates_tx,
148                },
149            );
150            // ensure the states map is max 1000 elements
151            if agent.states.len() > 1000 {
152                // state id is monotonic so it's sufficient to remove the first element
153                agent
154                    .states
155                    .remove(&agent.states.keys().next().unwrap().clone());
156            }
157
158            let _ = agent
159                .outgoing_tx
160                .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
161                    new_id: state_id.0.to_string(),
162                    updates: vec![
163                        StateUpdate::FileUpdate(FileUpdateMessage {
164                            path: path.clone(),
165                            content,
166                        }),
167                        StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
168                    ],
169                }));
170
171            Some(SupermavenCompletion {
172                id: state_id,
173                updates: updates_rx,
174            })
175        } else {
176            None
177        }
178    }
179
180    pub fn completion(
181        &self,
182        buffer: &Model<Buffer>,
183        cursor_position: Anchor,
184        cx: &AppContext,
185    ) -> Option<&str> {
186        if let Self::Spawned(agent) = self {
187            find_relevant_completion(
188                &agent.states,
189                buffer.entity_id(),
190                &buffer.read(cx).snapshot(),
191                cursor_position,
192            )
193        } else {
194            None
195        }
196    }
197
198    pub fn sign_out(&mut self) {
199        if let Self::Spawned(agent) = self {
200            agent
201                .outgoing_tx
202                .unbounded_send(OutboundMessage::Logout)
203                .ok();
204            // The account status will get set to RequiresActivation or Ready when the next
205            // message from the agent comes in. Until that happens, set the status to Unknown
206            // to disable the button.
207            agent.account_status = AccountStatus::Unknown;
208        }
209    }
210}
211
212fn find_relevant_completion<'a>(
213    states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
214    buffer_id: EntityId,
215    buffer: &BufferSnapshot,
216    cursor_position: Anchor,
217) -> Option<&'a str> {
218    let mut best_completion: Option<&str> = None;
219    'completions: for state in states.values() {
220        if state.buffer_id != buffer_id {
221            continue;
222        }
223        let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
224            continue;
225        };
226
227        let current_cursor_offset = cursor_position.to_offset(buffer);
228        if current_cursor_offset < state.prefix_offset {
229            continue;
230        }
231
232        let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left);
233        let text_inserted_since_completion_request =
234            buffer.text_for_range(original_cursor_offset..current_cursor_offset);
235        let mut trimmed_completion = state_completion;
236        for chunk in text_inserted_since_completion_request {
237            if let Some(suffix) = trimmed_completion.strip_prefix(chunk) {
238                trimmed_completion = suffix;
239            } else {
240                continue 'completions;
241            }
242        }
243
244        if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) {
245            continue;
246        }
247
248        best_completion = Some(trimmed_completion);
249    }
250    best_completion
251}
252
253pub struct SupermavenAgent {
254    _process: Child,
255    next_state_id: SupermavenCompletionStateId,
256    states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
257    outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
258    _handle_outgoing_messages: Task<Result<()>>,
259    _handle_incoming_messages: Task<Result<()>>,
260    pub account_status: AccountStatus,
261    service_tier: Option<ServiceTier>,
262    #[allow(dead_code)]
263    client: Arc<Client>,
264}
265
266impl SupermavenAgent {
267    fn new(
268        binary_path: PathBuf,
269        client: Arc<Client>,
270        cx: &mut ModelContext<Supermaven>,
271    ) -> Result<Self> {
272        let mut process = util::command::new_smol_command(&binary_path)
273            .arg("stdio")
274            .stdin(Stdio::piped())
275            .stdout(Stdio::piped())
276            .stderr(Stdio::piped())
277            .kill_on_drop(true)
278            .spawn()
279            .context("failed to start the binary")?;
280
281        let stdin = process
282            .stdin
283            .take()
284            .context("failed to get stdin for process")?;
285        let stdout = process
286            .stdout
287            .take()
288            .context("failed to get stdout for process")?;
289
290        let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
291
292        cx.spawn({
293            let client = client.clone();
294            let outgoing_tx = outgoing_tx.clone();
295            move |this, mut cx| async move {
296                let mut status = client.status();
297                while let Some(status) = status.next().await {
298                    if status.is_connected() {
299                        let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
300                        outgoing_tx
301                            .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
302                            .ok();
303                        this.update(&mut cx, |this, cx| {
304                            if let Supermaven::Spawned(this) = this {
305                                this.account_status = AccountStatus::Ready;
306                                cx.notify();
307                            }
308                        })?;
309                        break;
310                    }
311                }
312                anyhow::Ok(())
313            }
314        })
315        .detach();
316
317        Ok(Self {
318            _process: process,
319            next_state_id: SupermavenCompletionStateId::default(),
320            states: BTreeMap::default(),
321            outgoing_tx,
322            _handle_outgoing_messages: cx
323                .spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)),
324            _handle_incoming_messages: cx
325                .spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)),
326            account_status: AccountStatus::Unknown,
327            service_tier: None,
328            client,
329        })
330    }
331
332    async fn handle_outgoing_messages(
333        mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
334        mut stdin: ChildStdin,
335    ) -> Result<()> {
336        while let Some(message) = outgoing.next().await {
337            let bytes = serde_json::to_vec(&message)?;
338            stdin.write_all(&bytes).await?;
339            stdin.write_all(&[b'\n']).await?;
340        }
341        Ok(())
342    }
343
344    async fn handle_incoming_messages(
345        this: WeakModel<Supermaven>,
346        stdout: ChildStdout,
347        mut cx: AsyncAppContext,
348    ) -> Result<()> {
349        const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
350
351        let stdout = BufReader::new(stdout);
352        let mut lines = stdout.lines();
353        while let Some(line) = lines.next().await {
354            let Some(line) = line.context("failed to read line from stdout").log_err() else {
355                continue;
356            };
357            let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
358                continue;
359            };
360            let Some(message) = serde_json::from_str::<SupermavenMessage>(line)
361                .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
362                .log_err()
363            else {
364                continue;
365            };
366
367            this.update(&mut cx, |this, _cx| {
368                if let Supermaven::Spawned(this) = this {
369                    this.handle_message(message);
370                }
371                Task::ready(anyhow::Ok(()))
372            })?
373            .await?;
374        }
375
376        Ok(())
377    }
378
379    fn handle_message(&mut self, message: SupermavenMessage) {
380        match message {
381            SupermavenMessage::ActivationRequest(request) => {
382                self.account_status = match request.activate_url {
383                    Some(activate_url) => AccountStatus::NeedsActivation {
384                        activate_url: activate_url.clone(),
385                    },
386                    None => AccountStatus::Ready,
387                };
388            }
389            SupermavenMessage::ActivationSuccess => {
390                self.account_status = AccountStatus::Ready;
391            }
392            SupermavenMessage::ServiceTier { service_tier } => {
393                self.account_status = AccountStatus::Ready;
394                self.service_tier = Some(service_tier);
395            }
396            SupermavenMessage::Response(response) => {
397                let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
398                if let Some(state) = self.states.get_mut(&state_id) {
399                    for item in &response.items {
400                        match item {
401                            ResponseItem::Text { text } => state.text.push_str(text),
402                            ResponseItem::Dedent { text } => state.dedent.push_str(text),
403                            _ => {}
404                        }
405                    }
406                    *state.updates_tx.borrow_mut() = ();
407                }
408            }
409            SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
410            _ => {
411                log::warn!("unhandled message: {:?}", message);
412            }
413        }
414    }
415}
416
417#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
418pub struct SupermavenCompletionStateId(usize);
419
420#[allow(dead_code)]
421pub struct SupermavenCompletionState {
422    buffer_id: EntityId,
423    prefix_anchor: Anchor,
424    // prefix_offset is tracked independently because the anchor biases left which
425    // doesn't allow us to determine if the prior text has been deleted.
426    prefix_offset: usize,
427    text: String,
428    dedent: String,
429    updates_tx: watch::Sender<()>,
430}
431
432pub struct SupermavenCompletion {
433    pub id: SupermavenCompletionStateId,
434    pub updates: watch::Receiver<()>,
435}