sweep_ai.rs

  1use crate::{
  2    CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
  3    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  4    EditPredictionStore, UserActionRecord, UserActionType, prediction::EditPredictionResult,
  5};
  6use anyhow::{Result, bail};
  7use client::Client;
  8use edit_prediction_types::SuggestionDisplayType;
  9use futures::{AsyncReadExt as _, channel::mpsc};
 10use gpui::{
 11    App, AppContext as _, Entity, Global, SharedString, Task,
 12    http_client::{self, AsyncBody, Method},
 13};
 14use language::language_settings::all_language_settings;
 15use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
 16use language_model::{ApiKeyState, EnvVar, env_var};
 17use lsp::DiagnosticSeverity;
 18use serde::{Deserialize, Serialize};
 19use std::{
 20    fmt::{self, Write as _},
 21    ops::Range,
 22    path::Path,
 23    sync::Arc,
 24    time::Instant,
 25};
 26
 27const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 28const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
 29
 30pub struct SweepAi {
 31    pub api_token: Entity<ApiKeyState>,
 32    pub debug_info: Arc<str>,
 33}
 34
 35impl SweepAi {
 36    pub fn new(cx: &mut App) -> Self {
 37        SweepAi {
 38            api_token: sweep_api_token(cx),
 39            debug_info: debug_info(cx),
 40        }
 41    }
 42
 43    pub fn request_prediction_with_sweep(
 44        &self,
 45        inputs: EditPredictionModelInput,
 46        cx: &mut App,
 47    ) -> Task<Result<Option<EditPredictionResult>>> {
 48        let privacy_mode_enabled = all_language_settings(None, cx)
 49            .edit_predictions
 50            .sweep
 51            .privacy_mode;
 52        let debug_info = self.debug_info.clone();
 53        self.api_token.update(cx, |key_state, cx| {
 54            _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
 55        });
 56
 57        let buffer = inputs.buffer.clone();
 58        let debug_tx = inputs.debug_tx.clone();
 59
 60        let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
 61            return Task::ready(Ok(None));
 62        };
 63        let full_path: Arc<Path> = inputs
 64            .snapshot
 65            .file()
 66            .map(|file| file.full_path(cx))
 67            .unwrap_or_else(|| "untitled".into())
 68            .into();
 69
 70        let project_file = project::File::from_dyn(inputs.snapshot.file());
 71        let repo_name = project_file
 72            .map(|file| file.worktree.read(cx).root_name_str())
 73            .unwrap_or("untitled")
 74            .into();
 75        let offset = inputs.position.to_offset(&inputs.snapshot);
 76        let buffer_entity_id = inputs.buffer.entity_id();
 77
 78        let recent_buffers = inputs.recent_paths.iter().cloned();
 79        let http_client = cx.http_client();
 80
 81        let recent_buffer_snapshots = recent_buffers
 82            .filter_map(|project_path| {
 83                let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
 84                if inputs.buffer == buffer {
 85                    None
 86                } else {
 87                    Some(buffer.read(cx).snapshot())
 88                }
 89            })
 90            .take(3)
 91            .collect::<Vec<_>>();
 92
 93        let buffer_snapshotted_at = Instant::now();
 94
 95        let result = cx.background_spawn(async move {
 96            let text = inputs.snapshot.text();
 97
 98            let mut recent_changes = String::new();
 99            for event in &inputs.events {
100                write_event(event.as_ref(), &mut recent_changes).unwrap();
101            }
102
103            let file_chunks = recent_buffer_snapshots
104                .into_iter()
105                .map(|snapshot| {
106                    let end_point = Point::new(30, 0).min(snapshot.max_point());
107                    FileChunk {
108                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
109                        file_path: snapshot
110                            .file()
111                            .map(|f| f.path().as_unix_str())
112                            .unwrap_or("untitled")
113                            .to_string(),
114                        start_line: 0,
115                        end_line: end_point.row as usize,
116                        timestamp: snapshot.file().and_then(|file| {
117                            Some(
118                                file.disk_state()
119                                    .mtime()?
120                                    .to_seconds_and_nanos_for_persistence()?
121                                    .0,
122                            )
123                        }),
124                    }
125                })
126                .collect::<Vec<_>>();
127
128            let mut retrieval_chunks: Vec<FileChunk> = inputs
129                .related_files
130                .iter()
131                .flat_map(|related_file| {
132                    related_file.excerpts.iter().map(|excerpt| FileChunk {
133                        file_path: related_file.path.to_string_lossy().to_string(),
134                        start_line: excerpt.row_range.start as usize,
135                        end_line: excerpt.row_range.end as usize,
136                        content: excerpt.text.to_string(),
137                        timestamp: None,
138                    })
139                })
140                .collect();
141
142            let diagnostic_entries = inputs
143                .snapshot
144                .diagnostics_in_range(inputs.diagnostic_search_range, false);
145            let mut diagnostic_content = String::new();
146            let mut diagnostic_count = 0;
147
148            for entry in diagnostic_entries {
149                let start_point: Point = entry.range.start;
150
151                let severity = match entry.diagnostic.severity {
152                    DiagnosticSeverity::ERROR => "error",
153                    DiagnosticSeverity::WARNING => "warning",
154                    DiagnosticSeverity::INFORMATION => "info",
155                    DiagnosticSeverity::HINT => "hint",
156                    _ => continue,
157                };
158
159                diagnostic_count += 1;
160
161                writeln!(
162                    &mut diagnostic_content,
163                    "{}:{}:{}: {}: {}",
164                    full_path.display(),
165                    start_point.row + 1,
166                    start_point.column + 1,
167                    severity,
168                    entry.diagnostic.message
169                )?;
170            }
171
172            if !diagnostic_content.is_empty() {
173                retrieval_chunks.push(FileChunk {
174                    file_path: "diagnostics".to_string(),
175                    start_line: 1,
176                    end_line: diagnostic_count,
177                    content: diagnostic_content,
178                    timestamp: None,
179                });
180            }
181
182            let file_path_str = full_path.display().to_string();
183            let recent_user_actions = inputs
184                .user_actions
185                .iter()
186                .filter(|r| r.buffer_id == buffer_entity_id)
187                .map(|r| to_sweep_user_action(r, &file_path_str))
188                .collect();
189
190            let request_body = AutocompleteRequest {
191                debug_info,
192                repo_name,
193                file_path: full_path.clone(),
194                file_contents: text.clone(),
195                original_file_contents: text,
196                cursor_position: offset,
197                recent_changes: recent_changes.clone(),
198                changes_above_cursor: true,
199                multiple_suggestions: false,
200                branch: None,
201                file_chunks,
202                retrieval_chunks,
203                recent_user_actions,
204                use_bytes: true,
205                privacy_mode_enabled,
206            };
207
208            let mut buf: Vec<u8> = Vec::new();
209            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 1, 22);
210            serde_json::to_writer(writer, &request_body)?;
211            let body: AsyncBody = buf.into();
212
213            let ep_inputs = zeta_prompt::ZetaPromptInput {
214                events: inputs.events,
215                related_files: Some(inputs.related_files.clone()),
216                active_buffer_diagnostics: vec![],
217                cursor_path: full_path.clone(),
218                cursor_excerpt: request_body.file_contents.clone().into(),
219                cursor_offset_in_excerpt: request_body.cursor_position,
220                excerpt_start_row: Some(0),
221                excerpt_ranges: zeta_prompt::ExcerptRanges {
222                    editable_150: 0..inputs.snapshot.len(),
223                    editable_180: 0..inputs.snapshot.len(),
224                    editable_350: 0..inputs.snapshot.len(),
225                    editable_150_context_350: 0..inputs.snapshot.len(),
226                    editable_180_context_350: 0..inputs.snapshot.len(),
227                    editable_350_context_150: 0..inputs.snapshot.len(),
228                    ..Default::default()
229                },
230                syntax_ranges: None,
231                experiment: None,
232                in_open_source_repo: false,
233                can_collect_data: false,
234                repo_url: None,
235            };
236
237            send_started_event(
238                &debug_tx,
239                &buffer,
240                inputs.position,
241                serde_json::to_string(&request_body).unwrap_or_default(),
242            );
243
244            let request = http_client::Request::builder()
245                .uri(SWEEP_API_URL)
246                .header("Content-Type", "application/json")
247                .header("Authorization", format!("Bearer {}", api_token))
248                .header("Connection", "keep-alive")
249                .header("Content-Encoding", "br")
250                .method(Method::POST)
251                .body(body)?;
252
253            let mut response = http_client.send(request).await?;
254
255            let mut body = String::new();
256            response.body_mut().read_to_string(&mut body).await?;
257
258            let response_received_at = Instant::now();
259            if !response.status().is_success() {
260                let message = format!(
261                    "Request failed with status: {:?}\nBody: {}",
262                    response.status(),
263                    body,
264                );
265                send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
266                bail!(message);
267            };
268
269            let response: AutocompleteResponse = serde_json::from_str(&body)?;
270
271            send_finished_event(&debug_tx, &buffer, inputs.position, body);
272
273            let old_text = inputs
274                .snapshot
275                .text_for_range(response.start_index..response.end_index)
276                .collect::<String>();
277            let edits = language::text_diff(&old_text, &response.completion)
278                .into_iter()
279                .map(|(range, text)| {
280                    (
281                        inputs
282                            .snapshot
283                            .anchor_after(response.start_index + range.start)
284                            ..inputs
285                                .snapshot
286                                .anchor_before(response.start_index + range.end),
287                        text,
288                    )
289                })
290                .collect::<Vec<_>>();
291
292            anyhow::Ok((
293                response.autocomplete_id,
294                edits,
295                inputs.snapshot,
296                response_received_at,
297                ep_inputs,
298            ))
299        });
300
301        let buffer = inputs.buffer.clone();
302
303        cx.spawn(async move |cx| {
304            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
305            anyhow::Ok(Some(
306                EditPredictionResult::new(
307                    EditPredictionId(id.into()),
308                    &buffer,
309                    &old_snapshot,
310                    edits.into(),
311                    None,
312                    buffer_snapshotted_at,
313                    response_received_at,
314                    inputs,
315                    None,
316                    cx,
317                )
318                .await,
319            ))
320        })
321    }
322}
323
324fn send_started_event(
325    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
326    buffer: &Entity<Buffer>,
327    position: Anchor,
328    prompt: String,
329) {
330    if let Some(debug_tx) = debug_tx {
331        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
332            EditPredictionStartedDebugEvent {
333                buffer: buffer.downgrade(),
334                position,
335                prompt: Some(prompt),
336            },
337        ));
338    }
339}
340
341fn send_finished_event(
342    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
343    buffer: &Entity<Buffer>,
344    position: Anchor,
345    model_output: String,
346) {
347    if let Some(debug_tx) = debug_tx {
348        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
349            EditPredictionFinishedDebugEvent {
350                buffer: buffer.downgrade(),
351                position,
352                model_output: Some(model_output),
353            },
354        ));
355    }
356}
357
358pub const SWEEP_CREDENTIALS_URL: SharedString =
359    SharedString::new_static("https://autocomplete.sweep.dev");
360pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
361pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
362
363struct GlobalSweepApiKey(Entity<ApiKeyState>);
364
365impl Global for GlobalSweepApiKey {}
366
367pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
368    if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
369        return global.0.clone();
370    }
371    let entity =
372        cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
373    cx.set_global(GlobalSweepApiKey(entity.clone()));
374    entity
375}
376
377pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
378    sweep_api_token(cx).update(cx, |key_state, cx| {
379        key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
380    })
381}
382
383#[derive(Debug, Clone, Serialize)]
384struct AutocompleteRequest {
385    pub debug_info: Arc<str>,
386    pub repo_name: String,
387    pub branch: Option<String>,
388    pub file_path: Arc<Path>,
389    pub file_contents: String,
390    pub recent_changes: String,
391    pub cursor_position: usize,
392    pub original_file_contents: String,
393    pub file_chunks: Vec<FileChunk>,
394    pub retrieval_chunks: Vec<FileChunk>,
395    pub recent_user_actions: Vec<UserAction>,
396    pub multiple_suggestions: bool,
397    pub privacy_mode_enabled: bool,
398    pub changes_above_cursor: bool,
399    pub use_bytes: bool,
400}
401
402#[derive(Debug, Clone, Serialize)]
403struct FileChunk {
404    pub file_path: String,
405    pub start_line: usize,
406    pub end_line: usize,
407    pub content: String,
408    pub timestamp: Option<u64>,
409}
410
411#[derive(Debug, Clone, Serialize)]
412struct UserAction {
413    pub action_type: ActionType,
414    pub line_number: usize,
415    pub offset: usize,
416    pub file_path: String,
417    pub timestamp: u64,
418}
419
420#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
421#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
422enum ActionType {
423    CursorMovement,
424    InsertChar,
425    DeleteChar,
426    InsertSelection,
427    DeleteSelection,
428}
429
430fn to_sweep_user_action(record: &UserActionRecord, file_path: &str) -> UserAction {
431    UserAction {
432        action_type: match record.action_type {
433            UserActionType::InsertChar => ActionType::InsertChar,
434            UserActionType::InsertSelection => ActionType::InsertSelection,
435            UserActionType::DeleteChar => ActionType::DeleteChar,
436            UserActionType::DeleteSelection => ActionType::DeleteSelection,
437            UserActionType::CursorMovement => ActionType::CursorMovement,
438        },
439        line_number: record.line_number as usize,
440        offset: record.offset,
441        file_path: file_path.to_string(),
442        timestamp: record.timestamp_epoch_ms,
443    }
444}
445
446#[derive(Debug, Clone, Deserialize)]
447struct AutocompleteResponse {
448    pub autocomplete_id: String,
449    pub start_index: usize,
450    pub end_index: usize,
451    pub completion: String,
452    #[allow(dead_code)]
453    pub confidence: f64,
454    #[allow(dead_code)]
455    pub logprobs: Option<serde_json::Value>,
456    #[allow(dead_code)]
457    pub finish_reason: Option<String>,
458    #[allow(dead_code)]
459    pub elapsed_time_ms: u64,
460    #[allow(dead_code)]
461    #[serde(default, rename = "completions")]
462    pub additional_completions: Vec<AdditionalCompletion>,
463}
464
465#[allow(dead_code)]
466#[derive(Debug, Clone, Deserialize)]
467struct AdditionalCompletion {
468    pub start_index: usize,
469    pub end_index: usize,
470    pub completion: String,
471    pub confidence: f64,
472    pub autocomplete_id: String,
473    pub logprobs: Option<serde_json::Value>,
474    pub finish_reason: Option<String>,
475}
476
477fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
478    match event {
479        zeta_prompt::Event::BufferChange {
480            old_path,
481            path,
482            diff,
483            ..
484        } => {
485            if old_path != path {
486                // TODO confirm how to do this for sweep
487                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
488            }
489
490            if !diff.is_empty() {
491                write!(f, "File: {}:\n{}\n", path.display(), diff)?
492            }
493
494            fmt::Result::Ok(())
495        }
496    }
497}
498
499fn debug_info(cx: &gpui::App) -> Arc<str> {
500    format!(
501        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
502        version = release_channel::AppVersion::global(cx),
503        sha = release_channel::AppCommitSha::try_global(cx)
504            .map_or("unknown".to_string(), |sha| sha.full()),
505        os = client::telemetry::os_name(),
506    )
507    .into()
508}
509
510#[derive(Debug, Clone, Copy, Serialize)]
511#[serde(rename_all = "snake_case")]
512pub enum SweepEventType {
513    AutocompleteSuggestionShown,
514    AutocompleteSuggestionAccepted,
515}
516
517#[derive(Debug, Clone, Copy, Serialize)]
518#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
519pub enum SweepSuggestionType {
520    GhostText,
521    Popup,
522    JumpToEdit,
523}
524
525#[derive(Debug, Clone, Serialize)]
526struct AutocompleteMetricsRequest {
527    event_type: SweepEventType,
528    suggestion_type: SweepSuggestionType,
529    additions: u32,
530    deletions: u32,
531    autocomplete_id: String,
532    edit_tracking: String,
533    edit_tracking_line: Option<u32>,
534    lifespan: Option<u64>,
535    debug_info: Arc<str>,
536    device_id: String,
537    privacy_mode_enabled: bool,
538}
539
540fn send_autocomplete_metrics_request(
541    cx: &App,
542    client: Arc<Client>,
543    api_token: Arc<str>,
544    request_body: AutocompleteMetricsRequest,
545) {
546    let http_client = client.http_client();
547    cx.background_spawn(async move {
548        let body: AsyncBody = serde_json::to_string(&request_body)?.into();
549
550        let request = http_client::Request::builder()
551            .uri(SWEEP_METRICS_URL)
552            .header("Content-Type", "application/json")
553            .header("Authorization", format!("Bearer {}", api_token))
554            .method(Method::POST)
555            .body(body)?;
556
557        let mut response = http_client.send(request).await?;
558
559        if !response.status().is_success() {
560            let mut body = String::new();
561            response.body_mut().read_to_string(&mut body).await?;
562            anyhow::bail!(
563                "Failed to send autocomplete metrics for sweep_ai: {:?}\nBody: {}",
564                response.status(),
565                body,
566            );
567        }
568
569        Ok(())
570    })
571    .detach_and_log_err(cx);
572}
573
574pub(crate) fn edit_prediction_accepted(
575    store: &EditPredictionStore,
576    current_prediction: CurrentEditPrediction,
577    cx: &App,
578) {
579    let Some(api_token) = store
580        .sweep_ai
581        .api_token
582        .read(cx)
583        .key(&SWEEP_CREDENTIALS_URL)
584    else {
585        return;
586    };
587    let debug_info = store.sweep_ai.debug_info.clone();
588
589    let prediction = current_prediction.prediction;
590
591    let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
592    let autocomplete_id = prediction.id.to_string();
593
594    let device_id = store
595        .client
596        .user_id()
597        .as_ref()
598        .map(ToString::to_string)
599        .unwrap_or_default();
600
601    let suggestion_type = match current_prediction.shown_with {
602        Some(SuggestionDisplayType::DiffPopover) => SweepSuggestionType::Popup,
603        Some(SuggestionDisplayType::Jump) => return, // should'nt happen
604        Some(SuggestionDisplayType::GhostText) | None => SweepSuggestionType::GhostText,
605    };
606
607    let request_body = AutocompleteMetricsRequest {
608        event_type: SweepEventType::AutocompleteSuggestionAccepted,
609        suggestion_type,
610        additions,
611        deletions,
612        autocomplete_id,
613        edit_tracking: String::new(),
614        edit_tracking_line: None,
615        lifespan: None,
616        debug_info,
617        device_id,
618        privacy_mode_enabled: false,
619    };
620
621    send_autocomplete_metrics_request(cx, store.client.clone(), api_token, request_body);
622}
623
624pub fn edit_prediction_shown(
625    sweep_ai: &SweepAi,
626    client: Arc<Client>,
627    prediction: &EditPrediction,
628    display_type: SuggestionDisplayType,
629    cx: &App,
630) {
631    let Some(api_token) = sweep_ai.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
632        return;
633    };
634    let debug_info = sweep_ai.debug_info.clone();
635
636    let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
637    let autocomplete_id = prediction.id.to_string();
638
639    let suggestion_type = match display_type {
640        SuggestionDisplayType::GhostText => SweepSuggestionType::GhostText,
641        SuggestionDisplayType::DiffPopover => SweepSuggestionType::Popup,
642        SuggestionDisplayType::Jump => SweepSuggestionType::JumpToEdit,
643    };
644
645    let request_body = AutocompleteMetricsRequest {
646        event_type: SweepEventType::AutocompleteSuggestionShown,
647        suggestion_type,
648        additions,
649        deletions,
650        autocomplete_id,
651        edit_tracking: String::new(),
652        edit_tracking_line: None,
653        lifespan: None,
654        debug_info,
655        device_id: String::new(),
656        privacy_mode_enabled: false,
657    };
658
659    send_autocomplete_metrics_request(cx, client, api_token, request_body);
660}
661
662fn compute_edit_metrics(
663    edits: &[(Range<Anchor>, Arc<str>)],
664    snapshot: &BufferSnapshot,
665) -> (u32, u32) {
666    let mut additions = 0u32;
667    let mut deletions = 0u32;
668
669    for (range, new_text) in edits {
670        let old_text = snapshot.text_for_range(range.clone());
671        deletions += old_text
672            .map(|chunk| chunk.lines().count())
673            .sum::<usize>()
674            .max(1) as u32;
675        additions += new_text.lines().count().max(1) as u32;
676    }
677
678    (additions, deletions)
679}