1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
3 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
4 EditPredictionStore, prediction::EditPredictionResult,
5};
6use anyhow::{Result, bail};
7use client::Client;
8use edit_prediction_types::SuggestionDisplayType;
9use futures::{AsyncReadExt as _, channel::mpsc};
10use gpui::{
11 App, AppContext as _, Entity, Global, SharedString, Task,
12 http_client::{self, AsyncBody, Method},
13};
14use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
15use language_model::{ApiKeyState, EnvVar, env_var};
16use lsp::DiagnosticSeverity;
17use serde::{Deserialize, Serialize};
18use std::{
19 fmt::{self, Write as _},
20 ops::Range,
21 path::Path,
22 sync::Arc,
23 time::Instant,
24};
25
26const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
27const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
28
29pub struct SweepAi {
30 pub api_token: Entity<ApiKeyState>,
31 pub debug_info: Arc<str>,
32}
33
34impl SweepAi {
35 pub fn new(cx: &mut App) -> Self {
36 SweepAi {
37 api_token: sweep_api_token(cx),
38 debug_info: debug_info(cx),
39 }
40 }
41
42 pub fn request_prediction_with_sweep(
43 &self,
44 inputs: EditPredictionModelInput,
45 cx: &mut App,
46 ) -> Task<Result<Option<EditPredictionResult>>> {
47 let debug_info = self.debug_info.clone();
48 self.api_token.update(cx, |key_state, cx| {
49 _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
50 });
51
52 let buffer = inputs.buffer.clone();
53 let debug_tx = inputs.debug_tx.clone();
54
55 let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
56 return Task::ready(Ok(None));
57 };
58 let full_path: Arc<Path> = inputs
59 .snapshot
60 .file()
61 .map(|file| file.full_path(cx))
62 .unwrap_or_else(|| "untitled".into())
63 .into();
64
65 let project_file = project::File::from_dyn(inputs.snapshot.file());
66 let repo_name = project_file
67 .map(|file| file.worktree.read(cx).root_name_str())
68 .unwrap_or("untitled")
69 .into();
70 let offset = inputs.position.to_offset(&inputs.snapshot);
71
72 let recent_buffers = inputs.recent_paths.iter().cloned();
73 let http_client = cx.http_client();
74
75 let recent_buffer_snapshots = recent_buffers
76 .filter_map(|project_path| {
77 let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
78 if inputs.buffer == buffer {
79 None
80 } else {
81 Some(buffer.read(cx).snapshot())
82 }
83 })
84 .take(3)
85 .collect::<Vec<_>>();
86
87 let buffer_snapshotted_at = Instant::now();
88
89 let result = cx.background_spawn(async move {
90 let text = inputs.snapshot.text();
91
92 let mut recent_changes = String::new();
93 for event in &inputs.events {
94 write_event(event.as_ref(), &mut recent_changes).unwrap();
95 }
96
97 let mut file_chunks = recent_buffer_snapshots
98 .into_iter()
99 .map(|snapshot| {
100 let end_point = Point::new(30, 0).min(snapshot.max_point());
101 FileChunk {
102 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
103 file_path: snapshot
104 .file()
105 .map(|f| f.path().as_unix_str())
106 .unwrap_or("untitled")
107 .to_string(),
108 start_line: 0,
109 end_line: end_point.row as usize,
110 timestamp: snapshot.file().and_then(|file| {
111 Some(
112 file.disk_state()
113 .mtime()?
114 .to_seconds_and_nanos_for_persistence()?
115 .0,
116 )
117 }),
118 }
119 })
120 .collect::<Vec<_>>();
121
122 let retrieval_chunks = inputs
123 .related_files
124 .iter()
125 .flat_map(|related_file| {
126 related_file.excerpts.iter().map(|excerpt| FileChunk {
127 file_path: related_file.path.to_string_lossy().to_string(),
128 start_line: excerpt.row_range.start as usize,
129 end_line: excerpt.row_range.end as usize,
130 content: excerpt.text.to_string(),
131 timestamp: None,
132 })
133 })
134 .collect();
135
136 let diagnostic_entries = inputs
137 .snapshot
138 .diagnostics_in_range(inputs.diagnostic_search_range, false);
139 let mut diagnostic_content = String::new();
140 let mut diagnostic_count = 0;
141
142 for entry in diagnostic_entries {
143 let start_point: Point = entry.range.start;
144
145 let severity = match entry.diagnostic.severity {
146 DiagnosticSeverity::ERROR => "error",
147 DiagnosticSeverity::WARNING => "warning",
148 DiagnosticSeverity::INFORMATION => "info",
149 DiagnosticSeverity::HINT => "hint",
150 _ => continue,
151 };
152
153 diagnostic_count += 1;
154
155 writeln!(
156 &mut diagnostic_content,
157 "{} at line {}: {}",
158 severity,
159 start_point.row + 1,
160 entry.diagnostic.message
161 )?;
162 }
163
164 if !diagnostic_content.is_empty() {
165 file_chunks.push(FileChunk {
166 file_path: format!("Diagnostics for {}", full_path.display()),
167 start_line: 0,
168 end_line: diagnostic_count,
169 content: diagnostic_content,
170 timestamp: None,
171 });
172 }
173
174 let request_body = AutocompleteRequest {
175 debug_info,
176 repo_name,
177 file_path: full_path.clone(),
178 file_contents: text.clone(),
179 original_file_contents: text,
180 cursor_position: offset,
181 recent_changes: recent_changes.clone(),
182 changes_above_cursor: true,
183 multiple_suggestions: false,
184 branch: None,
185 file_chunks,
186 retrieval_chunks,
187 recent_user_actions: vec![],
188 use_bytes: true,
189 // TODO
190 privacy_mode_enabled: false,
191 };
192
193 let mut buf: Vec<u8> = Vec::new();
194 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
195 serde_json::to_writer(writer, &request_body)?;
196 let body: AsyncBody = buf.into();
197
198 let ep_inputs = zeta_prompt::ZetaPromptInput {
199 events: inputs.events,
200 related_files: inputs.related_files.clone(),
201 cursor_path: full_path.clone(),
202 cursor_excerpt: request_body.file_contents.clone().into(),
203 // we actually don't know
204 editable_range_in_excerpt: 0..inputs.snapshot.len(),
205 cursor_offset_in_excerpt: request_body.cursor_position,
206 };
207
208 send_started_event(
209 &debug_tx,
210 &buffer,
211 inputs.position,
212 serde_json::to_string(&request_body).unwrap_or_default(),
213 );
214
215 let request = http_client::Request::builder()
216 .uri(SWEEP_API_URL)
217 .header("Content-Type", "application/json")
218 .header("Authorization", format!("Bearer {}", api_token))
219 .header("Connection", "keep-alive")
220 .header("Content-Encoding", "br")
221 .method(Method::POST)
222 .body(body)?;
223
224 let mut response = http_client.send(request).await?;
225
226 let mut body = String::new();
227 response.body_mut().read_to_string(&mut body).await?;
228
229 let response_received_at = Instant::now();
230 if !response.status().is_success() {
231 let message = format!(
232 "Request failed with status: {:?}\nBody: {}",
233 response.status(),
234 body,
235 );
236 send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
237 bail!(message);
238 };
239
240 let response: AutocompleteResponse = serde_json::from_str(&body)?;
241
242 send_finished_event(&debug_tx, &buffer, inputs.position, body);
243
244 let old_text = inputs
245 .snapshot
246 .text_for_range(response.start_index..response.end_index)
247 .collect::<String>();
248 let edits = language::text_diff(&old_text, &response.completion)
249 .into_iter()
250 .map(|(range, text)| {
251 (
252 inputs
253 .snapshot
254 .anchor_after(response.start_index + range.start)
255 ..inputs
256 .snapshot
257 .anchor_before(response.start_index + range.end),
258 text,
259 )
260 })
261 .collect::<Vec<_>>();
262
263 anyhow::Ok((
264 response.autocomplete_id,
265 edits,
266 inputs.snapshot,
267 response_received_at,
268 ep_inputs,
269 ))
270 });
271
272 let buffer = inputs.buffer.clone();
273
274 cx.spawn(async move |cx| {
275 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
276 anyhow::Ok(Some(
277 EditPredictionResult::new(
278 EditPredictionId(id.into()),
279 &buffer,
280 &old_snapshot,
281 edits.into(),
282 buffer_snapshotted_at,
283 response_received_at,
284 inputs,
285 cx,
286 )
287 .await,
288 ))
289 })
290 }
291}
292
293fn send_started_event(
294 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
295 buffer: &Entity<Buffer>,
296 position: Anchor,
297 prompt: String,
298) {
299 if let Some(debug_tx) = debug_tx {
300 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
301 EditPredictionStartedDebugEvent {
302 buffer: buffer.downgrade(),
303 position,
304 prompt: Some(prompt),
305 },
306 ));
307 }
308}
309
310fn send_finished_event(
311 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
312 buffer: &Entity<Buffer>,
313 position: Anchor,
314 model_output: String,
315) {
316 if let Some(debug_tx) = debug_tx {
317 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
318 EditPredictionFinishedDebugEvent {
319 buffer: buffer.downgrade(),
320 position,
321 model_output: Some(model_output),
322 },
323 ));
324 }
325}
326
327pub const SWEEP_CREDENTIALS_URL: SharedString =
328 SharedString::new_static("https://autocomplete.sweep.dev");
329pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
330pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
331
332struct GlobalSweepApiKey(Entity<ApiKeyState>);
333
334impl Global for GlobalSweepApiKey {}
335
336pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
337 if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
338 return global.0.clone();
339 }
340 let entity =
341 cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
342 cx.set_global(GlobalSweepApiKey(entity.clone()));
343 entity
344}
345
346pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
347 sweep_api_token(cx).update(cx, |key_state, cx| {
348 key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
349 })
350}
351
352#[derive(Debug, Clone, Serialize)]
353struct AutocompleteRequest {
354 pub debug_info: Arc<str>,
355 pub repo_name: String,
356 pub branch: Option<String>,
357 pub file_path: Arc<Path>,
358 pub file_contents: String,
359 pub recent_changes: String,
360 pub cursor_position: usize,
361 pub original_file_contents: String,
362 pub file_chunks: Vec<FileChunk>,
363 pub retrieval_chunks: Vec<FileChunk>,
364 pub recent_user_actions: Vec<UserAction>,
365 pub multiple_suggestions: bool,
366 pub privacy_mode_enabled: bool,
367 pub changes_above_cursor: bool,
368 pub use_bytes: bool,
369}
370
371#[derive(Debug, Clone, Serialize)]
372struct FileChunk {
373 pub file_path: String,
374 pub start_line: usize,
375 pub end_line: usize,
376 pub content: String,
377 pub timestamp: Option<u64>,
378}
379
380#[derive(Debug, Clone, Serialize)]
381struct UserAction {
382 pub action_type: ActionType,
383 pub line_number: usize,
384 pub offset: usize,
385 pub file_path: String,
386 pub timestamp: u64,
387}
388
389#[allow(dead_code)]
390#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
391#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
392enum ActionType {
393 CursorMovement,
394 InsertChar,
395 DeleteChar,
396 InsertSelection,
397 DeleteSelection,
398}
399
400#[derive(Debug, Clone, Deserialize)]
401struct AutocompleteResponse {
402 pub autocomplete_id: String,
403 pub start_index: usize,
404 pub end_index: usize,
405 pub completion: String,
406 #[allow(dead_code)]
407 pub confidence: f64,
408 #[allow(dead_code)]
409 pub logprobs: Option<serde_json::Value>,
410 #[allow(dead_code)]
411 pub finish_reason: Option<String>,
412 #[allow(dead_code)]
413 pub elapsed_time_ms: u64,
414 #[allow(dead_code)]
415 #[serde(default, rename = "completions")]
416 pub additional_completions: Vec<AdditionalCompletion>,
417}
418
419#[allow(dead_code)]
420#[derive(Debug, Clone, Deserialize)]
421struct AdditionalCompletion {
422 pub start_index: usize,
423 pub end_index: usize,
424 pub completion: String,
425 pub confidence: f64,
426 pub autocomplete_id: String,
427 pub logprobs: Option<serde_json::Value>,
428 pub finish_reason: Option<String>,
429}
430
431fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
432 match event {
433 zeta_prompt::Event::BufferChange {
434 old_path,
435 path,
436 diff,
437 ..
438 } => {
439 if old_path != path {
440 // TODO confirm how to do this for sweep
441 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
442 }
443
444 if !diff.is_empty() {
445 write!(f, "File: {}:\n{}\n", path.display(), diff)?
446 }
447
448 fmt::Result::Ok(())
449 }
450 }
451}
452
453fn debug_info(cx: &gpui::App) -> Arc<str> {
454 format!(
455 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
456 version = release_channel::AppVersion::global(cx),
457 sha = release_channel::AppCommitSha::try_global(cx)
458 .map_or("unknown".to_string(), |sha| sha.full()),
459 os = client::telemetry::os_name(),
460 )
461 .into()
462}
463
464#[derive(Debug, Clone, Copy, Serialize)]
465#[serde(rename_all = "snake_case")]
466pub enum SweepEventType {
467 AutocompleteSuggestionShown,
468 AutocompleteSuggestionAccepted,
469}
470
471#[derive(Debug, Clone, Copy, Serialize)]
472#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
473pub enum SweepSuggestionType {
474 GhostText,
475 Popup,
476 JumpToEdit,
477}
478
479#[derive(Debug, Clone, Serialize)]
480struct AutocompleteMetricsRequest {
481 event_type: SweepEventType,
482 suggestion_type: SweepSuggestionType,
483 additions: u32,
484 deletions: u32,
485 autocomplete_id: String,
486 edit_tracking: String,
487 edit_tracking_line: Option<u32>,
488 lifespan: Option<u64>,
489 debug_info: Arc<str>,
490 device_id: String,
491 privacy_mode_enabled: bool,
492}
493
494fn send_autocomplete_metrics_request(
495 cx: &App,
496 client: Arc<Client>,
497 api_token: Arc<str>,
498 request_body: AutocompleteMetricsRequest,
499) {
500 let http_client = client.http_client();
501 cx.background_spawn(async move {
502 let body: AsyncBody = serde_json::to_string(&request_body)?.into();
503
504 let request = http_client::Request::builder()
505 .uri(SWEEP_METRICS_URL)
506 .header("Content-Type", "application/json")
507 .header("Authorization", format!("Bearer {}", api_token))
508 .method(Method::POST)
509 .body(body)?;
510
511 let mut response = http_client.send(request).await?;
512
513 if !response.status().is_success() {
514 let mut body = String::new();
515 response.body_mut().read_to_string(&mut body).await?;
516 anyhow::bail!(
517 "Failed to send autocomplete metrics for sweep_ai: {:?}\nBody: {}",
518 response.status(),
519 body,
520 );
521 }
522
523 Ok(())
524 })
525 .detach_and_log_err(cx);
526}
527
528pub(crate) fn edit_prediction_accepted(
529 store: &EditPredictionStore,
530 current_prediction: CurrentEditPrediction,
531 cx: &App,
532) {
533 let Some(api_token) = store
534 .sweep_ai
535 .api_token
536 .read(cx)
537 .key(&SWEEP_CREDENTIALS_URL)
538 else {
539 return;
540 };
541 let debug_info = store.sweep_ai.debug_info.clone();
542
543 let prediction = current_prediction.prediction;
544
545 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
546 let autocomplete_id = prediction.id.to_string();
547
548 let device_id = store
549 .client
550 .user_id()
551 .as_ref()
552 .map(ToString::to_string)
553 .unwrap_or_default();
554
555 let suggestion_type = match current_prediction.shown_with {
556 Some(SuggestionDisplayType::DiffPopover) => SweepSuggestionType::Popup,
557 Some(SuggestionDisplayType::Jump) => return, // should'nt happen
558 Some(SuggestionDisplayType::GhostText) | None => SweepSuggestionType::GhostText,
559 };
560
561 let request_body = AutocompleteMetricsRequest {
562 event_type: SweepEventType::AutocompleteSuggestionAccepted,
563 suggestion_type,
564 additions,
565 deletions,
566 autocomplete_id,
567 edit_tracking: String::new(),
568 edit_tracking_line: None,
569 lifespan: None,
570 debug_info,
571 device_id,
572 privacy_mode_enabled: false,
573 };
574
575 send_autocomplete_metrics_request(cx, store.client.clone(), api_token, request_body);
576}
577
578pub fn edit_prediction_shown(
579 sweep_ai: &SweepAi,
580 client: Arc<Client>,
581 prediction: &EditPrediction,
582 display_type: SuggestionDisplayType,
583 cx: &App,
584) {
585 let Some(api_token) = sweep_ai.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
586 return;
587 };
588 let debug_info = sweep_ai.debug_info.clone();
589
590 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
591 let autocomplete_id = prediction.id.to_string();
592
593 let suggestion_type = match display_type {
594 SuggestionDisplayType::GhostText => SweepSuggestionType::GhostText,
595 SuggestionDisplayType::DiffPopover => SweepSuggestionType::Popup,
596 SuggestionDisplayType::Jump => SweepSuggestionType::JumpToEdit,
597 };
598
599 let request_body = AutocompleteMetricsRequest {
600 event_type: SweepEventType::AutocompleteSuggestionShown,
601 suggestion_type,
602 additions,
603 deletions,
604 autocomplete_id,
605 edit_tracking: String::new(),
606 edit_tracking_line: None,
607 lifespan: None,
608 debug_info,
609 device_id: String::new(),
610 privacy_mode_enabled: false,
611 };
612
613 send_autocomplete_metrics_request(cx, client, api_token, request_body);
614}
615
616fn compute_edit_metrics(
617 edits: &[(Range<Anchor>, Arc<str>)],
618 snapshot: &BufferSnapshot,
619) -> (u32, u32) {
620 let mut additions = 0u32;
621 let mut deletions = 0u32;
622
623 for (range, new_text) in edits {
624 let old_text = snapshot.text_for_range(range.clone());
625 deletions += old_text
626 .map(|chunk| chunk.lines().count())
627 .sum::<usize>()
628 .max(1) as u32;
629 additions += new_text.lines().count().max(1) as u32;
630 }
631
632 (additions, deletions)
633}