1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
3 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
4 EditPredictionStore, UserActionRecord, UserActionType, prediction::EditPredictionResult,
5};
6use anyhow::{Result, bail};
7use client::Client;
8use edit_prediction_types::SuggestionDisplayType;
9use futures::{AsyncReadExt as _, channel::mpsc};
10use gpui::{
11 App, AppContext as _, Entity, Global, SharedString, Task,
12 http_client::{self, AsyncBody, Method},
13};
14use language::language_settings::all_language_settings;
15use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
16use language_model::{ApiKeyState, EnvVar, env_var};
17use lsp::DiagnosticSeverity;
18use serde::{Deserialize, Serialize};
19use std::{
20 fmt::{self, Write as _},
21 ops::Range,
22 path::Path,
23 sync::Arc,
24};
25
26const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
27const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
28
29pub struct SweepAi {
30 pub api_token: Entity<ApiKeyState>,
31 pub debug_info: Arc<str>,
32}
33
34impl SweepAi {
35 pub fn new(cx: &mut App) -> Self {
36 SweepAi {
37 api_token: sweep_api_token(cx),
38 debug_info: debug_info(cx),
39 }
40 }
41
42 pub fn request_prediction_with_sweep(
43 &self,
44 inputs: EditPredictionModelInput,
45 cx: &mut App,
46 ) -> Task<Result<Option<EditPredictionResult>>> {
47 let privacy_mode_enabled = all_language_settings(None, cx)
48 .edit_predictions
49 .sweep
50 .privacy_mode;
51 let debug_info = self.debug_info.clone();
52 let request_start = cx.background_executor().now();
53 self.api_token.update(cx, |key_state, cx| {
54 _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
55 });
56
57 let buffer = inputs.buffer.clone();
58 let debug_tx = inputs.debug_tx.clone();
59
60 let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
61 return Task::ready(Ok(None));
62 };
63 let full_path: Arc<Path> = inputs
64 .snapshot
65 .file()
66 .map(|file| file.full_path(cx))
67 .unwrap_or_else(|| "untitled".into())
68 .into();
69
70 let project_file = project::File::from_dyn(inputs.snapshot.file());
71 let repo_name = project_file
72 .map(|file| file.worktree.read(cx).root_name_str())
73 .unwrap_or("untitled")
74 .into();
75 let offset = inputs.position.to_offset(&inputs.snapshot);
76 let buffer_entity_id = inputs.buffer.entity_id();
77
78 let recent_buffers = inputs.recent_paths.iter().cloned();
79 let http_client = cx.http_client();
80
81 let recent_buffer_snapshots = recent_buffers
82 .filter_map(|project_path| {
83 let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
84 if inputs.buffer == buffer {
85 None
86 } else {
87 Some(buffer.read(cx).snapshot())
88 }
89 })
90 .take(3)
91 .collect::<Vec<_>>();
92
93 let result = cx.background_spawn(async move {
94 let text = inputs.snapshot.text();
95
96 let mut recent_changes = String::new();
97 for event in &inputs.events {
98 write_event(event.as_ref(), &mut recent_changes).unwrap();
99 }
100
101 let file_chunks = recent_buffer_snapshots
102 .into_iter()
103 .map(|snapshot| {
104 let end_point = Point::new(30, 0).min(snapshot.max_point());
105 FileChunk {
106 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
107 file_path: snapshot
108 .file()
109 .map(|f| f.path().as_unix_str())
110 .unwrap_or("untitled")
111 .to_string(),
112 start_line: 0,
113 end_line: end_point.row as usize,
114 timestamp: snapshot.file().and_then(|file| {
115 Some(
116 file.disk_state()
117 .mtime()?
118 .to_seconds_and_nanos_for_persistence()?
119 .0,
120 )
121 }),
122 }
123 })
124 .collect::<Vec<_>>();
125
126 let mut retrieval_chunks: Vec<FileChunk> = inputs
127 .related_files
128 .iter()
129 .flat_map(|related_file| {
130 related_file.excerpts.iter().map(|excerpt| FileChunk {
131 file_path: related_file.path.to_string_lossy().to_string(),
132 start_line: excerpt.row_range.start as usize,
133 end_line: excerpt.row_range.end as usize,
134 content: excerpt.text.to_string(),
135 timestamp: None,
136 })
137 })
138 .collect();
139
140 let diagnostic_entries = inputs
141 .snapshot
142 .diagnostics_in_range(inputs.diagnostic_search_range, false);
143 let mut diagnostic_content = String::new();
144 let mut diagnostic_count = 0;
145
146 for entry in diagnostic_entries {
147 let start_point: Point = entry.range.start;
148
149 let severity = match entry.diagnostic.severity {
150 DiagnosticSeverity::ERROR => "error",
151 DiagnosticSeverity::WARNING => "warning",
152 DiagnosticSeverity::INFORMATION => "info",
153 DiagnosticSeverity::HINT => "hint",
154 _ => continue,
155 };
156
157 diagnostic_count += 1;
158
159 writeln!(
160 &mut diagnostic_content,
161 "{}:{}:{}: {}: {}",
162 full_path.display(),
163 start_point.row + 1,
164 start_point.column + 1,
165 severity,
166 entry.diagnostic.message
167 )?;
168 }
169
170 if !diagnostic_content.is_empty() {
171 retrieval_chunks.push(FileChunk {
172 file_path: "diagnostics".to_string(),
173 start_line: 1,
174 end_line: diagnostic_count,
175 content: diagnostic_content,
176 timestamp: None,
177 });
178 }
179
180 let file_path_str = full_path.display().to_string();
181 let recent_user_actions = inputs
182 .user_actions
183 .iter()
184 .filter(|r| r.buffer_id == buffer_entity_id)
185 .map(|r| to_sweep_user_action(r, &file_path_str))
186 .collect();
187
188 let request_body = AutocompleteRequest {
189 debug_info,
190 repo_name,
191 file_path: full_path.clone(),
192 file_contents: text.clone(),
193 original_file_contents: text,
194 cursor_position: offset,
195 recent_changes: recent_changes.clone(),
196 changes_above_cursor: true,
197 multiple_suggestions: false,
198 branch: None,
199 file_chunks,
200 retrieval_chunks,
201 recent_user_actions,
202 use_bytes: true,
203 privacy_mode_enabled,
204 };
205
206 let mut buf: Vec<u8> = Vec::new();
207 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 1, 22);
208 serde_json::to_writer(writer, &request_body)?;
209 let body: AsyncBody = buf.into();
210
211 let ep_inputs = zeta_prompt::ZetaPromptInput {
212 events: inputs.events,
213 related_files: Some(inputs.related_files.clone()),
214 active_buffer_diagnostics: vec![],
215 cursor_path: full_path.clone(),
216 cursor_excerpt: request_body.file_contents.clone().into(),
217 cursor_offset_in_excerpt: request_body.cursor_position,
218 excerpt_start_row: Some(0),
219 excerpt_ranges: zeta_prompt::ExcerptRanges {
220 editable_150: 0..inputs.snapshot.len(),
221 editable_180: 0..inputs.snapshot.len(),
222 editable_350: 0..inputs.snapshot.len(),
223 editable_150_context_350: 0..inputs.snapshot.len(),
224 editable_180_context_350: 0..inputs.snapshot.len(),
225 editable_350_context_150: 0..inputs.snapshot.len(),
226 ..Default::default()
227 },
228 syntax_ranges: None,
229 experiment: None,
230 in_open_source_repo: false,
231 can_collect_data: false,
232 repo_url: None,
233 };
234
235 send_started_event(
236 &debug_tx,
237 &buffer,
238 inputs.position,
239 serde_json::to_string(&request_body).unwrap_or_default(),
240 );
241
242 let request = http_client::Request::builder()
243 .uri(SWEEP_API_URL)
244 .header("Content-Type", "application/json")
245 .header("Authorization", format!("Bearer {}", api_token))
246 .header("Connection", "keep-alive")
247 .header("Content-Encoding", "br")
248 .method(Method::POST)
249 .body(body)?;
250
251 let mut response = http_client.send(request).await?;
252
253 let mut body = String::new();
254 response.body_mut().read_to_string(&mut body).await?;
255
256 if !response.status().is_success() {
257 let message = format!(
258 "Request failed with status: {:?}\nBody: {}",
259 response.status(),
260 body,
261 );
262 send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
263 bail!(message);
264 };
265
266 let response: AutocompleteResponse = serde_json::from_str(&body)?;
267
268 send_finished_event(&debug_tx, &buffer, inputs.position, body);
269
270 let old_text = inputs
271 .snapshot
272 .text_for_range(response.start_index..response.end_index)
273 .collect::<String>();
274 let edits = language::text_diff(&old_text, &response.completion)
275 .into_iter()
276 .map(|(range, text)| {
277 (
278 inputs
279 .snapshot
280 .anchor_after(response.start_index + range.start)
281 ..inputs
282 .snapshot
283 .anchor_before(response.start_index + range.end),
284 text,
285 )
286 })
287 .collect::<Vec<_>>();
288
289 anyhow::Ok((response.autocomplete_id, edits, inputs.snapshot, ep_inputs))
290 });
291
292 let buffer = inputs.buffer.clone();
293
294 cx.spawn(async move |cx| {
295 let (id, edits, old_snapshot, inputs) = result.await?;
296 anyhow::Ok(Some(
297 EditPredictionResult::new(
298 EditPredictionId(id.into()),
299 &buffer,
300 &old_snapshot,
301 edits.into(),
302 None,
303 inputs,
304 None,
305 cx.background_executor().now() - request_start,
306 cx,
307 )
308 .await,
309 ))
310 })
311 }
312}
313
314fn send_started_event(
315 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
316 buffer: &Entity<Buffer>,
317 position: Anchor,
318 prompt: String,
319) {
320 if let Some(debug_tx) = debug_tx {
321 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
322 EditPredictionStartedDebugEvent {
323 buffer: buffer.downgrade(),
324 position,
325 prompt: Some(prompt),
326 },
327 ));
328 }
329}
330
331fn send_finished_event(
332 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
333 buffer: &Entity<Buffer>,
334 position: Anchor,
335 model_output: String,
336) {
337 if let Some(debug_tx) = debug_tx {
338 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
339 EditPredictionFinishedDebugEvent {
340 buffer: buffer.downgrade(),
341 position,
342 model_output: Some(model_output),
343 },
344 ));
345 }
346}
347
348pub const SWEEP_CREDENTIALS_URL: SharedString =
349 SharedString::new_static("https://autocomplete.sweep.dev");
350pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
351pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
352
353struct GlobalSweepApiKey(Entity<ApiKeyState>);
354
355impl Global for GlobalSweepApiKey {}
356
357pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
358 if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
359 return global.0.clone();
360 }
361 let entity =
362 cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
363 cx.set_global(GlobalSweepApiKey(entity.clone()));
364 entity
365}
366
367pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
368 sweep_api_token(cx).update(cx, |key_state, cx| {
369 key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
370 })
371}
372
373#[derive(Debug, Clone, Serialize)]
374struct AutocompleteRequest {
375 pub debug_info: Arc<str>,
376 pub repo_name: String,
377 pub branch: Option<String>,
378 pub file_path: Arc<Path>,
379 pub file_contents: String,
380 pub recent_changes: String,
381 pub cursor_position: usize,
382 pub original_file_contents: String,
383 pub file_chunks: Vec<FileChunk>,
384 pub retrieval_chunks: Vec<FileChunk>,
385 pub recent_user_actions: Vec<UserAction>,
386 pub multiple_suggestions: bool,
387 pub privacy_mode_enabled: bool,
388 pub changes_above_cursor: bool,
389 pub use_bytes: bool,
390}
391
392#[derive(Debug, Clone, Serialize)]
393struct FileChunk {
394 pub file_path: String,
395 pub start_line: usize,
396 pub end_line: usize,
397 pub content: String,
398 pub timestamp: Option<u64>,
399}
400
401#[derive(Debug, Clone, Serialize)]
402struct UserAction {
403 pub action_type: ActionType,
404 pub line_number: usize,
405 pub offset: usize,
406 pub file_path: String,
407 pub timestamp: u64,
408}
409
410#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
411#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
412enum ActionType {
413 CursorMovement,
414 InsertChar,
415 DeleteChar,
416 InsertSelection,
417 DeleteSelection,
418}
419
420fn to_sweep_user_action(record: &UserActionRecord, file_path: &str) -> UserAction {
421 UserAction {
422 action_type: match record.action_type {
423 UserActionType::InsertChar => ActionType::InsertChar,
424 UserActionType::InsertSelection => ActionType::InsertSelection,
425 UserActionType::DeleteChar => ActionType::DeleteChar,
426 UserActionType::DeleteSelection => ActionType::DeleteSelection,
427 UserActionType::CursorMovement => ActionType::CursorMovement,
428 },
429 line_number: record.line_number as usize,
430 offset: record.offset,
431 file_path: file_path.to_string(),
432 timestamp: record.timestamp_epoch_ms,
433 }
434}
435
436#[derive(Debug, Clone, Deserialize)]
437struct AutocompleteResponse {
438 pub autocomplete_id: String,
439 pub start_index: usize,
440 pub end_index: usize,
441 pub completion: String,
442 #[allow(dead_code)]
443 pub confidence: f64,
444 #[allow(dead_code)]
445 pub logprobs: Option<serde_json::Value>,
446 #[allow(dead_code)]
447 pub finish_reason: Option<String>,
448 #[allow(dead_code)]
449 pub elapsed_time_ms: u64,
450 #[allow(dead_code)]
451 #[serde(default, rename = "completions")]
452 pub additional_completions: Vec<AdditionalCompletion>,
453}
454
455#[allow(dead_code)]
456#[derive(Debug, Clone, Deserialize)]
457struct AdditionalCompletion {
458 pub start_index: usize,
459 pub end_index: usize,
460 pub completion: String,
461 pub confidence: f64,
462 pub autocomplete_id: String,
463 pub logprobs: Option<serde_json::Value>,
464 pub finish_reason: Option<String>,
465}
466
467fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
468 match event {
469 zeta_prompt::Event::BufferChange {
470 old_path,
471 path,
472 diff,
473 ..
474 } => {
475 if old_path != path {
476 // TODO confirm how to do this for sweep
477 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
478 }
479
480 if !diff.is_empty() {
481 write!(f, "File: {}:\n{}\n", path.display(), diff)?
482 }
483
484 fmt::Result::Ok(())
485 }
486 }
487}
488
489fn debug_info(cx: &gpui::App) -> Arc<str> {
490 format!(
491 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
492 version = release_channel::AppVersion::global(cx),
493 sha = release_channel::AppCommitSha::try_global(cx)
494 .map_or("unknown".to_string(), |sha| sha.full()),
495 os = client::telemetry::os_name(),
496 )
497 .into()
498}
499
500#[derive(Debug, Clone, Copy, Serialize)]
501#[serde(rename_all = "snake_case")]
502pub enum SweepEventType {
503 AutocompleteSuggestionShown,
504 AutocompleteSuggestionAccepted,
505}
506
507#[derive(Debug, Clone, Copy, Serialize)]
508#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
509pub enum SweepSuggestionType {
510 GhostText,
511 Popup,
512 JumpToEdit,
513}
514
515#[derive(Debug, Clone, Serialize)]
516struct AutocompleteMetricsRequest {
517 event_type: SweepEventType,
518 suggestion_type: SweepSuggestionType,
519 additions: u32,
520 deletions: u32,
521 autocomplete_id: String,
522 edit_tracking: String,
523 edit_tracking_line: Option<u32>,
524 lifespan: Option<u64>,
525 debug_info: Arc<str>,
526 device_id: String,
527 privacy_mode_enabled: bool,
528}
529
530fn send_autocomplete_metrics_request(
531 cx: &App,
532 client: Arc<Client>,
533 api_token: Arc<str>,
534 request_body: AutocompleteMetricsRequest,
535) {
536 let http_client = client.http_client();
537 cx.background_spawn(async move {
538 let body: AsyncBody = serde_json::to_string(&request_body)?.into();
539
540 let request = http_client::Request::builder()
541 .uri(SWEEP_METRICS_URL)
542 .header("Content-Type", "application/json")
543 .header("Authorization", format!("Bearer {}", api_token))
544 .method(Method::POST)
545 .body(body)?;
546
547 let mut response = http_client.send(request).await?;
548
549 if !response.status().is_success() {
550 let mut body = String::new();
551 response.body_mut().read_to_string(&mut body).await?;
552 anyhow::bail!(
553 "Failed to send autocomplete metrics for sweep_ai: {:?}\nBody: {}",
554 response.status(),
555 body,
556 );
557 }
558
559 Ok(())
560 })
561 .detach_and_log_err(cx);
562}
563
564pub(crate) fn edit_prediction_accepted(
565 store: &EditPredictionStore,
566 current_prediction: CurrentEditPrediction,
567 cx: &App,
568) {
569 let Some(api_token) = store
570 .sweep_ai
571 .api_token
572 .read(cx)
573 .key(&SWEEP_CREDENTIALS_URL)
574 else {
575 return;
576 };
577 let debug_info = store.sweep_ai.debug_info.clone();
578
579 let prediction = current_prediction.prediction;
580
581 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
582 let autocomplete_id = prediction.id.to_string();
583
584 let device_id = store
585 .client
586 .user_id()
587 .as_ref()
588 .map(ToString::to_string)
589 .unwrap_or_default();
590
591 let suggestion_type = match current_prediction.shown_with {
592 Some(SuggestionDisplayType::DiffPopover) => SweepSuggestionType::Popup,
593 Some(SuggestionDisplayType::Jump) => return, // should'nt happen
594 Some(SuggestionDisplayType::GhostText) | None => SweepSuggestionType::GhostText,
595 };
596
597 let request_body = AutocompleteMetricsRequest {
598 event_type: SweepEventType::AutocompleteSuggestionAccepted,
599 suggestion_type,
600 additions,
601 deletions,
602 autocomplete_id,
603 edit_tracking: String::new(),
604 edit_tracking_line: None,
605 lifespan: None,
606 debug_info,
607 device_id,
608 privacy_mode_enabled: false,
609 };
610
611 send_autocomplete_metrics_request(cx, store.client.clone(), api_token, request_body);
612}
613
614pub fn edit_prediction_shown(
615 sweep_ai: &SweepAi,
616 client: Arc<Client>,
617 prediction: &EditPrediction,
618 display_type: SuggestionDisplayType,
619 cx: &App,
620) {
621 let Some(api_token) = sweep_ai.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
622 return;
623 };
624 let debug_info = sweep_ai.debug_info.clone();
625
626 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
627 let autocomplete_id = prediction.id.to_string();
628
629 let suggestion_type = match display_type {
630 SuggestionDisplayType::GhostText => SweepSuggestionType::GhostText,
631 SuggestionDisplayType::DiffPopover => SweepSuggestionType::Popup,
632 SuggestionDisplayType::Jump => SweepSuggestionType::JumpToEdit,
633 };
634
635 let request_body = AutocompleteMetricsRequest {
636 event_type: SweepEventType::AutocompleteSuggestionShown,
637 suggestion_type,
638 additions,
639 deletions,
640 autocomplete_id,
641 edit_tracking: String::new(),
642 edit_tracking_line: None,
643 lifespan: None,
644 debug_info,
645 device_id: String::new(),
646 privacy_mode_enabled: false,
647 };
648
649 send_autocomplete_metrics_request(cx, client, api_token, request_body);
650}
651
652fn compute_edit_metrics(
653 edits: &[(Range<Anchor>, Arc<str>)],
654 snapshot: &BufferSnapshot,
655) -> (u32, u32) {
656 let mut additions = 0u32;
657 let mut deletions = 0u32;
658
659 for (range, new_text) in edits {
660 let old_text = snapshot.text_for_range(range.clone());
661 deletions += old_text
662 .map(|chunk| chunk.lines().count())
663 .sum::<usize>()
664 .max(1) as u32;
665 additions += new_text.lines().count().max(1) as u32;
666 }
667
668 (additions, deletions)
669}