1mod messages;
2mod supermaven_completion_provider;
3
4pub use supermaven_completion_provider::*;
5
6use anyhow::{Context as _, Result};
7#[allow(unused_imports)]
8use client::{Client, proto};
9use collections::BTreeMap;
10
11use futures::{AsyncBufReadExt, StreamExt, channel::mpsc, io::BufReader};
12use gpui::{App, AsyncApp, Context, Entity, EntityId, Global, Task, WeakEntity, actions};
13use language::{
14 Anchor, Buffer, BufferSnapshot, ToOffset, language_settings::all_language_settings,
15};
16use messages::*;
17use postage::watch;
18use serde::{Deserialize, Serialize};
19use settings::SettingsStore;
20use smol::{
21 io::AsyncWriteExt,
22 process::{Child, ChildStdin, ChildStdout},
23};
24use std::{path::PathBuf, process::Stdio, sync::Arc};
25use ui::prelude::*;
26use util::ResultExt;
27
28actions!(
29 supermaven,
30 [
31 /// Signs out of Supermaven.
32 SignOut
33 ]
34);
35
36pub fn init(client: Arc<Client>, cx: &mut App) {
37 let supermaven = cx.new(|_| Supermaven::Starting);
38 Supermaven::set_global(supermaven.clone(), cx);
39
40 let mut provider = all_language_settings(None, cx).edit_predictions.provider;
41 if provider == language::language_settings::EditPredictionProvider::Supermaven {
42 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
43 }
44
45 cx.observe_global::<SettingsStore>(move |cx| {
46 let new_provider = all_language_settings(None, cx).edit_predictions.provider;
47 if new_provider != provider {
48 provider = new_provider;
49 if provider == language::language_settings::EditPredictionProvider::Supermaven {
50 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
51 } else {
52 supermaven.update(cx, |supermaven, _cx| supermaven.stop());
53 }
54 }
55 })
56 .detach();
57
58 cx.on_action(|_: &SignOut, cx| {
59 if let Some(supermaven) = Supermaven::global(cx) {
60 supermaven.update(cx, |supermaven, _cx| supermaven.sign_out());
61 }
62 });
63}
64
65pub enum Supermaven {
66 Starting,
67 FailedDownload { error: anyhow::Error },
68 Spawned(SupermavenAgent),
69 Error { error: anyhow::Error },
70}
71
72#[derive(Clone)]
73pub enum AccountStatus {
74 Unknown,
75 NeedsActivation { activate_url: String },
76 Ready,
77}
78
79#[derive(Clone)]
80struct SupermavenGlobal(Entity<Supermaven>);
81
82impl Global for SupermavenGlobal {}
83
84impl Supermaven {
85 pub fn global(cx: &App) -> Option<Entity<Self>> {
86 cx.try_global::<SupermavenGlobal>()
87 .map(|model| model.0.clone())
88 }
89
90 pub fn set_global(supermaven: Entity<Self>, cx: &mut App) {
91 cx.set_global(SupermavenGlobal(supermaven));
92 }
93
94 pub fn start(&mut self, client: Arc<Client>, cx: &mut Context<Self>) {
95 if let Self::Starting = self {
96 cx.spawn(async move |this, cx| {
97 let binary_path =
98 supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
99
100 this.update(cx, |this, cx| {
101 if let Self::Starting = this {
102 *this =
103 Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
104 }
105 anyhow::Ok(())
106 })
107 })
108 .detach_and_log_err(cx)
109 }
110 }
111
112 pub fn stop(&mut self) {
113 *self = Self::Starting;
114 }
115
116 pub fn is_enabled(&self) -> bool {
117 matches!(self, Self::Spawned { .. })
118 }
119
120 pub fn complete(
121 &mut self,
122 buffer: &Entity<Buffer>,
123 cursor_position: Anchor,
124 cx: &App,
125 ) -> Option<SupermavenCompletion> {
126 if let Self::Spawned(agent) = self {
127 let buffer_id = buffer.entity_id();
128 let buffer = buffer.read(cx);
129 let path = buffer
130 .file()
131 .and_then(|file| Some(file.as_local()?.abs_path(cx)))
132 .unwrap_or_else(|| PathBuf::from("untitled"))
133 .to_string_lossy()
134 .to_string();
135 let content = buffer.text();
136 let offset = cursor_position.to_offset(buffer);
137 let state_id = agent.next_state_id;
138 agent.next_state_id.0 += 1;
139
140 let (updates_tx, mut updates_rx) = watch::channel();
141 postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
142
143 agent.states.insert(
144 state_id,
145 SupermavenCompletionState {
146 buffer_id,
147 prefix_anchor: cursor_position,
148 prefix_offset: offset,
149 text: String::new(),
150 dedent: String::new(),
151 updates_tx,
152 },
153 );
154 // ensure the states map is max 1000 elements
155 if agent.states.len() > 1000 {
156 // state id is monotonic so it's sufficient to remove the first element
157 agent
158 .states
159 .remove(&agent.states.keys().next().unwrap().clone());
160 }
161
162 let _ = agent
163 .outgoing_tx
164 .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
165 new_id: state_id.0.to_string(),
166 updates: vec![
167 StateUpdate::FileUpdate(FileUpdateMessage {
168 path: path.clone(),
169 content,
170 }),
171 StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
172 ],
173 }));
174
175 Some(SupermavenCompletion {
176 id: state_id,
177 updates: updates_rx,
178 })
179 } else {
180 None
181 }
182 }
183
184 pub fn completion(
185 &self,
186 buffer: &Entity<Buffer>,
187 cursor_position: Anchor,
188 cx: &App,
189 ) -> Option<&str> {
190 if let Self::Spawned(agent) = self {
191 find_relevant_completion(
192 &agent.states,
193 buffer.entity_id(),
194 &buffer.read(cx).snapshot(),
195 cursor_position,
196 )
197 } else {
198 None
199 }
200 }
201
202 pub fn sign_out(&mut self) {
203 if let Self::Spawned(agent) = self {
204 agent
205 .outgoing_tx
206 .unbounded_send(OutboundMessage::Logout)
207 .ok();
208 // The account status will get set to RequiresActivation or Ready when the next
209 // message from the agent comes in. Until that happens, set the status to Unknown
210 // to disable the button.
211 agent.account_status = AccountStatus::Unknown;
212 }
213 }
214}
215
216fn find_relevant_completion<'a>(
217 states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
218 buffer_id: EntityId,
219 buffer: &BufferSnapshot,
220 cursor_position: Anchor,
221) -> Option<&'a str> {
222 let mut best_completion: Option<&str> = None;
223 'completions: for state in states.values() {
224 if state.buffer_id != buffer_id {
225 continue;
226 }
227 let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
228 continue;
229 };
230
231 let current_cursor_offset = cursor_position.to_offset(buffer);
232 if current_cursor_offset < state.prefix_offset {
233 continue;
234 }
235
236 let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left);
237 let text_inserted_since_completion_request: String = buffer
238 .text_for_range(original_cursor_offset..current_cursor_offset)
239 .collect();
240 let trimmed_completion =
241 match state_completion.strip_prefix(&text_inserted_since_completion_request) {
242 Some(suffix) => suffix,
243 None => continue 'completions,
244 };
245
246 if best_completion.is_some_and(|best| best.len() > trimmed_completion.len()) {
247 continue;
248 }
249
250 best_completion = Some(trimmed_completion);
251 }
252 best_completion
253}
254
255pub struct SupermavenAgent {
256 _process: Child,
257 next_state_id: SupermavenCompletionStateId,
258 states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
259 outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
260 _handle_outgoing_messages: Task<Result<()>>,
261 _handle_incoming_messages: Task<Result<()>>,
262 pub account_status: AccountStatus,
263 service_tier: Option<ServiceTier>,
264 #[allow(dead_code)]
265 client: Arc<Client>,
266}
267
268impl SupermavenAgent {
269 fn new(
270 binary_path: PathBuf,
271 client: Arc<Client>,
272 cx: &mut Context<Supermaven>,
273 ) -> Result<Self> {
274 let mut process = util::command::new_smol_command(&binary_path)
275 .arg("stdio")
276 .stdin(Stdio::piped())
277 .stdout(Stdio::piped())
278 .stderr(Stdio::piped())
279 .kill_on_drop(true)
280 .spawn()
281 .context("failed to start the binary")?;
282
283 let stdin = process
284 .stdin
285 .take()
286 .context("failed to get stdin for process")?;
287 let stdout = process
288 .stdout
289 .take()
290 .context("failed to get stdout for process")?;
291
292 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
293
294 cx.spawn({
295 let client = client.clone();
296 let outgoing_tx = outgoing_tx.clone();
297 async move |this, cx| {
298 let mut status = client.status();
299 while let Some(status) = status.next().await {
300 if status.is_connected() {
301 let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
302 outgoing_tx
303 .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
304 .ok();
305 this.update(cx, |this, cx| {
306 if let Supermaven::Spawned(this) = this {
307 this.account_status = AccountStatus::Ready;
308 cx.notify();
309 }
310 })?;
311 break;
312 }
313 }
314 anyhow::Ok(())
315 }
316 })
317 .detach();
318
319 Ok(Self {
320 _process: process,
321 next_state_id: SupermavenCompletionStateId::default(),
322 states: BTreeMap::default(),
323 outgoing_tx,
324 _handle_outgoing_messages: cx.spawn(async move |_, _cx| {
325 Self::handle_outgoing_messages(outgoing_rx, stdin).await
326 }),
327 _handle_incoming_messages: cx.spawn(async move |this, cx| {
328 Self::handle_incoming_messages(this, stdout, cx).await
329 }),
330 account_status: AccountStatus::Unknown,
331 service_tier: None,
332 client,
333 })
334 }
335
336 async fn handle_outgoing_messages(
337 mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
338 mut stdin: ChildStdin,
339 ) -> Result<()> {
340 while let Some(message) = outgoing.next().await {
341 let bytes = serde_json::to_vec(&message)?;
342 stdin.write_all(&bytes).await?;
343 stdin.write_all(&[b'\n']).await?;
344 }
345 Ok(())
346 }
347
348 async fn handle_incoming_messages(
349 this: WeakEntity<Supermaven>,
350 stdout: ChildStdout,
351 cx: &mut AsyncApp,
352 ) -> Result<()> {
353 const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
354
355 let stdout = BufReader::new(stdout);
356 let mut lines = stdout.lines();
357 while let Some(line) = lines.next().await {
358 let Some(line) = line.context("failed to read line from stdout").log_err() else {
359 continue;
360 };
361 let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
362 continue;
363 };
364 let Some(message) = serde_json::from_str::<SupermavenMessage>(line)
365 .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
366 .log_err()
367 else {
368 continue;
369 };
370
371 this.update(cx, |this, _cx| {
372 if let Supermaven::Spawned(this) = this {
373 this.handle_message(message);
374 }
375 Task::ready(anyhow::Ok(()))
376 })?
377 .await?;
378 }
379
380 Ok(())
381 }
382
383 fn handle_message(&mut self, message: SupermavenMessage) {
384 match message {
385 SupermavenMessage::ActivationRequest(request) => {
386 self.account_status = match request.activate_url {
387 Some(activate_url) => AccountStatus::NeedsActivation {
388 activate_url: activate_url.clone(),
389 },
390 None => AccountStatus::Ready,
391 };
392 }
393 SupermavenMessage::ActivationSuccess => {
394 self.account_status = AccountStatus::Ready;
395 }
396 SupermavenMessage::ServiceTier { service_tier } => {
397 self.account_status = AccountStatus::Ready;
398 self.service_tier = Some(service_tier);
399 }
400 SupermavenMessage::Response(response) => {
401 let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
402 if let Some(state) = self.states.get_mut(&state_id) {
403 for item in &response.items {
404 match item {
405 ResponseItem::Text { text } => state.text.push_str(text),
406 ResponseItem::Dedent { text } => state.dedent.push_str(text),
407 _ => {}
408 }
409 }
410 *state.updates_tx.borrow_mut() = ();
411 }
412 }
413 SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
414 _ => {
415 log::warn!("unhandled message: {:?}", message);
416 }
417 }
418 }
419}
420
421#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
422pub struct SupermavenCompletionStateId(usize);
423
424#[allow(dead_code)]
425pub struct SupermavenCompletionState {
426 buffer_id: EntityId,
427 prefix_anchor: Anchor,
428 // prefix_offset is tracked independently because the anchor biases left which
429 // doesn't allow us to determine if the prior text has been deleted.
430 prefix_offset: usize,
431 text: String,
432 dedent: String,
433 updates_tx: watch::Sender<()>,
434}
435
436pub struct SupermavenCompletion {
437 pub id: SupermavenCompletionStateId,
438 pub updates: watch::Receiver<()>,
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use collections::BTreeMap;
445 use gpui::TestAppContext;
446 use language::Buffer;
447
448 #[gpui::test]
449 async fn test_find_relevant_completion_no_first_letter_skip(cx: &mut TestAppContext) {
450 let buffer = cx.new(|cx| Buffer::local("hello world", cx));
451 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
452
453 let mut states = BTreeMap::new();
454 let state_id = SupermavenCompletionStateId(1);
455 let (updates_tx, _) = watch::channel();
456
457 states.insert(
458 state_id,
459 SupermavenCompletionState {
460 buffer_id: buffer.entity_id(),
461 prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer
462 prefix_offset: 0,
463 text: "hello".to_string(),
464 dedent: String::new(),
465 updates_tx,
466 },
467 );
468
469 let cursor_position = buffer_snapshot.anchor_after(1);
470
471 let result = find_relevant_completion(
472 &states,
473 buffer.entity_id(),
474 &buffer_snapshot,
475 cursor_position,
476 );
477
478 assert_eq!(result, Some("ello"));
479 }
480
481 #[gpui::test]
482 async fn test_find_relevant_completion_with_multiple_chars(cx: &mut TestAppContext) {
483 let buffer = cx.new(|cx| Buffer::local("hello world", cx));
484 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
485
486 let mut states = BTreeMap::new();
487 let state_id = SupermavenCompletionStateId(1);
488 let (updates_tx, _) = watch::channel();
489
490 states.insert(
491 state_id,
492 SupermavenCompletionState {
493 buffer_id: buffer.entity_id(),
494 prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer
495 prefix_offset: 0,
496 text: "hello".to_string(),
497 dedent: String::new(),
498 updates_tx,
499 },
500 );
501
502 let cursor_position = buffer_snapshot.anchor_after(3);
503
504 let result = find_relevant_completion(
505 &states,
506 buffer.entity_id(),
507 &buffer_snapshot,
508 cursor_position,
509 );
510
511 assert_eq!(result, Some("lo"));
512 }
513}