1use crate::{
2 CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
3 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
4 EditPredictionStore, UserActionRecord, UserActionType, prediction::EditPredictionResult,
5};
6use anyhow::{Result, bail};
7use client::Client;
8use edit_prediction_types::SuggestionDisplayType;
9use futures::{AsyncReadExt as _, channel::mpsc};
10use gpui::{
11 App, AppContext as _, Entity, Global, SharedString, Task,
12 http_client::{self, AsyncBody, Method},
13};
14use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
15use language_model::{ApiKeyState, EnvVar, env_var};
16use lsp::DiagnosticSeverity;
17use serde::{Deserialize, Serialize};
18use std::{
19 fmt::{self, Write as _},
20 ops::Range,
21 path::Path,
22 sync::Arc,
23 time::Instant,
24};
25
26const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
27const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
28
29pub struct SweepAi {
30 pub api_token: Entity<ApiKeyState>,
31 pub debug_info: Arc<str>,
32}
33
34impl SweepAi {
35 pub fn new(cx: &mut App) -> Self {
36 SweepAi {
37 api_token: sweep_api_token(cx),
38 debug_info: debug_info(cx),
39 }
40 }
41
42 pub fn request_prediction_with_sweep(
43 &self,
44 inputs: EditPredictionModelInput,
45 cx: &mut App,
46 ) -> Task<Result<Option<EditPredictionResult>>> {
47 let debug_info = self.debug_info.clone();
48 self.api_token.update(cx, |key_state, cx| {
49 _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
50 });
51
52 let buffer = inputs.buffer.clone();
53 let debug_tx = inputs.debug_tx.clone();
54
55 let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
56 return Task::ready(Ok(None));
57 };
58 let full_path: Arc<Path> = inputs
59 .snapshot
60 .file()
61 .map(|file| file.full_path(cx))
62 .unwrap_or_else(|| "untitled".into())
63 .into();
64
65 let project_file = project::File::from_dyn(inputs.snapshot.file());
66 let repo_name = project_file
67 .map(|file| file.worktree.read(cx).root_name_str())
68 .unwrap_or("untitled")
69 .into();
70 let offset = inputs.position.to_offset(&inputs.snapshot);
71 let buffer_entity_id = inputs.buffer.entity_id();
72
73 let recent_buffers = inputs.recent_paths.iter().cloned();
74 let http_client = cx.http_client();
75
76 let recent_buffer_snapshots = recent_buffers
77 .filter_map(|project_path| {
78 let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
79 if inputs.buffer == buffer {
80 None
81 } else {
82 Some(buffer.read(cx).snapshot())
83 }
84 })
85 .take(3)
86 .collect::<Vec<_>>();
87
88 let buffer_snapshotted_at = Instant::now();
89
90 let result = cx.background_spawn(async move {
91 let text = inputs.snapshot.text();
92
93 let mut recent_changes = String::new();
94 for event in &inputs.events {
95 write_event(event.as_ref(), &mut recent_changes).unwrap();
96 }
97
98 let file_chunks = recent_buffer_snapshots
99 .into_iter()
100 .map(|snapshot| {
101 let end_point = Point::new(30, 0).min(snapshot.max_point());
102 FileChunk {
103 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
104 file_path: snapshot
105 .file()
106 .map(|f| f.path().as_unix_str())
107 .unwrap_or("untitled")
108 .to_string(),
109 start_line: 0,
110 end_line: end_point.row as usize,
111 timestamp: snapshot.file().and_then(|file| {
112 Some(
113 file.disk_state()
114 .mtime()?
115 .to_seconds_and_nanos_for_persistence()?
116 .0,
117 )
118 }),
119 }
120 })
121 .collect::<Vec<_>>();
122
123 let mut retrieval_chunks: Vec<FileChunk> = inputs
124 .related_files
125 .iter()
126 .flat_map(|related_file| {
127 related_file.excerpts.iter().map(|excerpt| FileChunk {
128 file_path: related_file.path.to_string_lossy().to_string(),
129 start_line: excerpt.row_range.start as usize,
130 end_line: excerpt.row_range.end as usize,
131 content: excerpt.text.to_string(),
132 timestamp: None,
133 })
134 })
135 .collect();
136
137 let diagnostic_entries = inputs
138 .snapshot
139 .diagnostics_in_range(inputs.diagnostic_search_range, false);
140 let mut diagnostic_content = String::new();
141 let mut diagnostic_count = 0;
142
143 for entry in diagnostic_entries {
144 let start_point: Point = entry.range.start;
145
146 let severity = match entry.diagnostic.severity {
147 DiagnosticSeverity::ERROR => "error",
148 DiagnosticSeverity::WARNING => "warning",
149 DiagnosticSeverity::INFORMATION => "info",
150 DiagnosticSeverity::HINT => "hint",
151 _ => continue,
152 };
153
154 diagnostic_count += 1;
155
156 writeln!(
157 &mut diagnostic_content,
158 "{}:{}:{}: {}: {}",
159 full_path.display(),
160 start_point.row + 1,
161 start_point.column + 1,
162 severity,
163 entry.diagnostic.message
164 )?;
165 }
166
167 if !diagnostic_content.is_empty() {
168 retrieval_chunks.push(FileChunk {
169 file_path: "diagnostics".to_string(),
170 start_line: 1,
171 end_line: diagnostic_count,
172 content: diagnostic_content,
173 timestamp: None,
174 });
175 }
176
177 let file_path_str = full_path.display().to_string();
178 let recent_user_actions = inputs
179 .user_actions
180 .iter()
181 .filter(|r| r.buffer_id == buffer_entity_id)
182 .map(|r| to_sweep_user_action(r, &file_path_str))
183 .collect();
184
185 let request_body = AutocompleteRequest {
186 debug_info,
187 repo_name,
188 file_path: full_path.clone(),
189 file_contents: text.clone(),
190 original_file_contents: text,
191 cursor_position: offset,
192 recent_changes: recent_changes.clone(),
193 changes_above_cursor: true,
194 multiple_suggestions: false,
195 branch: None,
196 file_chunks,
197 retrieval_chunks,
198 recent_user_actions,
199 use_bytes: true,
200 // TODO
201 privacy_mode_enabled: false,
202 };
203
204 let mut buf: Vec<u8> = Vec::new();
205 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 1, 22);
206 serde_json::to_writer(writer, &request_body)?;
207 let body: AsyncBody = buf.into();
208
209 let ep_inputs = zeta_prompt::ZetaPromptInput {
210 events: inputs.events,
211 related_files: inputs.related_files.clone(),
212 cursor_path: full_path.clone(),
213 cursor_excerpt: request_body.file_contents.clone().into(),
214 // we actually don't know
215 editable_range_in_excerpt: 0..inputs.snapshot.len(),
216 cursor_offset_in_excerpt: request_body.cursor_position,
217 };
218
219 send_started_event(
220 &debug_tx,
221 &buffer,
222 inputs.position,
223 serde_json::to_string(&request_body).unwrap_or_default(),
224 );
225
226 let request = http_client::Request::builder()
227 .uri(SWEEP_API_URL)
228 .header("Content-Type", "application/json")
229 .header("Authorization", format!("Bearer {}", api_token))
230 .header("Connection", "keep-alive")
231 .header("Content-Encoding", "br")
232 .method(Method::POST)
233 .body(body)?;
234
235 let mut response = http_client.send(request).await?;
236
237 let mut body = String::new();
238 response.body_mut().read_to_string(&mut body).await?;
239
240 let response_received_at = Instant::now();
241 if !response.status().is_success() {
242 let message = format!(
243 "Request failed with status: {:?}\nBody: {}",
244 response.status(),
245 body,
246 );
247 send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
248 bail!(message);
249 };
250
251 let response: AutocompleteResponse = serde_json::from_str(&body)?;
252
253 send_finished_event(&debug_tx, &buffer, inputs.position, body);
254
255 let old_text = inputs
256 .snapshot
257 .text_for_range(response.start_index..response.end_index)
258 .collect::<String>();
259 let edits = language::text_diff(&old_text, &response.completion)
260 .into_iter()
261 .map(|(range, text)| {
262 (
263 inputs
264 .snapshot
265 .anchor_after(response.start_index + range.start)
266 ..inputs
267 .snapshot
268 .anchor_before(response.start_index + range.end),
269 text,
270 )
271 })
272 .collect::<Vec<_>>();
273
274 anyhow::Ok((
275 response.autocomplete_id,
276 edits,
277 inputs.snapshot,
278 response_received_at,
279 ep_inputs,
280 ))
281 });
282
283 let buffer = inputs.buffer.clone();
284
285 cx.spawn(async move |cx| {
286 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
287 anyhow::Ok(Some(
288 EditPredictionResult::new(
289 EditPredictionId(id.into()),
290 &buffer,
291 &old_snapshot,
292 edits.into(),
293 buffer_snapshotted_at,
294 response_received_at,
295 inputs,
296 cx,
297 )
298 .await,
299 ))
300 })
301 }
302}
303
304fn send_started_event(
305 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
306 buffer: &Entity<Buffer>,
307 position: Anchor,
308 prompt: String,
309) {
310 if let Some(debug_tx) = debug_tx {
311 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
312 EditPredictionStartedDebugEvent {
313 buffer: buffer.downgrade(),
314 position,
315 prompt: Some(prompt),
316 },
317 ));
318 }
319}
320
321fn send_finished_event(
322 debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
323 buffer: &Entity<Buffer>,
324 position: Anchor,
325 model_output: String,
326) {
327 if let Some(debug_tx) = debug_tx {
328 _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
329 EditPredictionFinishedDebugEvent {
330 buffer: buffer.downgrade(),
331 position,
332 model_output: Some(model_output),
333 },
334 ));
335 }
336}
337
338pub const SWEEP_CREDENTIALS_URL: SharedString =
339 SharedString::new_static("https://autocomplete.sweep.dev");
340pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
341pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
342
343struct GlobalSweepApiKey(Entity<ApiKeyState>);
344
345impl Global for GlobalSweepApiKey {}
346
347pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
348 if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
349 return global.0.clone();
350 }
351 let entity =
352 cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
353 cx.set_global(GlobalSweepApiKey(entity.clone()));
354 entity
355}
356
357pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
358 sweep_api_token(cx).update(cx, |key_state, cx| {
359 key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
360 })
361}
362
363#[derive(Debug, Clone, Serialize)]
364struct AutocompleteRequest {
365 pub debug_info: Arc<str>,
366 pub repo_name: String,
367 pub branch: Option<String>,
368 pub file_path: Arc<Path>,
369 pub file_contents: String,
370 pub recent_changes: String,
371 pub cursor_position: usize,
372 pub original_file_contents: String,
373 pub file_chunks: Vec<FileChunk>,
374 pub retrieval_chunks: Vec<FileChunk>,
375 pub recent_user_actions: Vec<UserAction>,
376 pub multiple_suggestions: bool,
377 pub privacy_mode_enabled: bool,
378 pub changes_above_cursor: bool,
379 pub use_bytes: bool,
380}
381
382#[derive(Debug, Clone, Serialize)]
383struct FileChunk {
384 pub file_path: String,
385 pub start_line: usize,
386 pub end_line: usize,
387 pub content: String,
388 pub timestamp: Option<u64>,
389}
390
391#[derive(Debug, Clone, Serialize)]
392struct UserAction {
393 pub action_type: ActionType,
394 pub line_number: usize,
395 pub offset: usize,
396 pub file_path: String,
397 pub timestamp: u64,
398}
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
401#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
402enum ActionType {
403 CursorMovement,
404 InsertChar,
405 DeleteChar,
406 InsertSelection,
407 DeleteSelection,
408}
409
410fn to_sweep_user_action(record: &UserActionRecord, file_path: &str) -> UserAction {
411 UserAction {
412 action_type: match record.action_type {
413 UserActionType::InsertChar => ActionType::InsertChar,
414 UserActionType::InsertSelection => ActionType::InsertSelection,
415 UserActionType::DeleteChar => ActionType::DeleteChar,
416 UserActionType::DeleteSelection => ActionType::DeleteSelection,
417 UserActionType::CursorMovement => ActionType::CursorMovement,
418 },
419 line_number: record.line_number as usize,
420 offset: record.offset,
421 file_path: file_path.to_string(),
422 timestamp: record.timestamp_epoch_ms,
423 }
424}
425
426#[derive(Debug, Clone, Deserialize)]
427struct AutocompleteResponse {
428 pub autocomplete_id: String,
429 pub start_index: usize,
430 pub end_index: usize,
431 pub completion: String,
432 #[allow(dead_code)]
433 pub confidence: f64,
434 #[allow(dead_code)]
435 pub logprobs: Option<serde_json::Value>,
436 #[allow(dead_code)]
437 pub finish_reason: Option<String>,
438 #[allow(dead_code)]
439 pub elapsed_time_ms: u64,
440 #[allow(dead_code)]
441 #[serde(default, rename = "completions")]
442 pub additional_completions: Vec<AdditionalCompletion>,
443}
444
445#[allow(dead_code)]
446#[derive(Debug, Clone, Deserialize)]
447struct AdditionalCompletion {
448 pub start_index: usize,
449 pub end_index: usize,
450 pub completion: String,
451 pub confidence: f64,
452 pub autocomplete_id: String,
453 pub logprobs: Option<serde_json::Value>,
454 pub finish_reason: Option<String>,
455}
456
457fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
458 match event {
459 zeta_prompt::Event::BufferChange {
460 old_path,
461 path,
462 diff,
463 ..
464 } => {
465 if old_path != path {
466 // TODO confirm how to do this for sweep
467 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
468 }
469
470 if !diff.is_empty() {
471 write!(f, "File: {}:\n{}\n", path.display(), diff)?
472 }
473
474 fmt::Result::Ok(())
475 }
476 }
477}
478
479fn debug_info(cx: &gpui::App) -> Arc<str> {
480 format!(
481 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
482 version = release_channel::AppVersion::global(cx),
483 sha = release_channel::AppCommitSha::try_global(cx)
484 .map_or("unknown".to_string(), |sha| sha.full()),
485 os = client::telemetry::os_name(),
486 )
487 .into()
488}
489
490#[derive(Debug, Clone, Copy, Serialize)]
491#[serde(rename_all = "snake_case")]
492pub enum SweepEventType {
493 AutocompleteSuggestionShown,
494 AutocompleteSuggestionAccepted,
495}
496
497#[derive(Debug, Clone, Copy, Serialize)]
498#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
499pub enum SweepSuggestionType {
500 GhostText,
501 Popup,
502 JumpToEdit,
503}
504
505#[derive(Debug, Clone, Serialize)]
506struct AutocompleteMetricsRequest {
507 event_type: SweepEventType,
508 suggestion_type: SweepSuggestionType,
509 additions: u32,
510 deletions: u32,
511 autocomplete_id: String,
512 edit_tracking: String,
513 edit_tracking_line: Option<u32>,
514 lifespan: Option<u64>,
515 debug_info: Arc<str>,
516 device_id: String,
517 privacy_mode_enabled: bool,
518}
519
520fn send_autocomplete_metrics_request(
521 cx: &App,
522 client: Arc<Client>,
523 api_token: Arc<str>,
524 request_body: AutocompleteMetricsRequest,
525) {
526 let http_client = client.http_client();
527 cx.background_spawn(async move {
528 let body: AsyncBody = serde_json::to_string(&request_body)?.into();
529
530 let request = http_client::Request::builder()
531 .uri(SWEEP_METRICS_URL)
532 .header("Content-Type", "application/json")
533 .header("Authorization", format!("Bearer {}", api_token))
534 .method(Method::POST)
535 .body(body)?;
536
537 let mut response = http_client.send(request).await?;
538
539 if !response.status().is_success() {
540 let mut body = String::new();
541 response.body_mut().read_to_string(&mut body).await?;
542 anyhow::bail!(
543 "Failed to send autocomplete metrics for sweep_ai: {:?}\nBody: {}",
544 response.status(),
545 body,
546 );
547 }
548
549 Ok(())
550 })
551 .detach_and_log_err(cx);
552}
553
554pub(crate) fn edit_prediction_accepted(
555 store: &EditPredictionStore,
556 current_prediction: CurrentEditPrediction,
557 cx: &App,
558) {
559 let Some(api_token) = store
560 .sweep_ai
561 .api_token
562 .read(cx)
563 .key(&SWEEP_CREDENTIALS_URL)
564 else {
565 return;
566 };
567 let debug_info = store.sweep_ai.debug_info.clone();
568
569 let prediction = current_prediction.prediction;
570
571 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
572 let autocomplete_id = prediction.id.to_string();
573
574 let device_id = store
575 .client
576 .user_id()
577 .as_ref()
578 .map(ToString::to_string)
579 .unwrap_or_default();
580
581 let suggestion_type = match current_prediction.shown_with {
582 Some(SuggestionDisplayType::DiffPopover) => SweepSuggestionType::Popup,
583 Some(SuggestionDisplayType::Jump) => return, // should'nt happen
584 Some(SuggestionDisplayType::GhostText) | None => SweepSuggestionType::GhostText,
585 };
586
587 let request_body = AutocompleteMetricsRequest {
588 event_type: SweepEventType::AutocompleteSuggestionAccepted,
589 suggestion_type,
590 additions,
591 deletions,
592 autocomplete_id,
593 edit_tracking: String::new(),
594 edit_tracking_line: None,
595 lifespan: None,
596 debug_info,
597 device_id,
598 privacy_mode_enabled: false,
599 };
600
601 send_autocomplete_metrics_request(cx, store.client.clone(), api_token, request_body);
602}
603
604pub fn edit_prediction_shown(
605 sweep_ai: &SweepAi,
606 client: Arc<Client>,
607 prediction: &EditPrediction,
608 display_type: SuggestionDisplayType,
609 cx: &App,
610) {
611 let Some(api_token) = sweep_ai.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
612 return;
613 };
614 let debug_info = sweep_ai.debug_info.clone();
615
616 let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
617 let autocomplete_id = prediction.id.to_string();
618
619 let suggestion_type = match display_type {
620 SuggestionDisplayType::GhostText => SweepSuggestionType::GhostText,
621 SuggestionDisplayType::DiffPopover => SweepSuggestionType::Popup,
622 SuggestionDisplayType::Jump => SweepSuggestionType::JumpToEdit,
623 };
624
625 let request_body = AutocompleteMetricsRequest {
626 event_type: SweepEventType::AutocompleteSuggestionShown,
627 suggestion_type,
628 additions,
629 deletions,
630 autocomplete_id,
631 edit_tracking: String::new(),
632 edit_tracking_line: None,
633 lifespan: None,
634 debug_info,
635 device_id: String::new(),
636 privacy_mode_enabled: false,
637 };
638
639 send_autocomplete_metrics_request(cx, client, api_token, request_body);
640}
641
642fn compute_edit_metrics(
643 edits: &[(Range<Anchor>, Arc<str>)],
644 snapshot: &BufferSnapshot,
645) -> (u32, u32) {
646 let mut additions = 0u32;
647 let mut deletions = 0u32;
648
649 for (range, new_text) in edits {
650 let old_text = snapshot.text_for_range(range.clone());
651 deletions += old_text
652 .map(|chunk| chunk.lines().count())
653 .sum::<usize>()
654 .max(1) as u32;
655 additions += new_text.lines().count().max(1) as u32;
656 }
657
658 (additions, deletions)
659}