1mod messages;
2mod supermaven_completion_provider;
3
4pub use supermaven_completion_provider::*;
5
6use anyhow::{Context as _, Result};
7#[allow(unused_imports)]
8use client::{Client, proto};
9use collections::BTreeMap;
10
11use futures::{AsyncBufReadExt, StreamExt, channel::mpsc, io::BufReader};
12use gpui::{App, AsyncApp, Context, Entity, EntityId, Global, Task, WeakEntity, actions};
13use language::{
14 Anchor, Buffer, BufferSnapshot, ToOffset, language_settings::all_language_settings,
15};
16use messages::*;
17use postage::watch;
18use serde::{Deserialize, Serialize};
19use settings::SettingsStore;
20use smol::{
21 io::AsyncWriteExt,
22 process::{Child, ChildStdin, ChildStdout},
23};
24use std::{path::PathBuf, process::Stdio, sync::Arc};
25use ui::prelude::*;
26use util::ResultExt;
27
28actions!(supermaven, [SignOut]);
29
30pub fn init(client: Arc<Client>, cx: &mut App) {
31 let supermaven = cx.new(|_| Supermaven::Starting);
32 Supermaven::set_global(supermaven.clone(), cx);
33
34 let mut provider = all_language_settings(None, cx).edit_predictions.provider;
35 if provider == language::language_settings::EditPredictionProvider::Supermaven {
36 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
37 }
38
39 cx.observe_global::<SettingsStore>(move |cx| {
40 let new_provider = all_language_settings(None, cx).edit_predictions.provider;
41 if new_provider != provider {
42 provider = new_provider;
43 if provider == language::language_settings::EditPredictionProvider::Supermaven {
44 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
45 } else {
46 supermaven.update(cx, |supermaven, _cx| supermaven.stop());
47 }
48 }
49 })
50 .detach();
51
52 cx.on_action(|_: &SignOut, cx| {
53 if let Some(supermaven) = Supermaven::global(cx) {
54 supermaven.update(cx, |supermaven, _cx| supermaven.sign_out());
55 }
56 });
57}
58
59pub enum Supermaven {
60 Starting,
61 FailedDownload { error: anyhow::Error },
62 Spawned(SupermavenAgent),
63 Error { error: anyhow::Error },
64}
65
66#[derive(Clone)]
67pub enum AccountStatus {
68 Unknown,
69 NeedsActivation { activate_url: String },
70 Ready,
71}
72
73#[derive(Clone)]
74struct SupermavenGlobal(Entity<Supermaven>);
75
76impl Global for SupermavenGlobal {}
77
78impl Supermaven {
79 pub fn global(cx: &App) -> Option<Entity<Self>> {
80 cx.try_global::<SupermavenGlobal>()
81 .map(|model| model.0.clone())
82 }
83
84 pub fn set_global(supermaven: Entity<Self>, cx: &mut App) {
85 cx.set_global(SupermavenGlobal(supermaven));
86 }
87
88 pub fn start(&mut self, client: Arc<Client>, cx: &mut Context<Self>) {
89 if let Self::Starting = self {
90 cx.spawn(async move |this, cx| {
91 let binary_path =
92 supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
93
94 this.update(cx, |this, cx| {
95 if let Self::Starting = this {
96 *this =
97 Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
98 }
99 anyhow::Ok(())
100 })
101 })
102 .detach_and_log_err(cx)
103 }
104 }
105
106 pub fn stop(&mut self) {
107 *self = Self::Starting;
108 }
109
110 pub fn is_enabled(&self) -> bool {
111 matches!(self, Self::Spawned { .. })
112 }
113
114 pub fn complete(
115 &mut self,
116 buffer: &Entity<Buffer>,
117 cursor_position: Anchor,
118 cx: &App,
119 ) -> Option<SupermavenCompletion> {
120 if let Self::Spawned(agent) = self {
121 let buffer_id = buffer.entity_id();
122 let buffer = buffer.read(cx);
123 let path = buffer
124 .file()
125 .and_then(|file| Some(file.as_local()?.abs_path(cx)))
126 .unwrap_or_else(|| PathBuf::from("untitled"))
127 .to_string_lossy()
128 .to_string();
129 let content = buffer.text();
130 let offset = cursor_position.to_offset(buffer);
131 let state_id = agent.next_state_id;
132 agent.next_state_id.0 += 1;
133
134 let (updates_tx, mut updates_rx) = watch::channel();
135 postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
136
137 agent.states.insert(
138 state_id,
139 SupermavenCompletionState {
140 buffer_id,
141 prefix_anchor: cursor_position,
142 prefix_offset: offset,
143 text: String::new(),
144 dedent: String::new(),
145 updates_tx,
146 },
147 );
148 // ensure the states map is max 1000 elements
149 if agent.states.len() > 1000 {
150 // state id is monotonic so it's sufficient to remove the first element
151 agent
152 .states
153 .remove(&agent.states.keys().next().unwrap().clone());
154 }
155
156 let _ = agent
157 .outgoing_tx
158 .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
159 new_id: state_id.0.to_string(),
160 updates: vec![
161 StateUpdate::FileUpdate(FileUpdateMessage {
162 path: path.clone(),
163 content,
164 }),
165 StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
166 ],
167 }));
168
169 Some(SupermavenCompletion {
170 id: state_id,
171 updates: updates_rx,
172 })
173 } else {
174 None
175 }
176 }
177
178 pub fn completion(
179 &self,
180 buffer: &Entity<Buffer>,
181 cursor_position: Anchor,
182 cx: &App,
183 ) -> Option<&str> {
184 if let Self::Spawned(agent) = self {
185 find_relevant_completion(
186 &agent.states,
187 buffer.entity_id(),
188 &buffer.read(cx).snapshot(),
189 cursor_position,
190 )
191 } else {
192 None
193 }
194 }
195
196 pub fn sign_out(&mut self) {
197 if let Self::Spawned(agent) = self {
198 agent
199 .outgoing_tx
200 .unbounded_send(OutboundMessage::Logout)
201 .ok();
202 // The account status will get set to RequiresActivation or Ready when the next
203 // message from the agent comes in. Until that happens, set the status to Unknown
204 // to disable the button.
205 agent.account_status = AccountStatus::Unknown;
206 }
207 }
208}
209
210fn find_relevant_completion<'a>(
211 states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
212 buffer_id: EntityId,
213 buffer: &BufferSnapshot,
214 cursor_position: Anchor,
215) -> Option<&'a str> {
216 let mut best_completion: Option<&str> = None;
217 'completions: for state in states.values() {
218 if state.buffer_id != buffer_id {
219 continue;
220 }
221 let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
222 continue;
223 };
224
225 let current_cursor_offset = cursor_position.to_offset(buffer);
226 if current_cursor_offset < state.prefix_offset {
227 continue;
228 }
229
230 let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left);
231 let text_inserted_since_completion_request =
232 buffer.text_for_range(original_cursor_offset..current_cursor_offset);
233 let mut trimmed_completion = state_completion;
234 for chunk in text_inserted_since_completion_request {
235 if let Some(suffix) = trimmed_completion.strip_prefix(chunk) {
236 trimmed_completion = suffix;
237 } else {
238 continue 'completions;
239 }
240 }
241
242 if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) {
243 continue;
244 }
245
246 best_completion = Some(trimmed_completion);
247 }
248 best_completion
249}
250
251pub struct SupermavenAgent {
252 _process: Child,
253 next_state_id: SupermavenCompletionStateId,
254 states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
255 outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
256 _handle_outgoing_messages: Task<Result<()>>,
257 _handle_incoming_messages: Task<Result<()>>,
258 pub account_status: AccountStatus,
259 service_tier: Option<ServiceTier>,
260 #[allow(dead_code)]
261 client: Arc<Client>,
262}
263
264impl SupermavenAgent {
265 fn new(
266 binary_path: PathBuf,
267 client: Arc<Client>,
268 cx: &mut Context<Supermaven>,
269 ) -> Result<Self> {
270 let mut process = util::command::new_smol_command(&binary_path)
271 .arg("stdio")
272 .stdin(Stdio::piped())
273 .stdout(Stdio::piped())
274 .stderr(Stdio::piped())
275 .kill_on_drop(true)
276 .spawn()
277 .context("failed to start the binary")?;
278
279 let stdin = process
280 .stdin
281 .take()
282 .context("failed to get stdin for process")?;
283 let stdout = process
284 .stdout
285 .take()
286 .context("failed to get stdout for process")?;
287
288 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
289
290 cx.spawn({
291 let client = client.clone();
292 let outgoing_tx = outgoing_tx.clone();
293 async move |this, cx| {
294 let mut status = client.status();
295 while let Some(status) = status.next().await {
296 if status.is_connected() {
297 let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
298 outgoing_tx
299 .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
300 .ok();
301 this.update(cx, |this, cx| {
302 if let Supermaven::Spawned(this) = this {
303 this.account_status = AccountStatus::Ready;
304 cx.notify();
305 }
306 })?;
307 break;
308 }
309 }
310 anyhow::Ok(())
311 }
312 })
313 .detach();
314
315 Ok(Self {
316 _process: process,
317 next_state_id: SupermavenCompletionStateId::default(),
318 states: BTreeMap::default(),
319 outgoing_tx,
320 _handle_outgoing_messages: cx.spawn(async move |_, _cx| {
321 Self::handle_outgoing_messages(outgoing_rx, stdin).await
322 }),
323 _handle_incoming_messages: cx.spawn(async move |this, cx| {
324 Self::handle_incoming_messages(this, stdout, cx).await
325 }),
326 account_status: AccountStatus::Unknown,
327 service_tier: None,
328 client,
329 })
330 }
331
332 async fn handle_outgoing_messages(
333 mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
334 mut stdin: ChildStdin,
335 ) -> Result<()> {
336 while let Some(message) = outgoing.next().await {
337 let bytes = serde_json::to_vec(&message)?;
338 stdin.write_all(&bytes).await?;
339 stdin.write_all(&[b'\n']).await?;
340 }
341 Ok(())
342 }
343
344 async fn handle_incoming_messages(
345 this: WeakEntity<Supermaven>,
346 stdout: ChildStdout,
347 cx: &mut AsyncApp,
348 ) -> Result<()> {
349 const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
350
351 let stdout = BufReader::new(stdout);
352 let mut lines = stdout.lines();
353 while let Some(line) = lines.next().await {
354 let Some(line) = line.context("failed to read line from stdout").log_err() else {
355 continue;
356 };
357 let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
358 continue;
359 };
360 let Some(message) = serde_json::from_str::<SupermavenMessage>(line)
361 .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
362 .log_err()
363 else {
364 continue;
365 };
366
367 this.update(cx, |this, _cx| {
368 if let Supermaven::Spawned(this) = this {
369 this.handle_message(message);
370 }
371 Task::ready(anyhow::Ok(()))
372 })?
373 .await?;
374 }
375
376 Ok(())
377 }
378
379 fn handle_message(&mut self, message: SupermavenMessage) {
380 match message {
381 SupermavenMessage::ActivationRequest(request) => {
382 self.account_status = match request.activate_url {
383 Some(activate_url) => AccountStatus::NeedsActivation {
384 activate_url: activate_url.clone(),
385 },
386 None => AccountStatus::Ready,
387 };
388 }
389 SupermavenMessage::ActivationSuccess => {
390 self.account_status = AccountStatus::Ready;
391 }
392 SupermavenMessage::ServiceTier { service_tier } => {
393 self.account_status = AccountStatus::Ready;
394 self.service_tier = Some(service_tier);
395 }
396 SupermavenMessage::Response(response) => {
397 let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
398 if let Some(state) = self.states.get_mut(&state_id) {
399 for item in &response.items {
400 match item {
401 ResponseItem::Text { text } => state.text.push_str(text),
402 ResponseItem::Dedent { text } => state.dedent.push_str(text),
403 _ => {}
404 }
405 }
406 *state.updates_tx.borrow_mut() = ();
407 }
408 }
409 SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
410 _ => {
411 log::warn!("unhandled message: {:?}", message);
412 }
413 }
414 }
415}
416
417#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
418pub struct SupermavenCompletionStateId(usize);
419
420#[allow(dead_code)]
421pub struct SupermavenCompletionState {
422 buffer_id: EntityId,
423 prefix_anchor: Anchor,
424 // prefix_offset is tracked independently because the anchor biases left which
425 // doesn't allow us to determine if the prior text has been deleted.
426 prefix_offset: usize,
427 text: String,
428 dedent: String,
429 updates_tx: watch::Sender<()>,
430}
431
432pub struct SupermavenCompletion {
433 pub id: SupermavenCompletionStateId,
434 pub updates: watch::Receiver<()>,
435}