1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
3 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
4 EditPredictionStore, UserActionRecord, UserActionType, prediction::EditPredictionResult,
5};
6use anyhow::{Result, bail};
7use client::Client;
8use edit_prediction_types::SuggestionDisplayType;
9use futures::{AsyncReadExt as _, channel::mpsc};
10use gpui::{
11 App, AppContext as _, Entity, Global, SharedString, Task,
12 http_client::{self, AsyncBody, Method},
13};
14use language::language_settings::all_language_settings;
15use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
16use language_model::{ApiKeyState, EnvVar, env_var};
17use lsp::DiagnosticSeverity;
18use serde::{Deserialize, Serialize};
19use std::{
20 fmt::{self, Write as _},
21 ops::Range,
22 path::Path,
23 sync::Arc,
24 time::Instant,
25};
26
27const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
28const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
29
30pub struct SweepAi {
31 pub api_token: Entity<ApiKeyState>,
32 pub debug_info: Arc<str>,
33}
34
35impl SweepAi {
36 pub fn new(cx: &mut App) -> Self {
37 SweepAi {
38 api_token: sweep_api_token(cx),
39 debug_info: debug_info(cx),
40 }
41 }
42
43 pub fn request_prediction_with_sweep(
44 &self,
45 inputs: EditPredictionModelInput,
46 cx: &mut App,
47 ) -> Task<Result<Option<EditPredictionResult>>> {
48 let privacy_mode_enabled = all_language_settings(None, cx)
49 .edit_predictions
50 .sweep
51 .privacy_mode;
52 let debug_info = self.debug_info.clone();
53 self.api_token.update(cx, |key_state, cx| {
54 _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
55 });
56
57 let buffer = inputs.buffer.clone();
58 let debug_tx = inputs.debug_tx.clone();
59
60 let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
61 return Task::ready(Ok(None));
62 };
63 let full_path: Arc<Path> = inputs
64 .snapshot
65 .file()
66 .map(|file| file.full_path(cx))
67 .unwrap_or_else(|| "untitled".into())
68 .into();
69
70 let project_file = project::File::from_dyn(inputs.snapshot.file());
71 let repo_name = project_file
72 .map(|file| file.worktree.read(cx).root_name_str())
73 .unwrap_or("untitled")
74 .into();
75 let offset = inputs.position.to_offset(&inputs.snapshot);
76 let buffer_entity_id = inputs.buffer.entity_id();
77
78 let recent_buffers = inputs.recent_paths.iter().cloned();
79 let http_client = cx.http_client();
80
81 let recent_buffer_snapshots = recent_buffers
82 .filter_map(|project_path| {
83 let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
84 if inputs.buffer == buffer {
85 None
86 } else {
87 Some(buffer.read(cx).snapshot())
88 }
89 })
90 .take(3)
91 .collect::<Vec<_>>();
92
93 let buffer_snapshotted_at = Instant::now();
94
95 let result = cx.background_spawn(async move {
96 let text = inputs.snapshot.text();
97
98 let mut recent_changes = String::new();
99 for event in &inputs.events {
100 write_event(event.as_ref(), &mut recent_changes).unwrap();
101 }
102
103 let file_chunks = recent_buffer_snapshots
104 .into_iter()
105 .map(|snapshot| {
106 let end_point = Point::new(30, 0).min(snapshot.max_point());
107 FileChunk {
108 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
109 file_path: snapshot
110 .file()
111 .map(|f| f.path().as_unix_str())
112 .unwrap_or("untitled")
113 .to_string(),
114 start_line: 0,
115 end_line: end_point.row as usize,
116 timestamp: snapshot.file().and_then(|file| {
117 Some(
118 file.disk_state()
119 .mtime()?
120 .to_seconds_and_nanos_for_persistence()?
121 .0,
122 )
123 }),
124 }
125 })
126 .collect::<Vec<_>>();
127
128 let mut retrieval_chunks: Vec<FileChunk> = inputs
129 .related_files
130 .iter()
131 .flat_map(|related_file| {
132 related_file.excerpts.iter().map(|excerpt| FileChunk {
133 file_path: related_file.path.to_string_lossy().to_string(),
134 start_line: excerpt.row_range.start as usize,
135 end_line: excerpt.row_range.end as usize,
136 content: excerpt.text.to_string(),
137 timestamp: None,
138 })
139 })
140 .collect();
141
142 let diagnostic_entries = inputs
143 .snapshot
144 .diagnostics_in_range(inputs.diagnostic_search_range, false);
145 let mut diagnostic_content = String::new();
146 let mut diagnostic_count = 0;
147
148 for entry in diagnostic_entries {
149 let start_point: Point = entry.range.start;
150
151 let severity = match entry.diagnostic.severity {
152 DiagnosticSeverity::ERROR => "error",
153 DiagnosticSeverity::WARNING => "warning",
154 DiagnosticSeverity::INFORMATION => "info",
155 DiagnosticSeverity::HINT => "hint",
156 _ => continue,
157 };
158
159 diagnostic_count += 1;
160
161 writeln!(
162 &mut diagnostic_content,
163 "{}:{}:{}: {}: {}",
164 full_path.display(),
165 start_point.row + 1,
166 start_point.column + 1,
167 severity,
168 entry.diagnostic.message
169 )?;
170 }
171
172 if !diagnostic_content.is_empty() {
173 retrieval_chunks.push(FileChunk {
174 file_path: "diagnostics".to_string(),
175 start_line: 1,
176 end_line: diagnostic_count,
177 content: diagnostic_content,
178 timestamp: None,
179 });
180 }
181
182 let file_path_str = full_path.display().to_string();
183 let recent_user_actions = inputs
184 .user_actions
185 .iter()
186 .filter(|r| r.buffer_id == buffer_entity_id)
187 .map(|r| to_sweep_user_action(r, &file_path_str))
188 .collect();
189
190 let request_body = AutocompleteRequest {
191 debug_info,
192 repo_name,
193 file_path: full_path.clone(),
194 file_contents: text.clone(),
195 original_file_contents: text,
196 cursor_position: offset,
197 recent_changes: recent_changes.clone(),
198 changes_above_cursor: true,
199 multiple_suggestions: false,
200 branch: None,
201 file_chunks,
202 retrieval_chunks,
203 recent_user_actions,
204 use_bytes: true,
205 privacy_mode_enabled,
206 };
207
208 let mut buf: Vec<u8> = Vec::new();
209 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 1, 22);
210 serde_json::to_writer(writer, &request_body)?;
211 let body: AsyncBody = buf.into();
212
213 let ep_inputs = zeta_prompt::ZetaPromptInput {
214 events: inputs.events,
215 related_files: Some(inputs.related_files.clone()),
216 active_buffer_diagnostics: vec![],
217 cursor_path: full_path.clone(),
218 cursor_excerpt: request_body.file_contents.clone().into(),
219 cursor_offset_in_excerpt: request_body.cursor_position,
220 excerpt_start_row: Some(0),
221 excerpt_ranges: zeta_prompt::ExcerptRanges {
222 editable_150: 0..inputs.snapshot.len(),
223 editable_180: 0..inputs.snapshot.len(),
224 editable_350: 0..inputs.snapshot.len(),
225 editable_150_context_350: 0..inputs.snapshot.len(),
226 editable_180_context_350: 0..inputs.snapshot.len(),
227 editable_350_context_150: 0..inputs.snapshot.len(),
228 ..Default::default()
229 },
230 syntax_ranges: None,
231 experiment: None,
232 in_open_source_repo: false,
233 can_collect_data: false,
234 repo_url: None,
235 };
236
237 send_started_event(
238 &debug_tx,
239 &buffer,
240 inputs.position,
241 serde_json::to_string(&request_body).unwrap_or_default(),
242 );
243
244 let request = http_client::Request::builder()
245 .uri(SWEEP_API_URL)
246 .header("Content-Type", "application/json")
247 .header("Authorization", format!("Bearer {}", api_token))
248 .header("Connection", "keep-alive")
249 .header("Content-Encoding", "br")
250 .method(Method::POST)
251 .body(body)?;
252
253 let mut response = http_client.send(request).await?;
254
255 let mut body = String::new();
256 response.body_mut().read_to_string(&mut body).await?;
257
258 let response_received_at = Instant::now();
259 if !response.status().is_success() {
260 let message = format!(
261 "Request failed with status: {:?}\nBody: {}",
262 response.status(),
263 body,
264 );
265 send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
266 bail!(message);
267 };
268
269 let response: AutocompleteResponse = serde_json::from_str(&body)?;
270
271 send_finished_event(&debug_tx, &buffer, inputs.position, body);
272
273 let old_text = inputs
274 .snapshot
275 .text_for_range(response.start_index..response.end_index)
276 .collect::<String>();
277 let edits = language::text_diff(&old_text, &response.completion)
278 .into_iter()
279 .map(|(range, text)| {
280 (
281 inputs
282 .snapshot
283 .anchor_after(response.start_index + range.start)
284 ..inputs
285 .snapshot
286 .anchor_before(response.start_index + range.end),
287 text,
288 )
289 })
290 .collect::<Vec<_>>();
291
292 anyhow::Ok((
293 response.autocomplete_id,
294 edits,
295 inputs.snapshot,
296 response_received_at,
297 ep_inputs,
298 ))
299 });
300
301 let buffer = inputs.buffer.clone();
302
303 cx.spawn(async move |cx| {
304 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
305 anyhow::Ok(Some(
306 EditPredictionResult::new(
307 EditPredictionId(id.into()),
308 &buffer,
309 &old_snapshot,
310 edits.into(),
311 None,
312 buffer_snapshotted_at,
313 response_received_at,
314 inputs,
315 None,
316 cx,
317 )
318 .await,
319 ))
320 })
321 }
322}
323
324fn send_started_event(
325 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
326 buffer: &Entity<Buffer>,
327 position: Anchor,
328 prompt: String,
329) {
330 if let Some(debug_tx) = debug_tx {
331 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
332 EditPredictionStartedDebugEvent {
333 buffer: buffer.downgrade(),
334 position,
335 prompt: Some(prompt),
336 },
337 ));
338 }
339}
340
341fn send_finished_event(
342 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
343 buffer: &Entity<Buffer>,
344 position: Anchor,
345 model_output: String,
346) {
347 if let Some(debug_tx) = debug_tx {
348 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
349 EditPredictionFinishedDebugEvent {
350 buffer: buffer.downgrade(),
351 position,
352 model_output: Some(model_output),
353 },
354 ));
355 }
356}
357
358pub const SWEEP_CREDENTIALS_URL: SharedString =
359 SharedString::new_static("https://autocomplete.sweep.dev");
360pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
361pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
362
363struct GlobalSweepApiKey(Entity<ApiKeyState>);
364
365impl Global for GlobalSweepApiKey {}
366
367pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
368 if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
369 return global.0.clone();
370 }
371 let entity =
372 cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
373 cx.set_global(GlobalSweepApiKey(entity.clone()));
374 entity
375}
376
377pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
378 sweep_api_token(cx).update(cx, |key_state, cx| {
379 key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
380 })
381}
382
383#[derive(Debug, Clone, Serialize)]
384struct AutocompleteRequest {
385 pub debug_info: Arc<str>,
386 pub repo_name: String,
387 pub branch: Option<String>,
388 pub file_path: Arc<Path>,
389 pub file_contents: String,
390 pub recent_changes: String,
391 pub cursor_position: usize,
392 pub original_file_contents: String,
393 pub file_chunks: Vec<FileChunk>,
394 pub retrieval_chunks: Vec<FileChunk>,
395 pub recent_user_actions: Vec<UserAction>,
396 pub multiple_suggestions: bool,
397 pub privacy_mode_enabled: bool,
398 pub changes_above_cursor: bool,
399 pub use_bytes: bool,
400}
401
402#[derive(Debug, Clone, Serialize)]
403struct FileChunk {
404 pub file_path: String,
405 pub start_line: usize,
406 pub end_line: usize,
407 pub content: String,
408 pub timestamp: Option<u64>,
409}
410
411#[derive(Debug, Clone, Serialize)]
412struct UserAction {
413 pub action_type: ActionType,
414 pub line_number: usize,
415 pub offset: usize,
416 pub file_path: String,
417 pub timestamp: u64,
418}
419
420#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
421#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
422enum ActionType {
423 CursorMovement,
424 InsertChar,
425 DeleteChar,
426 InsertSelection,
427 DeleteSelection,
428}
429
430fn to_sweep_user_action(record: &UserActionRecord, file_path: &str) -> UserAction {
431 UserAction {
432 action_type: match record.action_type {
433 UserActionType::InsertChar => ActionType::InsertChar,
434 UserActionType::InsertSelection => ActionType::InsertSelection,
435 UserActionType::DeleteChar => ActionType::DeleteChar,
436 UserActionType::DeleteSelection => ActionType::DeleteSelection,
437 UserActionType::CursorMovement => ActionType::CursorMovement,
438 },
439 line_number: record.line_number as usize,
440 offset: record.offset,
441 file_path: file_path.to_string(),
442 timestamp: record.timestamp_epoch_ms,
443 }
444}
445
446#[derive(Debug, Clone, Deserialize)]
447struct AutocompleteResponse {
448 pub autocomplete_id: String,
449 pub start_index: usize,
450 pub end_index: usize,
451 pub completion: String,
452 #[allow(dead_code)]
453 pub confidence: f64,
454 #[allow(dead_code)]
455 pub logprobs: Option<serde_json::Value>,
456 #[allow(dead_code)]
457 pub finish_reason: Option<String>,
458 #[allow(dead_code)]
459 pub elapsed_time_ms: u64,
460 #[allow(dead_code)]
461 #[serde(default, rename = "completions")]
462 pub additional_completions: Vec<AdditionalCompletion>,
463}
464
465#[allow(dead_code)]
466#[derive(Debug, Clone, Deserialize)]
467struct AdditionalCompletion {
468 pub start_index: usize,
469 pub end_index: usize,
470 pub completion: String,
471 pub confidence: f64,
472 pub autocomplete_id: String,
473 pub logprobs: Option<serde_json::Value>,
474 pub finish_reason: Option<String>,
475}
476
477fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
478 match event {
479 zeta_prompt::Event::BufferChange {
480 old_path,
481 path,
482 diff,
483 ..
484 } => {
485 if old_path != path {
486 // TODO confirm how to do this for sweep
487 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
488 }
489
490 if !diff.is_empty() {
491 write!(f, "File: {}:\n{}\n", path.display(), diff)?
492 }
493
494 fmt::Result::Ok(())
495 }
496 }
497}
498
499fn debug_info(cx: &gpui::App) -> Arc<str> {
500 format!(
501 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
502 version = release_channel::AppVersion::global(cx),
503 sha = release_channel::AppCommitSha::try_global(cx)
504 .map_or("unknown".to_string(), |sha| sha.full()),
505 os = client::telemetry::os_name(),
506 )
507 .into()
508}
509
510#[derive(Debug, Clone, Copy, Serialize)]
511#[serde(rename_all = "snake_case")]
512pub enum SweepEventType {
513 AutocompleteSuggestionShown,
514 AutocompleteSuggestionAccepted,
515}
516
517#[derive(Debug, Clone, Copy, Serialize)]
518#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
519pub enum SweepSuggestionType {
520 GhostText,
521 Popup,
522 JumpToEdit,
523}
524
525#[derive(Debug, Clone, Serialize)]
526struct AutocompleteMetricsRequest {
527 event_type: SweepEventType,
528 suggestion_type: SweepSuggestionType,
529 additions: u32,
530 deletions: u32,
531 autocomplete_id: String,
532 edit_tracking: String,
533 edit_tracking_line: Option<u32>,
534 lifespan: Option<u64>,
535 debug_info: Arc<str>,
536 device_id: String,
537 privacy_mode_enabled: bool,
538}
539
540fn send_autocomplete_metrics_request(
541 cx: &App,
542 client: Arc<Client>,
543 api_token: Arc<str>,
544 request_body: AutocompleteMetricsRequest,
545) {
546 let http_client = client.http_client();
547 cx.background_spawn(async move {
548 let body: AsyncBody = serde_json::to_string(&request_body)?.into();
549
550 let request = http_client::Request::builder()
551 .uri(SWEEP_METRICS_URL)
552 .header("Content-Type", "application/json")
553 .header("Authorization", format!("Bearer {}", api_token))
554 .method(Method::POST)
555 .body(body)?;
556
557 let mut response = http_client.send(request).await?;
558
559 if !response.status().is_success() {
560 let mut body = String::new();
561 response.body_mut().read_to_string(&mut body).await?;
562 anyhow::bail!(
563 "Failed to send autocomplete metrics for sweep_ai: {:?}\nBody: {}",
564 response.status(),
565 body,
566 );
567 }
568
569 Ok(())
570 })
571 .detach_and_log_err(cx);
572}
573
574pub(crate) fn edit_prediction_accepted(
575 store: &EditPredictionStore,
576 current_prediction: CurrentEditPrediction,
577 cx: &App,
578) {
579 let Some(api_token) = store
580 .sweep_ai
581 .api_token
582 .read(cx)
583 .key(&SWEEP_CREDENTIALS_URL)
584 else {
585 return;
586 };
587 let debug_info = store.sweep_ai.debug_info.clone();
588
589 let prediction = current_prediction.prediction;
590
591 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
592 let autocomplete_id = prediction.id.to_string();
593
594 let device_id = store
595 .client
596 .user_id()
597 .as_ref()
598 .map(ToString::to_string)
599 .unwrap_or_default();
600
601 let suggestion_type = match current_prediction.shown_with {
602 Some(SuggestionDisplayType::DiffPopover) => SweepSuggestionType::Popup,
603 Some(SuggestionDisplayType::Jump) => return, // should'nt happen
604 Some(SuggestionDisplayType::GhostText) | None => SweepSuggestionType::GhostText,
605 };
606
607 let request_body = AutocompleteMetricsRequest {
608 event_type: SweepEventType::AutocompleteSuggestionAccepted,
609 suggestion_type,
610 additions,
611 deletions,
612 autocomplete_id,
613 edit_tracking: String::new(),
614 edit_tracking_line: None,
615 lifespan: None,
616 debug_info,
617 device_id,
618 privacy_mode_enabled: false,
619 };
620
621 send_autocomplete_metrics_request(cx, store.client.clone(), api_token, request_body);
622}
623
624pub fn edit_prediction_shown(
625 sweep_ai: &SweepAi,
626 client: Arc<Client>,
627 prediction: &EditPrediction,
628 display_type: SuggestionDisplayType,
629 cx: &App,
630) {
631 let Some(api_token) = sweep_ai.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
632 return;
633 };
634 let debug_info = sweep_ai.debug_info.clone();
635
636 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
637 let autocomplete_id = prediction.id.to_string();
638
639 let suggestion_type = match display_type {
640 SuggestionDisplayType::GhostText => SweepSuggestionType::GhostText,
641 SuggestionDisplayType::DiffPopover => SweepSuggestionType::Popup,
642 SuggestionDisplayType::Jump => SweepSuggestionType::JumpToEdit,
643 };
644
645 let request_body = AutocompleteMetricsRequest {
646 event_type: SweepEventType::AutocompleteSuggestionShown,
647 suggestion_type,
648 additions,
649 deletions,
650 autocomplete_id,
651 edit_tracking: String::new(),
652 edit_tracking_line: None,
653 lifespan: None,
654 debug_info,
655 device_id: String::new(),
656 privacy_mode_enabled: false,
657 };
658
659 send_autocomplete_metrics_request(cx, client, api_token, request_body);
660}
661
662fn compute_edit_metrics(
663 edits: &[(Range<Anchor>, Arc<str>)],
664 snapshot: &BufferSnapshot,
665) -> (u32, u32) {
666 let mut additions = 0u32;
667 let mut deletions = 0u32;
668
669 for (range, new_text) in edits {
670 let old_text = snapshot.text_for_range(range.clone());
671 deletions += old_text
672 .map(|chunk| chunk.lines().count())
673 .sum::<usize>()
674 .max(1) as u32;
675 additions += new_text.lines().count().max(1) as u32;
676 }
677
678 (additions, deletions)
679}