1mod messages;
2mod supermaven_completion_provider;
3
4pub use supermaven_completion_provider::*;
5
6use anyhow::{Context as _, Result};
7#[allow(unused_imports)]
8use client::{proto, Client};
9use collections::BTreeMap;
10
11use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt};
12use gpui::{
13 actions, AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel,
14};
15use language::{
16 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, ToOffset,
17};
18use messages::*;
19use postage::watch;
20use serde::{Deserialize, Serialize};
21use settings::SettingsStore;
22use smol::{
23 io::AsyncWriteExt,
24 process::{Child, ChildStdin, ChildStdout, Command},
25};
26use std::{path::PathBuf, process::Stdio, sync::Arc};
27use ui::prelude::*;
28use util::ResultExt;
29
30actions!(supermaven, [SignOut]);
31
32pub fn init(client: Arc<Client>, cx: &mut AppContext) {
33 let supermaven = cx.new_model(|_| Supermaven::Starting);
34 Supermaven::set_global(supermaven.clone(), cx);
35
36 let mut provider = all_language_settings(None, cx).inline_completions.provider;
37 if provider == language::language_settings::InlineCompletionProvider::Supermaven {
38 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
39 }
40
41 cx.observe_global::<SettingsStore>(move |cx| {
42 let new_provider = all_language_settings(None, cx).inline_completions.provider;
43 if new_provider != provider {
44 provider = new_provider;
45 if provider == language::language_settings::InlineCompletionProvider::Supermaven {
46 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
47 } else {
48 supermaven.update(cx, |supermaven, _cx| supermaven.stop());
49 }
50 }
51 })
52 .detach();
53
54 cx.on_action(|_: &SignOut, cx| {
55 if let Some(supermaven) = Supermaven::global(cx) {
56 supermaven.update(cx, |supermaven, _cx| supermaven.sign_out());
57 }
58 });
59}
60
61pub enum Supermaven {
62 Starting,
63 FailedDownload { error: anyhow::Error },
64 Spawned(SupermavenAgent),
65 Error { error: anyhow::Error },
66}
67
68#[derive(Clone)]
69pub enum AccountStatus {
70 Unknown,
71 NeedsActivation { activate_url: String },
72 Ready,
73}
74
75#[derive(Clone)]
76struct SupermavenGlobal(Model<Supermaven>);
77
78impl Global for SupermavenGlobal {}
79
80impl Supermaven {
81 pub fn global(cx: &AppContext) -> Option<Model<Self>> {
82 cx.try_global::<SupermavenGlobal>()
83 .map(|model| model.0.clone())
84 }
85
86 pub fn set_global(supermaven: Model<Self>, cx: &mut AppContext) {
87 cx.set_global(SupermavenGlobal(supermaven));
88 }
89
90 pub fn start(&mut self, client: Arc<Client>, cx: &mut ModelContext<Self>) {
91 if let Self::Starting = self {
92 cx.spawn(|this, mut cx| async move {
93 let binary_path =
94 supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
95
96 this.update(&mut cx, |this, cx| {
97 if let Self::Starting = this {
98 *this =
99 Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
100 }
101 anyhow::Ok(())
102 })
103 })
104 .detach_and_log_err(cx)
105 }
106 }
107
108 pub fn stop(&mut self) {
109 *self = Self::Starting;
110 }
111
112 pub fn is_enabled(&self) -> bool {
113 matches!(self, Self::Spawned { .. })
114 }
115
116 pub fn complete(
117 &mut self,
118 buffer: &Model<Buffer>,
119 cursor_position: Anchor,
120 cx: &AppContext,
121 ) -> Option<SupermavenCompletion> {
122 if let Self::Spawned(agent) = self {
123 let buffer_id = buffer.entity_id();
124 let buffer = buffer.read(cx);
125 let path = buffer
126 .file()
127 .and_then(|file| Some(file.as_local()?.abs_path(cx)))
128 .unwrap_or_else(|| PathBuf::from("untitled"))
129 .to_string_lossy()
130 .to_string();
131 let content = buffer.text();
132 let offset = cursor_position.to_offset(buffer);
133 let state_id = agent.next_state_id;
134 agent.next_state_id.0 += 1;
135
136 let (updates_tx, mut updates_rx) = watch::channel();
137 postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
138
139 agent.states.insert(
140 state_id,
141 SupermavenCompletionState {
142 buffer_id,
143 prefix_anchor: cursor_position,
144 prefix_offset: offset,
145 text: String::new(),
146 dedent: String::new(),
147 updates_tx,
148 },
149 );
150 // ensure the states map is max 1000 elements
151 if agent.states.len() > 1000 {
152 // state id is monotonic so it's sufficient to remove the first element
153 agent
154 .states
155 .remove(&agent.states.keys().next().unwrap().clone());
156 }
157
158 let _ = agent
159 .outgoing_tx
160 .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
161 new_id: state_id.0.to_string(),
162 updates: vec![
163 StateUpdate::FileUpdate(FileUpdateMessage {
164 path: path.clone(),
165 content,
166 }),
167 StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
168 ],
169 }));
170
171 Some(SupermavenCompletion {
172 id: state_id,
173 updates: updates_rx,
174 })
175 } else {
176 None
177 }
178 }
179
180 pub fn completion(
181 &self,
182 buffer: &Model<Buffer>,
183 cursor_position: Anchor,
184 cx: &AppContext,
185 ) -> Option<&str> {
186 if let Self::Spawned(agent) = self {
187 find_relevant_completion(
188 &agent.states,
189 buffer.entity_id(),
190 &buffer.read(cx).snapshot(),
191 cursor_position,
192 )
193 } else {
194 None
195 }
196 }
197
198 pub fn sign_out(&mut self) {
199 if let Self::Spawned(agent) = self {
200 agent
201 .outgoing_tx
202 .unbounded_send(OutboundMessage::Logout)
203 .ok();
204 // The account status will get set to RequiresActivation or Ready when the next
205 // message from the agent comes in. Until that happens, set the status to Unknown
206 // to disable the button.
207 agent.account_status = AccountStatus::Unknown;
208 }
209 }
210}
211
212fn find_relevant_completion<'a>(
213 states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
214 buffer_id: EntityId,
215 buffer: &BufferSnapshot,
216 cursor_position: Anchor,
217) -> Option<&'a str> {
218 let mut best_completion: Option<&str> = None;
219 'completions: for state in states.values() {
220 if state.buffer_id != buffer_id {
221 continue;
222 }
223 let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
224 continue;
225 };
226
227 let current_cursor_offset = cursor_position.to_offset(buffer);
228 if current_cursor_offset < state.prefix_offset {
229 continue;
230 }
231
232 let original_cursor_offset = buffer.clip_offset(state.prefix_offset, text::Bias::Left);
233 let text_inserted_since_completion_request =
234 buffer.text_for_range(original_cursor_offset..current_cursor_offset);
235 let mut trimmed_completion = state_completion;
236 for chunk in text_inserted_since_completion_request {
237 if let Some(suffix) = trimmed_completion.strip_prefix(chunk) {
238 trimmed_completion = suffix;
239 } else {
240 continue 'completions;
241 }
242 }
243
244 if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) {
245 continue;
246 }
247
248 best_completion = Some(trimmed_completion);
249 }
250 best_completion
251}
252
253pub struct SupermavenAgent {
254 _process: Child,
255 next_state_id: SupermavenCompletionStateId,
256 states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
257 outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
258 _handle_outgoing_messages: Task<Result<()>>,
259 _handle_incoming_messages: Task<Result<()>>,
260 pub account_status: AccountStatus,
261 service_tier: Option<ServiceTier>,
262 #[allow(dead_code)]
263 client: Arc<Client>,
264}
265
266impl SupermavenAgent {
267 fn new(
268 binary_path: PathBuf,
269 client: Arc<Client>,
270 cx: &mut ModelContext<Supermaven>,
271 ) -> Result<Self> {
272 let mut process = Command::new(&binary_path);
273 process
274 .arg("stdio")
275 .stdin(Stdio::piped())
276 .stdout(Stdio::piped())
277 .stderr(Stdio::piped())
278 .kill_on_drop(true);
279
280 #[cfg(target_os = "windows")]
281 {
282 use smol::process::windows::CommandExt;
283 process.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
284 }
285
286 let mut process = process.spawn().context("failed to start the binary")?;
287
288 let stdin = process
289 .stdin
290 .take()
291 .context("failed to get stdin for process")?;
292 let stdout = process
293 .stdout
294 .take()
295 .context("failed to get stdout for process")?;
296
297 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
298
299 cx.spawn({
300 let client = client.clone();
301 let outgoing_tx = outgoing_tx.clone();
302 move |this, mut cx| async move {
303 let mut status = client.status();
304 while let Some(status) = status.next().await {
305 if status.is_connected() {
306 let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
307 outgoing_tx
308 .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
309 .ok();
310 this.update(&mut cx, |this, cx| {
311 if let Supermaven::Spawned(this) = this {
312 this.account_status = AccountStatus::Ready;
313 cx.notify();
314 }
315 })?;
316 break;
317 }
318 }
319 anyhow::Ok(())
320 }
321 })
322 .detach();
323
324 Ok(Self {
325 _process: process,
326 next_state_id: SupermavenCompletionStateId::default(),
327 states: BTreeMap::default(),
328 outgoing_tx,
329 _handle_outgoing_messages: cx
330 .spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)),
331 _handle_incoming_messages: cx
332 .spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)),
333 account_status: AccountStatus::Unknown,
334 service_tier: None,
335 client,
336 })
337 }
338
339 async fn handle_outgoing_messages(
340 mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
341 mut stdin: ChildStdin,
342 ) -> Result<()> {
343 while let Some(message) = outgoing.next().await {
344 let bytes = serde_json::to_vec(&message)?;
345 stdin.write_all(&bytes).await?;
346 stdin.write_all(&[b'\n']).await?;
347 }
348 Ok(())
349 }
350
351 async fn handle_incoming_messages(
352 this: WeakModel<Supermaven>,
353 stdout: ChildStdout,
354 mut cx: AsyncAppContext,
355 ) -> Result<()> {
356 const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
357
358 let stdout = BufReader::new(stdout);
359 let mut lines = stdout.lines();
360 while let Some(line) = lines.next().await {
361 let Some(line) = line.context("failed to read line from stdout").log_err() else {
362 continue;
363 };
364 let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
365 continue;
366 };
367 let Some(message) = serde_json::from_str::<SupermavenMessage>(line)
368 .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
369 .log_err()
370 else {
371 continue;
372 };
373
374 this.update(&mut cx, |this, _cx| {
375 if let Supermaven::Spawned(this) = this {
376 this.handle_message(message);
377 }
378 Task::ready(anyhow::Ok(()))
379 })?
380 .await?;
381 }
382
383 Ok(())
384 }
385
386 fn handle_message(&mut self, message: SupermavenMessage) {
387 match message {
388 SupermavenMessage::ActivationRequest(request) => {
389 self.account_status = match request.activate_url {
390 Some(activate_url) => AccountStatus::NeedsActivation {
391 activate_url: activate_url.clone(),
392 },
393 None => AccountStatus::Ready,
394 };
395 }
396 SupermavenMessage::ActivationSuccess => {
397 self.account_status = AccountStatus::Ready;
398 }
399 SupermavenMessage::ServiceTier { service_tier } => {
400 self.account_status = AccountStatus::Ready;
401 self.service_tier = Some(service_tier);
402 }
403 SupermavenMessage::Response(response) => {
404 let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
405 if let Some(state) = self.states.get_mut(&state_id) {
406 for item in &response.items {
407 match item {
408 ResponseItem::Text { text } => state.text.push_str(text),
409 ResponseItem::Dedent { text } => state.dedent.push_str(text),
410 _ => {}
411 }
412 }
413 *state.updates_tx.borrow_mut() = ();
414 }
415 }
416 SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
417 _ => {
418 log::warn!("unhandled message: {:?}", message);
419 }
420 }
421 }
422}
423
424#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
425pub struct SupermavenCompletionStateId(usize);
426
427#[allow(dead_code)]
428pub struct SupermavenCompletionState {
429 buffer_id: EntityId,
430 prefix_anchor: Anchor,
431 // prefix_offset is tracked independently because the anchor biases left which
432 // doesn't allow us to determine if the prior text has been deleted.
433 prefix_offset: usize,
434 text: String,
435 dedent: String,
436 updates_tx: watch::Sender<()>,
437}
438
439pub struct SupermavenCompletion {
440 pub id: SupermavenCompletionStateId,
441 pub updates: watch::Receiver<()>,
442}