1mod messages;
2mod supermaven_edit_prediction_delegate;
3
4pub use supermaven_edit_prediction_delegate::*;
5
6use anyhow::{Context as _, Result};
7#[allow(unused_imports)]
8use client::{Client, proto};
9use collections::BTreeMap;
10
11use futures::{AsyncBufReadExt, StreamExt, channel::mpsc, io::BufReader};
12use gpui::{App, AsyncApp, Context, Entity, EntityId, Global, Task, WeakEntity, actions};
13use language::{
14 Anchor, Buffer, BufferSnapshot, ToOffset, language_settings::all_language_settings,
15};
16use messages::*;
17use postage::watch;
18use serde::{Deserialize, Serialize};
19use settings::SettingsStore;
20use smol::io::AsyncWriteExt;
21use std::{path::PathBuf, sync::Arc};
22use ui::prelude::*;
23use util::ResultExt;
24use util::command::Child;
25use util::command::Stdio;
26
27actions!(
28 supermaven,
29 [
30 /// Signs out of Supermaven.
31 SignOut
32 ]
33);
34
35pub fn init(client: Arc<Client>, cx: &mut App) {
36 let supermaven = cx.new(|_| Supermaven::Starting);
37 Supermaven::set_global(supermaven.clone(), cx);
38
39 let mut provider = all_language_settings(None, cx).edit_predictions.provider;
40 if provider == language::language_settings::EditPredictionProvider::Supermaven {
41 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
42 }
43
44 cx.observe_global::<SettingsStore>(move |cx| {
45 let new_provider = all_language_settings(None, cx).edit_predictions.provider;
46 if new_provider != provider {
47 provider = new_provider;
48 if provider == language::language_settings::EditPredictionProvider::Supermaven {
49 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
50 } else {
51 supermaven.update(cx, |supermaven, _cx| supermaven.stop());
52 }
53 }
54 })
55 .detach();
56
57 cx.on_action(|_: &SignOut, cx| {
58 if let Some(supermaven) = Supermaven::global(cx) {
59 supermaven.update(cx, |supermaven, _cx| supermaven.sign_out());
60 }
61 });
62}
63
64pub enum Supermaven {
65 Starting,
66 FailedDownload { error: anyhow::Error },
67 Spawned(SupermavenAgent),
68 Error { error: anyhow::Error },
69}
70
71#[derive(Clone)]
72pub enum AccountStatus {
73 Unknown,
74 NeedsActivation { activate_url: String },
75 Ready,
76}
77
78#[derive(Clone)]
79struct SupermavenGlobal(Entity<Supermaven>);
80
81impl Global for SupermavenGlobal {}
82
83impl Supermaven {
84 pub fn global(cx: &App) -> Option<Entity<Self>> {
85 cx.try_global::<SupermavenGlobal>()
86 .map(|model| model.0.clone())
87 }
88
89 pub fn set_global(supermaven: Entity<Self>, cx: &mut App) {
90 cx.set_global(SupermavenGlobal(supermaven));
91 }
92
93 pub fn start(&mut self, client: Arc<Client>, cx: &mut Context<Self>) {
94 if let Self::Starting = self {
95 cx.spawn(async move |this, cx| {
96 let binary_path =
97 supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
98
99 this.update(cx, |this, cx| {
100 if let Self::Starting = this {
101 *this =
102 Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
103 }
104 anyhow::Ok(())
105 })
106 })
107 .detach_and_log_err(cx)
108 }
109 }
110
111 pub fn stop(&mut self) {
112 *self = Self::Starting;
113 }
114
115 pub fn is_enabled(&self) -> bool {
116 matches!(self, Self::Spawned { .. })
117 }
118
119 pub fn complete(
120 &mut self,
121 buffer: &Entity<Buffer>,
122 cursor_position: Anchor,
123 cx: &App,
124 ) -> Option<SupermavenCompletion> {
125 if let Self::Spawned(agent) = self {
126 let buffer_id = buffer.entity_id();
127 let buffer = buffer.read(cx);
128 let path = buffer
129 .file()
130 .and_then(|file| Some(file.as_local()?.abs_path(cx)))
131 .unwrap_or_else(|| PathBuf::from("untitled"))
132 .to_string_lossy()
133 .to_string();
134 let content = buffer.text();
135 let offset = cursor_position.to_offset(buffer);
136 let state_id = agent.next_state_id;
137 agent.next_state_id.0 += 1;
138
139 let (updates_tx, mut updates_rx) = watch::channel();
140 postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
141
142 agent.states.insert(
143 state_id,
144 SupermavenCompletionState {
145 buffer_id,
146 prefix_anchor: cursor_position,
147 prefix_offset: offset,
148 text: String::new(),
149 dedent: String::new(),
150 updates_tx,
151 },
152 );
153 // ensure the states map is max 1000 elements
154 if agent.states.len() > 1000 {
155 // state id is monotonic so it's sufficient to remove the first element
156 agent
157 .states
158 .remove(&agent.states.keys().next().unwrap().clone());
159 }
160
161 let _ = agent
162 .outgoing_tx
163 .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
164 new_id: state_id.0.to_string(),
165 updates: vec![
166 StateUpdate::FileUpdate(FileUpdateMessage {
167 path: path.clone(),
168 content,
169 }),
170 StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
171 ],
172 }));
173
174 Some(SupermavenCompletion {
175 id: state_id,
176 updates: updates_rx,
177 })
178 } else {
179 None
180 }
181 }
182
183 pub fn completion(
184 &self,
185 buffer: &Entity<Buffer>,
186 cursor_position: Anchor,
187 cx: &App,
188 ) -> Option<&str> {
189 if let Self::Spawned(agent) = self {
190 find_relevant_completion(
191 &agent.states,
192 buffer.entity_id(),
193 &buffer.read(cx).snapshot(),
194 cursor_position,
195 )
196 } else {
197 None
198 }
199 }
200
201 pub fn sign_out(&mut self) {
202 if let Self::Spawned(agent) = self {
203 agent
204 .outgoing_tx
205 .unbounded_send(OutboundMessage::Logout)
206 .ok();
207 // The account status will get set to RequiresActivation or Ready when the next
208 // message from the agent comes in. Until that happens, set the status to Unknown
209 // to disable the button.
210 agent.account_status = AccountStatus::Unknown;
211 }
212 }
213}
214
215fn find_relevant_completion<'a>(
216 states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
217 buffer_id: EntityId,
218 buffer: &BufferSnapshot,
219 cursor_position: Anchor,
220) -> Option<&'a str> {
221 let mut best_completion: Option<&str> = None;
222 'completions: for state in states.values() {
223 if state.buffer_id != buffer_id {
224 continue;
225 }
226 let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
227 continue;
228 };
229
230 let current_cursor_offset = cursor_position.to_offset(buffer);
231 if current_cursor_offset < state.prefix_offset {
232 continue;
233 }
234
235 let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left);
236 let text_inserted_since_completion_request: String = buffer
237 .text_for_range(original_cursor_offset..current_cursor_offset)
238 .collect();
239 let trimmed_completion =
240 match state_completion.strip_prefix(&text_inserted_since_completion_request) {
241 Some(suffix) => suffix,
242 None => continue 'completions,
243 };
244
245 if best_completion.is_some_and(|best| best.len() > trimmed_completion.len()) {
246 continue;
247 }
248
249 best_completion = Some(trimmed_completion);
250 }
251 best_completion
252}
253
254pub struct SupermavenAgent {
255 _process: Child,
256 next_state_id: SupermavenCompletionStateId,
257 states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
258 outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
259 _handle_outgoing_messages: Task<Result<()>>,
260 _handle_incoming_messages: Task<Result<()>>,
261 pub account_status: AccountStatus,
262 service_tier: Option<ServiceTier>,
263 #[allow(dead_code)]
264 client: Arc<Client>,
265}
266
267impl SupermavenAgent {
268 fn new(
269 binary_path: PathBuf,
270 client: Arc<Client>,
271 cx: &mut Context<Supermaven>,
272 ) -> Result<Self> {
273 let mut process = util::command::new_command(&binary_path)
274 .arg("stdio")
275 .stdin(Stdio::piped())
276 .stdout(Stdio::piped())
277 .stderr(Stdio::piped())
278 .kill_on_drop(true)
279 .spawn()
280 .context("failed to start the binary")?;
281
282 let stdin = process
283 .stdin
284 .take()
285 .context("failed to get stdin for process")?;
286 let stdout = process
287 .stdout
288 .take()
289 .context("failed to get stdout for process")?;
290
291 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
292
293 Ok(Self {
294 _process: process,
295 next_state_id: SupermavenCompletionStateId::default(),
296 states: BTreeMap::default(),
297 outgoing_tx,
298 _handle_outgoing_messages: cx.spawn(async move |_, _cx| {
299 Self::handle_outgoing_messages(outgoing_rx, stdin).await
300 }),
301 _handle_incoming_messages: cx.spawn(async move |this, cx| {
302 Self::handle_incoming_messages(this, stdout, cx).await
303 }),
304 account_status: AccountStatus::Unknown,
305 service_tier: None,
306 client,
307 })
308 }
309
310 async fn handle_outgoing_messages<W: smol::io::AsyncWrite + Unpin>(
311 mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
312 mut stdin: W,
313 ) -> Result<()> {
314 while let Some(message) = outgoing.next().await {
315 let bytes = serde_json::to_vec(&message)?;
316 stdin.write_all(&bytes).await?;
317 stdin.write_all(&[b'\n']).await?;
318 }
319 Ok(())
320 }
321
322 async fn handle_incoming_messages<R: smol::io::AsyncRead + Unpin>(
323 this: WeakEntity<Supermaven>,
324 stdout: R,
325 cx: &mut AsyncApp,
326 ) -> Result<()> {
327 const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
328
329 let stdout = BufReader::new(stdout);
330 let mut lines = stdout.lines();
331 while let Some(line) = lines.next().await {
332 let Some(line) = line.context("failed to read line from stdout").log_err() else {
333 continue;
334 };
335 let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
336 continue;
337 };
338 let Some(message) = serde_json::from_str::<SupermavenMessage>(line)
339 .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
340 .log_err()
341 else {
342 continue;
343 };
344
345 this.update(cx, |this, _cx| {
346 if let Supermaven::Spawned(this) = this {
347 this.handle_message(message);
348 }
349 Task::ready(anyhow::Ok(()))
350 })?
351 .await?;
352 }
353
354 Ok(())
355 }
356
357 fn handle_message(&mut self, message: SupermavenMessage) {
358 match message {
359 SupermavenMessage::ActivationRequest(request) => {
360 self.account_status = match request.activate_url {
361 Some(activate_url) => AccountStatus::NeedsActivation { activate_url },
362 None => AccountStatus::Ready,
363 };
364 }
365 SupermavenMessage::ActivationSuccess => {
366 self.account_status = AccountStatus::Ready;
367 }
368 SupermavenMessage::ServiceTier { service_tier } => {
369 self.account_status = AccountStatus::Ready;
370 self.service_tier = Some(service_tier);
371 }
372 SupermavenMessage::Response(response) => {
373 let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
374 if let Some(state) = self.states.get_mut(&state_id) {
375 for item in &response.items {
376 match item {
377 ResponseItem::Text { text } => state.text.push_str(text),
378 ResponseItem::Dedent { text } => state.dedent.push_str(text),
379 _ => {}
380 }
381 }
382 *state.updates_tx.borrow_mut() = ();
383 }
384 }
385 SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
386 _ => {
387 log::warn!("unhandled message: {:?}", message);
388 }
389 }
390 }
391}
392
393#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
394pub struct SupermavenCompletionStateId(usize);
395
396#[allow(dead_code)]
397pub struct SupermavenCompletionState {
398 buffer_id: EntityId,
399 prefix_anchor: Anchor,
400 // prefix_offset is tracked independently because the anchor biases left which
401 // doesn't allow us to determine if the prior text has been deleted.
402 prefix_offset: usize,
403 text: String,
404 dedent: String,
405 updates_tx: watch::Sender<()>,
406}
407
408pub struct SupermavenCompletion {
409 pub id: SupermavenCompletionStateId,
410 pub updates: watch::Receiver<()>,
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use collections::BTreeMap;
417 use gpui::TestAppContext;
418 use language::Buffer;
419
420 #[gpui::test]
421 async fn test_find_relevant_completion_no_first_letter_skip(cx: &mut TestAppContext) {
422 let buffer = cx.new(|cx| Buffer::local("hello world", cx));
423 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
424
425 let mut states = BTreeMap::new();
426 let state_id = SupermavenCompletionStateId(1);
427 let (updates_tx, _) = watch::channel();
428
429 states.insert(
430 state_id,
431 SupermavenCompletionState {
432 buffer_id: buffer.entity_id(),
433 prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer
434 prefix_offset: 0,
435 text: "hello".to_string(),
436 dedent: String::new(),
437 updates_tx,
438 },
439 );
440
441 let cursor_position = buffer_snapshot.anchor_after(1);
442
443 let result = find_relevant_completion(
444 &states,
445 buffer.entity_id(),
446 &buffer_snapshot,
447 cursor_position,
448 );
449
450 assert_eq!(result, Some("ello"));
451 }
452
453 #[gpui::test]
454 async fn test_find_relevant_completion_with_multiple_chars(cx: &mut TestAppContext) {
455 let buffer = cx.new(|cx| Buffer::local("hello world", cx));
456 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
457
458 let mut states = BTreeMap::new();
459 let state_id = SupermavenCompletionStateId(1);
460 let (updates_tx, _) = watch::channel();
461
462 states.insert(
463 state_id,
464 SupermavenCompletionState {
465 buffer_id: buffer.entity_id(),
466 prefix_anchor: buffer_snapshot.anchor_before(0), // Start of buffer
467 prefix_offset: 0,
468 text: "hello".to_string(),
469 dedent: String::new(),
470 updates_tx,
471 },
472 );
473
474 let cursor_position = buffer_snapshot.anchor_after(3);
475
476 let result = find_relevant_completion(
477 &states,
478 buffer.entity_id(),
479 &buffer_snapshot,
480 cursor_position,
481 );
482
483 assert_eq!(result, Some("lo"));
484 }
485}