1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
3 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
4 EditPredictionStore, UserActionRecord, UserActionType, prediction::EditPredictionResult,
5};
6use anyhow::{Result, bail};
7use client::Client;
8use edit_prediction_types::SuggestionDisplayType;
9use futures::{AsyncReadExt as _, channel::mpsc};
10use gpui::{
11 App, AppContext as _, Entity, Global, SharedString, Task,
12 http_client::{self, AsyncBody, Method},
13};
14use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
15use language_model::{ApiKeyState, EnvVar, env_var};
16use lsp::DiagnosticSeverity;
17use serde::{Deserialize, Serialize};
18use std::{
19 fmt::{self, Write as _},
20 ops::Range,
21 path::Path,
22 sync::Arc,
23 time::Instant,
24};
25
26const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
27const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
28
29pub struct SweepAi {
30 pub api_token: Entity<ApiKeyState>,
31 pub debug_info: Arc<str>,
32}
33
34impl SweepAi {
35 pub fn new(cx: &mut App) -> Self {
36 SweepAi {
37 api_token: sweep_api_token(cx),
38 debug_info: debug_info(cx),
39 }
40 }
41
42 pub fn request_prediction_with_sweep(
43 &self,
44 inputs: EditPredictionModelInput,
45 cx: &mut App,
46 ) -> Task<Result<Option<EditPredictionResult>>> {
47 let debug_info = self.debug_info.clone();
48 self.api_token.update(cx, |key_state, cx| {
49 _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
50 });
51
52 let buffer = inputs.buffer.clone();
53 let debug_tx = inputs.debug_tx.clone();
54
55 let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
56 return Task::ready(Ok(None));
57 };
58 let full_path: Arc<Path> = inputs
59 .snapshot
60 .file()
61 .map(|file| file.full_path(cx))
62 .unwrap_or_else(|| "untitled".into())
63 .into();
64
65 let project_file = project::File::from_dyn(inputs.snapshot.file());
66 let repo_name = project_file
67 .map(|file| file.worktree.read(cx).root_name_str())
68 .unwrap_or("untitled")
69 .into();
70 let offset = inputs.position.to_offset(&inputs.snapshot);
71 let buffer_entity_id = inputs.buffer.entity_id();
72
73 let recent_buffers = inputs.recent_paths.iter().cloned();
74 let http_client = cx.http_client();
75
76 let recent_buffer_snapshots = recent_buffers
77 .filter_map(|project_path| {
78 let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
79 if inputs.buffer == buffer {
80 None
81 } else {
82 Some(buffer.read(cx).snapshot())
83 }
84 })
85 .take(3)
86 .collect::<Vec<_>>();
87
88 let buffer_snapshotted_at = Instant::now();
89
90 let result = cx.background_spawn(async move {
91 let text = inputs.snapshot.text();
92
93 let mut recent_changes = String::new();
94 for event in &inputs.events {
95 write_event(event.as_ref(), &mut recent_changes).unwrap();
96 }
97
98 let mut file_chunks = recent_buffer_snapshots
99 .into_iter()
100 .map(|snapshot| {
101 let end_point = Point::new(30, 0).min(snapshot.max_point());
102 FileChunk {
103 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
104 file_path: snapshot
105 .file()
106 .map(|f| f.path().as_unix_str())
107 .unwrap_or("untitled")
108 .to_string(),
109 start_line: 0,
110 end_line: end_point.row as usize,
111 timestamp: snapshot.file().and_then(|file| {
112 Some(
113 file.disk_state()
114 .mtime()?
115 .to_seconds_and_nanos_for_persistence()?
116 .0,
117 )
118 }),
119 }
120 })
121 .collect::<Vec<_>>();
122
123 let retrieval_chunks = inputs
124 .related_files
125 .iter()
126 .flat_map(|related_file| {
127 related_file.excerpts.iter().map(|excerpt| FileChunk {
128 file_path: related_file.path.to_string_lossy().to_string(),
129 start_line: excerpt.row_range.start as usize,
130 end_line: excerpt.row_range.end as usize,
131 content: excerpt.text.to_string(),
132 timestamp: None,
133 })
134 })
135 .collect();
136
137 let diagnostic_entries = inputs
138 .snapshot
139 .diagnostics_in_range(inputs.diagnostic_search_range, false);
140 let mut diagnostic_content = String::new();
141 let mut diagnostic_count = 0;
142
143 for entry in diagnostic_entries {
144 let start_point: Point = entry.range.start;
145
146 let severity = match entry.diagnostic.severity {
147 DiagnosticSeverity::ERROR => "error",
148 DiagnosticSeverity::WARNING => "warning",
149 DiagnosticSeverity::INFORMATION => "info",
150 DiagnosticSeverity::HINT => "hint",
151 _ => continue,
152 };
153
154 diagnostic_count += 1;
155
156 writeln!(
157 &mut diagnostic_content,
158 "{} at line {}: {}",
159 severity,
160 start_point.row + 1,
161 entry.diagnostic.message
162 )?;
163 }
164
165 if !diagnostic_content.is_empty() {
166 file_chunks.push(FileChunk {
167 file_path: format!("Diagnostics for {}", full_path.display()),
168 start_line: 0,
169 end_line: diagnostic_count,
170 content: diagnostic_content,
171 timestamp: None,
172 });
173 }
174
175 let file_path_str = full_path.display().to_string();
176 let recent_user_actions = inputs
177 .user_actions
178 .iter()
179 .filter(|r| r.buffer_id == buffer_entity_id)
180 .map(|r| to_sweep_user_action(r, &file_path_str))
181 .collect();
182
183 let request_body = AutocompleteRequest {
184 debug_info,
185 repo_name,
186 file_path: full_path.clone(),
187 file_contents: text.clone(),
188 original_file_contents: text,
189 cursor_position: offset,
190 recent_changes: recent_changes.clone(),
191 changes_above_cursor: true,
192 multiple_suggestions: false,
193 branch: None,
194 file_chunks,
195 retrieval_chunks,
196 recent_user_actions,
197 use_bytes: true,
198 // TODO
199 privacy_mode_enabled: false,
200 };
201
202 let mut buf: Vec<u8> = Vec::new();
203 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
204 serde_json::to_writer(writer, &request_body)?;
205 let body: AsyncBody = buf.into();
206
207 let ep_inputs = zeta_prompt::ZetaPromptInput {
208 events: inputs.events,
209 related_files: inputs.related_files.clone(),
210 cursor_path: full_path.clone(),
211 cursor_excerpt: request_body.file_contents.clone().into(),
212 // we actually don't know
213 editable_range_in_excerpt: 0..inputs.snapshot.len(),
214 cursor_offset_in_excerpt: request_body.cursor_position,
215 };
216
217 send_started_event(
218 &debug_tx,
219 &buffer,
220 inputs.position,
221 serde_json::to_string(&request_body).unwrap_or_default(),
222 );
223
224 let request = http_client::Request::builder()
225 .uri(SWEEP_API_URL)
226 .header("Content-Type", "application/json")
227 .header("Authorization", format!("Bearer {}", api_token))
228 .header("Connection", "keep-alive")
229 .header("Content-Encoding", "br")
230 .method(Method::POST)
231 .body(body)?;
232
233 let mut response = http_client.send(request).await?;
234
235 let mut body = String::new();
236 response.body_mut().read_to_string(&mut body).await?;
237
238 let response_received_at = Instant::now();
239 if !response.status().is_success() {
240 let message = format!(
241 "Request failed with status: {:?}\nBody: {}",
242 response.status(),
243 body,
244 );
245 send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
246 bail!(message);
247 };
248
249 let response: AutocompleteResponse = serde_json::from_str(&body)?;
250
251 send_finished_event(&debug_tx, &buffer, inputs.position, body);
252
253 let old_text = inputs
254 .snapshot
255 .text_for_range(response.start_index..response.end_index)
256 .collect::<String>();
257 let edits = language::text_diff(&old_text, &response.completion)
258 .into_iter()
259 .map(|(range, text)| {
260 (
261 inputs
262 .snapshot
263 .anchor_after(response.start_index + range.start)
264 ..inputs
265 .snapshot
266 .anchor_before(response.start_index + range.end),
267 text,
268 )
269 })
270 .collect::<Vec<_>>();
271
272 anyhow::Ok((
273 response.autocomplete_id,
274 edits,
275 inputs.snapshot,
276 response_received_at,
277 ep_inputs,
278 ))
279 });
280
281 let buffer = inputs.buffer.clone();
282
283 cx.spawn(async move |cx| {
284 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
285 anyhow::Ok(Some(
286 EditPredictionResult::new(
287 EditPredictionId(id.into()),
288 &buffer,
289 &old_snapshot,
290 edits.into(),
291 buffer_snapshotted_at,
292 response_received_at,
293 inputs,
294 cx,
295 )
296 .await,
297 ))
298 })
299 }
300}
301
302fn send_started_event(
303 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
304 buffer: &Entity<Buffer>,
305 position: Anchor,
306 prompt: String,
307) {
308 if let Some(debug_tx) = debug_tx {
309 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
310 EditPredictionStartedDebugEvent {
311 buffer: buffer.downgrade(),
312 position,
313 prompt: Some(prompt),
314 },
315 ));
316 }
317}
318
319fn send_finished_event(
320 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
321 buffer: &Entity<Buffer>,
322 position: Anchor,
323 model_output: String,
324) {
325 if let Some(debug_tx) = debug_tx {
326 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
327 EditPredictionFinishedDebugEvent {
328 buffer: buffer.downgrade(),
329 position,
330 model_output: Some(model_output),
331 },
332 ));
333 }
334}
335
336pub const SWEEP_CREDENTIALS_URL: SharedString =
337 SharedString::new_static("https://autocomplete.sweep.dev");
338pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
339pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
340
341struct GlobalSweepApiKey(Entity<ApiKeyState>);
342
343impl Global for GlobalSweepApiKey {}
344
345pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
346 if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
347 return global.0.clone();
348 }
349 let entity =
350 cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
351 cx.set_global(GlobalSweepApiKey(entity.clone()));
352 entity
353}
354
355pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
356 sweep_api_token(cx).update(cx, |key_state, cx| {
357 key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
358 })
359}
360
361#[derive(Debug, Clone, Serialize)]
362struct AutocompleteRequest {
363 pub debug_info: Arc<str>,
364 pub repo_name: String,
365 pub branch: Option<String>,
366 pub file_path: Arc<Path>,
367 pub file_contents: String,
368 pub recent_changes: String,
369 pub cursor_position: usize,
370 pub original_file_contents: String,
371 pub file_chunks: Vec<FileChunk>,
372 pub retrieval_chunks: Vec<FileChunk>,
373 pub recent_user_actions: Vec<UserAction>,
374 pub multiple_suggestions: bool,
375 pub privacy_mode_enabled: bool,
376 pub changes_above_cursor: bool,
377 pub use_bytes: bool,
378}
379
380#[derive(Debug, Clone, Serialize)]
381struct FileChunk {
382 pub file_path: String,
383 pub start_line: usize,
384 pub end_line: usize,
385 pub content: String,
386 pub timestamp: Option<u64>,
387}
388
389#[derive(Debug, Clone, Serialize)]
390struct UserAction {
391 pub action_type: ActionType,
392 pub line_number: usize,
393 pub offset: usize,
394 pub file_path: String,
395 pub timestamp: u64,
396}
397
398#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
399#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
400enum ActionType {
401 CursorMovement,
402 InsertChar,
403 DeleteChar,
404 InsertSelection,
405 DeleteSelection,
406}
407
408fn to_sweep_user_action(record: &UserActionRecord, file_path: &str) -> UserAction {
409 UserAction {
410 action_type: match record.action_type {
411 UserActionType::InsertChar => ActionType::InsertChar,
412 UserActionType::InsertSelection => ActionType::InsertSelection,
413 UserActionType::DeleteChar => ActionType::DeleteChar,
414 UserActionType::DeleteSelection => ActionType::DeleteSelection,
415 UserActionType::CursorMovement => ActionType::CursorMovement,
416 },
417 line_number: record.line_number as usize,
418 offset: record.offset,
419 file_path: file_path.to_string(),
420 timestamp: record.timestamp_epoch_ms,
421 }
422}
423
424#[derive(Debug, Clone, Deserialize)]
425struct AutocompleteResponse {
426 pub autocomplete_id: String,
427 pub start_index: usize,
428 pub end_index: usize,
429 pub completion: String,
430 #[allow(dead_code)]
431 pub confidence: f64,
432 #[allow(dead_code)]
433 pub logprobs: Option<serde_json::Value>,
434 #[allow(dead_code)]
435 pub finish_reason: Option<String>,
436 #[allow(dead_code)]
437 pub elapsed_time_ms: u64,
438 #[allow(dead_code)]
439 #[serde(default, rename = "completions")]
440 pub additional_completions: Vec<AdditionalCompletion>,
441}
442
443#[allow(dead_code)]
444#[derive(Debug, Clone, Deserialize)]
445struct AdditionalCompletion {
446 pub start_index: usize,
447 pub end_index: usize,
448 pub completion: String,
449 pub confidence: f64,
450 pub autocomplete_id: String,
451 pub logprobs: Option<serde_json::Value>,
452 pub finish_reason: Option<String>,
453}
454
455fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
456 match event {
457 zeta_prompt::Event::BufferChange {
458 old_path,
459 path,
460 diff,
461 ..
462 } => {
463 if old_path != path {
464 // TODO confirm how to do this for sweep
465 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
466 }
467
468 if !diff.is_empty() {
469 write!(f, "File: {}:\n{}\n", path.display(), diff)?
470 }
471
472 fmt::Result::Ok(())
473 }
474 }
475}
476
477fn debug_info(cx: &gpui::App) -> Arc<str> {
478 format!(
479 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
480 version = release_channel::AppVersion::global(cx),
481 sha = release_channel::AppCommitSha::try_global(cx)
482 .map_or("unknown".to_string(), |sha| sha.full()),
483 os = client::telemetry::os_name(),
484 )
485 .into()
486}
487
488#[derive(Debug, Clone, Copy, Serialize)]
489#[serde(rename_all = "snake_case")]
490pub enum SweepEventType {
491 AutocompleteSuggestionShown,
492 AutocompleteSuggestionAccepted,
493}
494
495#[derive(Debug, Clone, Copy, Serialize)]
496#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
497pub enum SweepSuggestionType {
498 GhostText,
499 Popup,
500 JumpToEdit,
501}
502
503#[derive(Debug, Clone, Serialize)]
504struct AutocompleteMetricsRequest {
505 event_type: SweepEventType,
506 suggestion_type: SweepSuggestionType,
507 additions: u32,
508 deletions: u32,
509 autocomplete_id: String,
510 edit_tracking: String,
511 edit_tracking_line: Option<u32>,
512 lifespan: Option<u64>,
513 debug_info: Arc<str>,
514 device_id: String,
515 privacy_mode_enabled: bool,
516}
517
518fn send_autocomplete_metrics_request(
519 cx: &App,
520 client: Arc<Client>,
521 api_token: Arc<str>,
522 request_body: AutocompleteMetricsRequest,
523) {
524 let http_client = client.http_client();
525 cx.background_spawn(async move {
526 let body: AsyncBody = serde_json::to_string(&request_body)?.into();
527
528 let request = http_client::Request::builder()
529 .uri(SWEEP_METRICS_URL)
530 .header("Content-Type", "application/json")
531 .header("Authorization", format!("Bearer {}", api_token))
532 .method(Method::POST)
533 .body(body)?;
534
535 let mut response = http_client.send(request).await?;
536
537 if !response.status().is_success() {
538 let mut body = String::new();
539 response.body_mut().read_to_string(&mut body).await?;
540 anyhow::bail!(
541 "Failed to send autocomplete metrics for sweep_ai: {:?}\nBody: {}",
542 response.status(),
543 body,
544 );
545 }
546
547 Ok(())
548 })
549 .detach_and_log_err(cx);
550}
551
552pub(crate) fn edit_prediction_accepted(
553 store: &EditPredictionStore,
554 current_prediction: CurrentEditPrediction,
555 cx: &App,
556) {
557 let Some(api_token) = store
558 .sweep_ai
559 .api_token
560 .read(cx)
561 .key(&SWEEP_CREDENTIALS_URL)
562 else {
563 return;
564 };
565 let debug_info = store.sweep_ai.debug_info.clone();
566
567 let prediction = current_prediction.prediction;
568
569 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
570 let autocomplete_id = prediction.id.to_string();
571
572 let device_id = store
573 .client
574 .user_id()
575 .as_ref()
576 .map(ToString::to_string)
577 .unwrap_or_default();
578
579 let suggestion_type = match current_prediction.shown_with {
580 Some(SuggestionDisplayType::DiffPopover) => SweepSuggestionType::Popup,
581 Some(SuggestionDisplayType::Jump) => return, // should'nt happen
582 Some(SuggestionDisplayType::GhostText) | None => SweepSuggestionType::GhostText,
583 };
584
585 let request_body = AutocompleteMetricsRequest {
586 event_type: SweepEventType::AutocompleteSuggestionAccepted,
587 suggestion_type,
588 additions,
589 deletions,
590 autocomplete_id,
591 edit_tracking: String::new(),
592 edit_tracking_line: None,
593 lifespan: None,
594 debug_info,
595 device_id,
596 privacy_mode_enabled: false,
597 };
598
599 send_autocomplete_metrics_request(cx, store.client.clone(), api_token, request_body);
600}
601
602pub fn edit_prediction_shown(
603 sweep_ai: &SweepAi,
604 client: Arc<Client>,
605 prediction: &EditPrediction,
606 display_type: SuggestionDisplayType,
607 cx: &App,
608) {
609 let Some(api_token) = sweep_ai.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
610 return;
611 };
612 let debug_info = sweep_ai.debug_info.clone();
613
614 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
615 let autocomplete_id = prediction.id.to_string();
616
617 let suggestion_type = match display_type {
618 SuggestionDisplayType::GhostText => SweepSuggestionType::GhostText,
619 SuggestionDisplayType::DiffPopover => SweepSuggestionType::Popup,
620 SuggestionDisplayType::Jump => SweepSuggestionType::JumpToEdit,
621 };
622
623 let request_body = AutocompleteMetricsRequest {
624 event_type: SweepEventType::AutocompleteSuggestionShown,
625 suggestion_type,
626 additions,
627 deletions,
628 autocomplete_id,
629 edit_tracking: String::new(),
630 edit_tracking_line: None,
631 lifespan: None,
632 debug_info,
633 device_id: String::new(),
634 privacy_mode_enabled: false,
635 };
636
637 send_autocomplete_metrics_request(cx, client, api_token, request_body);
638}
639
640fn compute_edit_metrics(
641 edits: &[(Range<Anchor>, Arc<str>)],
642 snapshot: &BufferSnapshot,
643) -> (u32, u32) {
644 let mut additions = 0u32;
645 let mut deletions = 0u32;
646
647 for (range, new_text) in edits {
648 let old_text = snapshot.text_for_range(range.clone());
649 deletions += old_text
650 .map(|chunk| chunk.lines().count())
651 .sum::<usize>()
652 .max(1) as u32;
653 additions += new_text.lines().count().max(1) as u32;
654 }
655
656 (additions, deletions)
657}