sweep_ai.rs

  1use crate::{
  2    CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
  3    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  4    EditPredictionStore, UserActionRecord, UserActionType, prediction::EditPredictionResult,
  5};
  6use anyhow::{Result, bail};
  7use client::Client;
  8use edit_prediction_types::SuggestionDisplayType;
  9use futures::{AsyncReadExt as _, channel::mpsc};
 10use gpui::{
 11    App, AppContext as _, Entity, Global, SharedString, Task,
 12    http_client::{self, AsyncBody, Method},
 13};
 14use language::language_settings::all_language_settings;
 15use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
 16use language_model::{ApiKeyState, EnvVar, env_var};
 17use lsp::DiagnosticSeverity;
 18use serde::{Deserialize, Serialize};
 19use std::{
 20    fmt::{self, Write as _},
 21    ops::Range,
 22    path::Path,
 23    sync::Arc,
 24};
 25
 26const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 27const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
 28
 29pub struct SweepAi {
 30    pub api_token: Entity<ApiKeyState>,
 31    pub debug_info: Arc<str>,
 32}
 33
 34impl SweepAi {
 35    pub fn new(cx: &mut App) -> Self {
 36        SweepAi {
 37            api_token: sweep_api_token(cx),
 38            debug_info: debug_info(cx),
 39        }
 40    }
 41
 42    pub fn request_prediction_with_sweep(
 43        &self,
 44        inputs: EditPredictionModelInput,
 45        cx: &mut App,
 46    ) -> Task<Result<Option<EditPredictionResult>>> {
 47        let privacy_mode_enabled = all_language_settings(None, cx)
 48            .edit_predictions
 49            .sweep
 50            .privacy_mode;
 51        let debug_info = self.debug_info.clone();
 52        let request_start = cx.background_executor().now();
 53        self.api_token.update(cx, |key_state, cx| {
 54            _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
 55        });
 56
 57        let buffer = inputs.buffer.clone();
 58        let debug_tx = inputs.debug_tx.clone();
 59
 60        let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
 61            return Task::ready(Ok(None));
 62        };
 63        let full_path: Arc<Path> = inputs
 64            .snapshot
 65            .file()
 66            .map(|file| file.full_path(cx))
 67            .unwrap_or_else(|| "untitled".into())
 68            .into();
 69
 70        let project_file = project::File::from_dyn(inputs.snapshot.file());
 71        let repo_name = project_file
 72            .map(|file| file.worktree.read(cx).root_name_str())
 73            .unwrap_or("untitled")
 74            .into();
 75        let offset = inputs.position.to_offset(&inputs.snapshot);
 76        let buffer_entity_id = inputs.buffer.entity_id();
 77
 78        let recent_buffers = inputs.recent_paths.iter().cloned();
 79        let http_client = cx.http_client();
 80
 81        let recent_buffer_snapshots = recent_buffers
 82            .filter_map(|project_path| {
 83                let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
 84                if inputs.buffer == buffer {
 85                    None
 86                } else {
 87                    Some(buffer.read(cx).snapshot())
 88                }
 89            })
 90            .take(3)
 91            .collect::<Vec<_>>();
 92
 93        let result = cx.background_spawn(async move {
 94            let text = inputs.snapshot.text();
 95
 96            let mut recent_changes = String::new();
 97            for event in &inputs.events {
 98                write_event(event.as_ref(), &mut recent_changes).unwrap();
 99            }
100
101            let file_chunks = recent_buffer_snapshots
102                .into_iter()
103                .map(|snapshot| {
104                    let end_point = Point::new(30, 0).min(snapshot.max_point());
105                    FileChunk {
106                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
107                        file_path: snapshot
108                            .file()
109                            .map(|f| f.path().as_unix_str())
110                            .unwrap_or("untitled")
111                            .to_string(),
112                        start_line: 0,
113                        end_line: end_point.row as usize,
114                        timestamp: snapshot.file().and_then(|file| {
115                            Some(
116                                file.disk_state()
117                                    .mtime()?
118                                    .to_seconds_and_nanos_for_persistence()?
119                                    .0,
120                            )
121                        }),
122                    }
123                })
124                .collect::<Vec<_>>();
125
126            let mut retrieval_chunks: Vec<FileChunk> = inputs
127                .related_files
128                .iter()
129                .flat_map(|related_file| {
130                    related_file.excerpts.iter().map(|excerpt| FileChunk {
131                        file_path: related_file.path.to_string_lossy().to_string(),
132                        start_line: excerpt.row_range.start as usize,
133                        end_line: excerpt.row_range.end as usize,
134                        content: excerpt.text.to_string(),
135                        timestamp: None,
136                    })
137                })
138                .collect();
139
140            let diagnostic_entries = inputs
141                .snapshot
142                .diagnostics_in_range(inputs.diagnostic_search_range, false);
143            let mut diagnostic_content = String::new();
144            let mut diagnostic_count = 0;
145
146            for entry in diagnostic_entries {
147                let start_point: Point = entry.range.start;
148
149                let severity = match entry.diagnostic.severity {
150                    DiagnosticSeverity::ERROR => "error",
151                    DiagnosticSeverity::WARNING => "warning",
152                    DiagnosticSeverity::INFORMATION => "info",
153                    DiagnosticSeverity::HINT => "hint",
154                    _ => continue,
155                };
156
157                diagnostic_count += 1;
158
159                writeln!(
160                    &mut diagnostic_content,
161                    "{}:{}:{}: {}: {}",
162                    full_path.display(),
163                    start_point.row + 1,
164                    start_point.column + 1,
165                    severity,
166                    entry.diagnostic.message
167                )?;
168            }
169
170            if !diagnostic_content.is_empty() {
171                retrieval_chunks.push(FileChunk {
172                    file_path: "diagnostics".to_string(),
173                    start_line: 1,
174                    end_line: diagnostic_count,
175                    content: diagnostic_content,
176                    timestamp: None,
177                });
178            }
179
180            let file_path_str = full_path.display().to_string();
181            let recent_user_actions = inputs
182                .user_actions
183                .iter()
184                .filter(|r| r.buffer_id == buffer_entity_id)
185                .map(|r| to_sweep_user_action(r, &file_path_str))
186                .collect();
187
188            let request_body = AutocompleteRequest {
189                debug_info,
190                repo_name,
191                file_path: full_path.clone(),
192                file_contents: text.clone(),
193                original_file_contents: text,
194                cursor_position: offset,
195                recent_changes: recent_changes.clone(),
196                changes_above_cursor: true,
197                multiple_suggestions: false,
198                branch: None,
199                file_chunks,
200                retrieval_chunks,
201                recent_user_actions,
202                use_bytes: true,
203                privacy_mode_enabled,
204            };
205
206            let mut buf: Vec<u8> = Vec::new();
207            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 1, 22);
208            serde_json::to_writer(writer, &request_body)?;
209            let body: AsyncBody = buf.into();
210
211            let ep_inputs = zeta_prompt::ZetaPromptInput {
212                events: inputs.events,
213                related_files: Some(inputs.related_files.clone()),
214                active_buffer_diagnostics: vec![],
215                cursor_path: full_path.clone(),
216                cursor_excerpt: request_body.file_contents.clone().into(),
217                cursor_offset_in_excerpt: request_body.cursor_position,
218                excerpt_start_row: Some(0),
219                excerpt_ranges: zeta_prompt::ExcerptRanges {
220                    editable_150: 0..inputs.snapshot.len(),
221                    editable_180: 0..inputs.snapshot.len(),
222                    editable_350: 0..inputs.snapshot.len(),
223                    editable_150_context_350: 0..inputs.snapshot.len(),
224                    editable_180_context_350: 0..inputs.snapshot.len(),
225                    editable_350_context_150: 0..inputs.snapshot.len(),
226                    ..Default::default()
227                },
228                syntax_ranges: None,
229                experiment: None,
230                in_open_source_repo: false,
231                can_collect_data: false,
232                repo_url: None,
233            };
234
235            send_started_event(
236                &debug_tx,
237                &buffer,
238                inputs.position,
239                serde_json::to_string(&request_body).unwrap_or_default(),
240            );
241
242            let request = http_client::Request::builder()
243                .uri(SWEEP_API_URL)
244                .header("Content-Type", "application/json")
245                .header("Authorization", format!("Bearer {}", api_token))
246                .header("Connection", "keep-alive")
247                .header("Content-Encoding", "br")
248                .method(Method::POST)
249                .body(body)?;
250
251            let mut response = http_client.send(request).await?;
252
253            let mut body = String::new();
254            response.body_mut().read_to_string(&mut body).await?;
255
256            if !response.status().is_success() {
257                let message = format!(
258                    "Request failed with status: {:?}\nBody: {}",
259                    response.status(),
260                    body,
261                );
262                send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
263                bail!(message);
264            };
265
266            let response: AutocompleteResponse = serde_json::from_str(&body)?;
267
268            send_finished_event(&debug_tx, &buffer, inputs.position, body);
269
270            let old_text = inputs
271                .snapshot
272                .text_for_range(response.start_index..response.end_index)
273                .collect::<String>();
274            let edits = language::text_diff(&old_text, &response.completion)
275                .into_iter()
276                .map(|(range, text)| {
277                    (
278                        inputs
279                            .snapshot
280                            .anchor_after(response.start_index + range.start)
281                            ..inputs
282                                .snapshot
283                                .anchor_before(response.start_index + range.end),
284                        text,
285                    )
286                })
287                .collect::<Vec<_>>();
288
289            anyhow::Ok((response.autocomplete_id, edits, inputs.snapshot, ep_inputs))
290        });
291
292        let buffer = inputs.buffer.clone();
293
294        cx.spawn(async move |cx| {
295            let (id, edits, old_snapshot, inputs) = result.await?;
296            anyhow::Ok(Some(
297                EditPredictionResult::new(
298                    EditPredictionId(id.into()),
299                    &buffer,
300                    &old_snapshot,
301                    edits.into(),
302                    None,
303                    inputs,
304                    None,
305                    cx.background_executor().now() - request_start,
306                    cx,
307                )
308                .await,
309            ))
310        })
311    }
312}
313
314fn send_started_event(
315    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
316    buffer: &Entity<Buffer>,
317    position: Anchor,
318    prompt: String,
319) {
320    if let Some(debug_tx) = debug_tx {
321        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
322            EditPredictionStartedDebugEvent {
323                buffer: buffer.downgrade(),
324                position,
325                prompt: Some(prompt),
326            },
327        ));
328    }
329}
330
331fn send_finished_event(
332    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
333    buffer: &Entity<Buffer>,
334    position: Anchor,
335    model_output: String,
336) {
337    if let Some(debug_tx) = debug_tx {
338        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
339            EditPredictionFinishedDebugEvent {
340                buffer: buffer.downgrade(),
341                position,
342                model_output: Some(model_output),
343            },
344        ));
345    }
346}
347
348pub const SWEEP_CREDENTIALS_URL: SharedString =
349    SharedString::new_static("https://autocomplete.sweep.dev");
350pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
351pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
352
353struct GlobalSweepApiKey(Entity<ApiKeyState>);
354
355impl Global for GlobalSweepApiKey {}
356
357pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
358    if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
359        return global.0.clone();
360    }
361    let entity =
362        cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
363    cx.set_global(GlobalSweepApiKey(entity.clone()));
364    entity
365}
366
367pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
368    sweep_api_token(cx).update(cx, |key_state, cx| {
369        key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
370    })
371}
372
373#[derive(Debug, Clone, Serialize)]
374struct AutocompleteRequest {
375    pub debug_info: Arc<str>,
376    pub repo_name: String,
377    pub branch: Option<String>,
378    pub file_path: Arc<Path>,
379    pub file_contents: String,
380    pub recent_changes: String,
381    pub cursor_position: usize,
382    pub original_file_contents: String,
383    pub file_chunks: Vec<FileChunk>,
384    pub retrieval_chunks: Vec<FileChunk>,
385    pub recent_user_actions: Vec<UserAction>,
386    pub multiple_suggestions: bool,
387    pub privacy_mode_enabled: bool,
388    pub changes_above_cursor: bool,
389    pub use_bytes: bool,
390}
391
392#[derive(Debug, Clone, Serialize)]
393struct FileChunk {
394    pub file_path: String,
395    pub start_line: usize,
396    pub end_line: usize,
397    pub content: String,
398    pub timestamp: Option<u64>,
399}
400
401#[derive(Debug, Clone, Serialize)]
402struct UserAction {
403    pub action_type: ActionType,
404    pub line_number: usize,
405    pub offset: usize,
406    pub file_path: String,
407    pub timestamp: u64,
408}
409
410#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
411#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
412enum ActionType {
413    CursorMovement,
414    InsertChar,
415    DeleteChar,
416    InsertSelection,
417    DeleteSelection,
418}
419
420fn to_sweep_user_action(record: &UserActionRecord, file_path: &str) -> UserAction {
421    UserAction {
422        action_type: match record.action_type {
423            UserActionType::InsertChar => ActionType::InsertChar,
424            UserActionType::InsertSelection => ActionType::InsertSelection,
425            UserActionType::DeleteChar => ActionType::DeleteChar,
426            UserActionType::DeleteSelection => ActionType::DeleteSelection,
427            UserActionType::CursorMovement => ActionType::CursorMovement,
428        },
429        line_number: record.line_number as usize,
430        offset: record.offset,
431        file_path: file_path.to_string(),
432        timestamp: record.timestamp_epoch_ms,
433    }
434}
435
436#[derive(Debug, Clone, Deserialize)]
437struct AutocompleteResponse {
438    pub autocomplete_id: String,
439    pub start_index: usize,
440    pub end_index: usize,
441    pub completion: String,
442    #[allow(dead_code)]
443    pub confidence: f64,
444    #[allow(dead_code)]
445    pub logprobs: Option<serde_json::Value>,
446    #[allow(dead_code)]
447    pub finish_reason: Option<String>,
448    #[allow(dead_code)]
449    pub elapsed_time_ms: u64,
450    #[allow(dead_code)]
451    #[serde(default, rename = "completions")]
452    pub additional_completions: Vec<AdditionalCompletion>,
453}
454
455#[allow(dead_code)]
456#[derive(Debug, Clone, Deserialize)]
457struct AdditionalCompletion {
458    pub start_index: usize,
459    pub end_index: usize,
460    pub completion: String,
461    pub confidence: f64,
462    pub autocomplete_id: String,
463    pub logprobs: Option<serde_json::Value>,
464    pub finish_reason: Option<String>,
465}
466
467fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
468    match event {
469        zeta_prompt::Event::BufferChange {
470            old_path,
471            path,
472            diff,
473            ..
474        } => {
475            if old_path != path {
476                // TODO confirm how to do this for sweep
477                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
478            }
479
480            if !diff.is_empty() {
481                write!(f, "File: {}:\n{}\n", path.display(), diff)?
482            }
483
484            fmt::Result::Ok(())
485        }
486    }
487}
488
489fn debug_info(cx: &gpui::App) -> Arc<str> {
490    format!(
491        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
492        version = release_channel::AppVersion::global(cx),
493        sha = release_channel::AppCommitSha::try_global(cx)
494            .map_or("unknown".to_string(), |sha| sha.full()),
495        os = client::telemetry::os_name(),
496    )
497    .into()
498}
499
500#[derive(Debug, Clone, Copy, Serialize)]
501#[serde(rename_all = "snake_case")]
502pub enum SweepEventType {
503    AutocompleteSuggestionShown,
504    AutocompleteSuggestionAccepted,
505}
506
507#[derive(Debug, Clone, Copy, Serialize)]
508#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
509pub enum SweepSuggestionType {
510    GhostText,
511    Popup,
512    JumpToEdit,
513}
514
515#[derive(Debug, Clone, Serialize)]
516struct AutocompleteMetricsRequest {
517    event_type: SweepEventType,
518    suggestion_type: SweepSuggestionType,
519    additions: u32,
520    deletions: u32,
521    autocomplete_id: String,
522    edit_tracking: String,
523    edit_tracking_line: Option<u32>,
524    lifespan: Option<u64>,
525    debug_info: Arc<str>,
526    device_id: String,
527    privacy_mode_enabled: bool,
528}
529
530fn send_autocomplete_metrics_request(
531    cx: &App,
532    client: Arc<Client>,
533    api_token: Arc<str>,
534    request_body: AutocompleteMetricsRequest,
535) {
536    let http_client = client.http_client();
537    cx.background_spawn(async move {
538        let body: AsyncBody = serde_json::to_string(&request_body)?.into();
539
540        let request = http_client::Request::builder()
541            .uri(SWEEP_METRICS_URL)
542            .header("Content-Type", "application/json")
543            .header("Authorization", format!("Bearer {}", api_token))
544            .method(Method::POST)
545            .body(body)?;
546
547        let mut response = http_client.send(request).await?;
548
549        if !response.status().is_success() {
550            let mut body = String::new();
551            response.body_mut().read_to_string(&mut body).await?;
552            anyhow::bail!(
553                "Failed to send autocomplete metrics for sweep_ai: {:?}\nBody: {}",
554                response.status(),
555                body,
556            );
557        }
558
559        Ok(())
560    })
561    .detach_and_log_err(cx);
562}
563
564pub(crate) fn edit_prediction_accepted(
565    store: &EditPredictionStore,
566    current_prediction: CurrentEditPrediction,
567    cx: &App,
568) {
569    let Some(api_token) = store
570        .sweep_ai
571        .api_token
572        .read(cx)
573        .key(&SWEEP_CREDENTIALS_URL)
574    else {
575        return;
576    };
577    let debug_info = store.sweep_ai.debug_info.clone();
578
579    let prediction = current_prediction.prediction;
580
581    let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
582    let autocomplete_id = prediction.id.to_string();
583
584    let device_id = store
585        .client
586        .user_id()
587        .as_ref()
588        .map(ToString::to_string)
589        .unwrap_or_default();
590
591    let suggestion_type = match current_prediction.shown_with {
592        Some(SuggestionDisplayType::DiffPopover) => SweepSuggestionType::Popup,
593        Some(SuggestionDisplayType::Jump) => return, // should'nt happen
594        Some(SuggestionDisplayType::GhostText) | None => SweepSuggestionType::GhostText,
595    };
596
597    let request_body = AutocompleteMetricsRequest {
598        event_type: SweepEventType::AutocompleteSuggestionAccepted,
599        suggestion_type,
600        additions,
601        deletions,
602        autocomplete_id,
603        edit_tracking: String::new(),
604        edit_tracking_line: None,
605        lifespan: None,
606        debug_info,
607        device_id,
608        privacy_mode_enabled: false,
609    };
610
611    send_autocomplete_metrics_request(cx, store.client.clone(), api_token, request_body);
612}
613
614pub fn edit_prediction_shown(
615    sweep_ai: &SweepAi,
616    client: Arc<Client>,
617    prediction: &EditPrediction,
618    display_type: SuggestionDisplayType,
619    cx: &App,
620) {
621    let Some(api_token) = sweep_ai.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
622        return;
623    };
624    let debug_info = sweep_ai.debug_info.clone();
625
626    let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
627    let autocomplete_id = prediction.id.to_string();
628
629    let suggestion_type = match display_type {
630        SuggestionDisplayType::GhostText => SweepSuggestionType::GhostText,
631        SuggestionDisplayType::DiffPopover => SweepSuggestionType::Popup,
632        SuggestionDisplayType::Jump => SweepSuggestionType::JumpToEdit,
633    };
634
635    let request_body = AutocompleteMetricsRequest {
636        event_type: SweepEventType::AutocompleteSuggestionShown,
637        suggestion_type,
638        additions,
639        deletions,
640        autocomplete_id,
641        edit_tracking: String::new(),
642        edit_tracking_line: None,
643        lifespan: None,
644        debug_info,
645        device_id: String::new(),
646        privacy_mode_enabled: false,
647    };
648
649    send_autocomplete_metrics_request(cx, client, api_token, request_body);
650}
651
652fn compute_edit_metrics(
653    edits: &[(Range<Anchor>, Arc<str>)],
654    snapshot: &BufferSnapshot,
655) -> (u32, u32) {
656    let mut additions = 0u32;
657    let mut deletions = 0u32;
658
659    for (range, new_text) in edits {
660        let old_text = snapshot.text_for_range(range.clone());
661        deletions += old_text
662            .map(|chunk| chunk.lines().count())
663            .sum::<usize>()
664            .max(1) as u32;
665        additions += new_text.lines().count().max(1) as u32;
666    }
667
668    (additions, deletions)
669}