1mod messages;
2mod supermaven_completion_provider;
3
4pub use supermaven_completion_provider::*;
5
6use anyhow::{Context as _, Result};
7#[allow(unused_imports)]
8use client::{proto, Client};
9use collections::BTreeMap;
10
11use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt};
12use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel};
13use language::{
14 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, ToOffset,
15};
16use messages::*;
17use postage::watch;
18use serde::{Deserialize, Serialize};
19use settings::SettingsStore;
20use smol::{
21 io::AsyncWriteExt,
22 process::{Child, ChildStdin, ChildStdout, Command},
23};
24use std::{path::PathBuf, process::Stdio, sync::Arc};
25use ui::prelude::*;
26use util::ResultExt;
27
28pub fn init(client: Arc<Client>, cx: &mut AppContext) {
29 let supermaven = cx.new_model(|_| Supermaven::Starting);
30 Supermaven::set_global(supermaven.clone(), cx);
31
32 let mut provider = all_language_settings(None, cx).inline_completions.provider;
33 if provider == language::language_settings::InlineCompletionProvider::Supermaven {
34 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
35 }
36
37 cx.observe_global::<SettingsStore>(move |cx| {
38 let new_provider = all_language_settings(None, cx).inline_completions.provider;
39 if new_provider != provider {
40 provider = new_provider;
41 if provider == language::language_settings::InlineCompletionProvider::Supermaven {
42 supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx));
43 } else {
44 supermaven.update(cx, |supermaven, _cx| supermaven.stop());
45 }
46 }
47 })
48 .detach();
49}
50
51pub enum Supermaven {
52 Starting,
53 FailedDownload { error: anyhow::Error },
54 Spawned(SupermavenAgent),
55 Error { error: anyhow::Error },
56}
57
58#[derive(Clone)]
59pub enum AccountStatus {
60 Unknown,
61 NeedsActivation { activate_url: String },
62 Ready,
63}
64
65#[derive(Clone)]
66struct SupermavenGlobal(Model<Supermaven>);
67
68impl Global for SupermavenGlobal {}
69
70impl Supermaven {
71 pub fn global(cx: &AppContext) -> Option<Model<Self>> {
72 cx.try_global::<SupermavenGlobal>()
73 .map(|model| model.0.clone())
74 }
75
76 pub fn set_global(supermaven: Model<Self>, cx: &mut AppContext) {
77 cx.set_global(SupermavenGlobal(supermaven));
78 }
79
80 pub fn start(&mut self, client: Arc<Client>, cx: &mut ModelContext<Self>) {
81 if let Self::Starting = self {
82 cx.spawn(|this, mut cx| async move {
83 let binary_path =
84 supermaven_api::get_supermaven_agent_path(client.http_client()).await?;
85
86 this.update(&mut cx, |this, cx| {
87 if let Self::Starting = this {
88 *this =
89 Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?);
90 }
91 anyhow::Ok(())
92 })
93 })
94 .detach_and_log_err(cx)
95 }
96 }
97
98 pub fn stop(&mut self) {
99 *self = Self::Starting;
100 }
101
102 pub fn is_enabled(&self) -> bool {
103 matches!(self, Self::Spawned { .. })
104 }
105
106 pub fn complete(
107 &mut self,
108 buffer: &Model<Buffer>,
109 cursor_position: Anchor,
110 cx: &AppContext,
111 ) -> Option<SupermavenCompletion> {
112 if let Self::Spawned(agent) = self {
113 let buffer_id = buffer.entity_id();
114 let buffer = buffer.read(cx);
115 let path = buffer
116 .file()
117 .and_then(|file| Some(file.as_local()?.abs_path(cx)))
118 .unwrap_or_else(|| PathBuf::from("untitled"))
119 .to_string_lossy()
120 .to_string();
121 let content = buffer.text();
122 let offset = cursor_position.to_offset(buffer);
123 let state_id = agent.next_state_id;
124 agent.next_state_id.0 += 1;
125
126 let (updates_tx, mut updates_rx) = watch::channel();
127 postage::stream::Stream::try_recv(&mut updates_rx).unwrap();
128
129 agent.states.insert(
130 state_id,
131 SupermavenCompletionState {
132 buffer_id,
133 prefix_anchor: cursor_position,
134 text: String::new(),
135 dedent: String::new(),
136 updates_tx,
137 },
138 );
139 let _ = agent
140 .outgoing_tx
141 .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage {
142 new_id: state_id.0.to_string(),
143 updates: vec![
144 StateUpdate::FileUpdate(FileUpdateMessage {
145 path: path.clone(),
146 content,
147 }),
148 StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }),
149 ],
150 }));
151
152 Some(SupermavenCompletion {
153 id: state_id,
154 updates: updates_rx,
155 })
156 } else {
157 None
158 }
159 }
160
161 pub fn completion(
162 &self,
163 buffer: &Model<Buffer>,
164 cursor_position: Anchor,
165 cx: &AppContext,
166 ) -> Option<&str> {
167 if let Self::Spawned(agent) = self {
168 find_relevant_completion(
169 &agent.states,
170 buffer.entity_id(),
171 &buffer.read(cx).snapshot(),
172 cursor_position,
173 )
174 } else {
175 None
176 }
177 }
178}
179
180fn find_relevant_completion<'a>(
181 states: &'a BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
182 buffer_id: EntityId,
183 buffer: &BufferSnapshot,
184 cursor_position: Anchor,
185) -> Option<&'a str> {
186 let mut best_completion: Option<&str> = None;
187 'completions: for state in states.values() {
188 if state.buffer_id != buffer_id {
189 continue;
190 }
191 let Some(state_completion) = state.text.strip_prefix(&state.dedent) else {
192 continue;
193 };
194
195 let current_cursor_offset = cursor_position.to_offset(buffer);
196 let original_cursor_offset = state.prefix_anchor.to_offset(buffer);
197 if current_cursor_offset < original_cursor_offset {
198 continue;
199 }
200
201 let text_inserted_since_completion_request =
202 buffer.text_for_range(original_cursor_offset..current_cursor_offset);
203 let mut trimmed_completion = state_completion;
204 for chunk in text_inserted_since_completion_request {
205 if let Some(suffix) = trimmed_completion.strip_prefix(chunk) {
206 trimmed_completion = suffix;
207 } else {
208 continue 'completions;
209 }
210 }
211
212 if best_completion.map_or(false, |best| best.len() > trimmed_completion.len()) {
213 continue;
214 }
215
216 best_completion = Some(trimmed_completion);
217 }
218 best_completion
219}
220
221pub struct SupermavenAgent {
222 _process: Child,
223 next_state_id: SupermavenCompletionStateId,
224 states: BTreeMap<SupermavenCompletionStateId, SupermavenCompletionState>,
225 outgoing_tx: mpsc::UnboundedSender<OutboundMessage>,
226 _handle_outgoing_messages: Task<Result<()>>,
227 _handle_incoming_messages: Task<Result<()>>,
228 pub account_status: AccountStatus,
229 service_tier: Option<ServiceTier>,
230 #[allow(dead_code)]
231 client: Arc<Client>,
232}
233
234impl SupermavenAgent {
235 fn new(
236 binary_path: PathBuf,
237 client: Arc<Client>,
238 cx: &mut ModelContext<Supermaven>,
239 ) -> Result<Self> {
240 let mut process = Command::new(&binary_path)
241 .arg("stdio")
242 .stdin(Stdio::piped())
243 .stdout(Stdio::piped())
244 .stderr(Stdio::piped())
245 .kill_on_drop(true)
246 .spawn()
247 .context("failed to start the binary")?;
248
249 let stdin = process
250 .stdin
251 .take()
252 .context("failed to get stdin for process")?;
253 let stdout = process
254 .stdout
255 .take()
256 .context("failed to get stdout for process")?;
257
258 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
259
260 cx.spawn({
261 let client = client.clone();
262 let outgoing_tx = outgoing_tx.clone();
263 move |this, mut cx| async move {
264 let mut status = client.status();
265 while let Some(status) = status.next().await {
266 if status.is_connected() {
267 let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key;
268 outgoing_tx
269 .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key }))
270 .ok();
271 this.update(&mut cx, |this, cx| {
272 if let Supermaven::Spawned(this) = this {
273 this.account_status = AccountStatus::Ready;
274 cx.notify();
275 }
276 })?;
277 break;
278 }
279 }
280 return anyhow::Ok(());
281 }
282 })
283 .detach();
284
285 Ok(Self {
286 _process: process,
287 next_state_id: SupermavenCompletionStateId::default(),
288 states: BTreeMap::default(),
289 outgoing_tx,
290 _handle_outgoing_messages: cx
291 .spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)),
292 _handle_incoming_messages: cx
293 .spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)),
294 account_status: AccountStatus::Unknown,
295 service_tier: None,
296 client,
297 })
298 }
299
300 async fn handle_outgoing_messages(
301 mut outgoing: mpsc::UnboundedReceiver<OutboundMessage>,
302 mut stdin: ChildStdin,
303 ) -> Result<()> {
304 while let Some(message) = outgoing.next().await {
305 let bytes = serde_json::to_vec(&message)?;
306 stdin.write_all(&bytes).await?;
307 stdin.write_all(&[b'\n']).await?;
308 }
309 Ok(())
310 }
311
312 async fn handle_incoming_messages(
313 this: WeakModel<Supermaven>,
314 stdout: ChildStdout,
315 mut cx: AsyncAppContext,
316 ) -> Result<()> {
317 const MESSAGE_PREFIX: &str = "SM-MESSAGE ";
318
319 let stdout = BufReader::new(stdout);
320 let mut lines = stdout.lines();
321 while let Some(line) = lines.next().await {
322 let Some(line) = line.context("failed to read line from stdout").log_err() else {
323 continue;
324 };
325 let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else {
326 continue;
327 };
328 let Some(message) = serde_json::from_str::<SupermavenMessage>(&line)
329 .with_context(|| format!("failed to deserialize line from stdout: {:?}", line))
330 .log_err()
331 else {
332 continue;
333 };
334
335 this.update(&mut cx, |this, _cx| {
336 if let Supermaven::Spawned(this) = this {
337 this.handle_message(message);
338 }
339 Task::ready(anyhow::Ok(()))
340 })?
341 .await?;
342 }
343
344 Ok(())
345 }
346
347 fn handle_message(&mut self, message: SupermavenMessage) {
348 match message {
349 SupermavenMessage::ActivationRequest(request) => {
350 self.account_status = match request.activate_url {
351 Some(activate_url) => AccountStatus::NeedsActivation {
352 activate_url: activate_url.clone(),
353 },
354 None => AccountStatus::Ready,
355 };
356 }
357 SupermavenMessage::ServiceTier { service_tier } => {
358 self.service_tier = Some(service_tier);
359 }
360 SupermavenMessage::Response(response) => {
361 let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap());
362 if let Some(state) = self.states.get_mut(&state_id) {
363 for item in &response.items {
364 match item {
365 ResponseItem::Text { text } => state.text.push_str(text),
366 ResponseItem::Dedent { text } => state.dedent.push_str(text),
367 _ => {}
368 }
369 }
370 *state.updates_tx.borrow_mut() = ();
371 }
372 }
373 SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough),
374 _ => {
375 log::warn!("unhandled message: {:?}", message);
376 }
377 }
378 }
379}
380
381#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
382pub struct SupermavenCompletionStateId(usize);
383
384#[allow(dead_code)]
385pub struct SupermavenCompletionState {
386 buffer_id: EntityId,
387 prefix_anchor: Anchor,
388 text: String,
389 dedent: String,
390 updates_tx: watch::Sender<()>,
391}
392
393pub struct SupermavenCompletion {
394 pub id: SupermavenCompletionStateId,
395 pub updates: watch::Receiver<()>,
396}