sweep_ai.rs

  1use crate::{
  2    CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
  3    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  4    EditPredictionStore, UserActionRecord, UserActionType, prediction::EditPredictionResult,
  5};
  6use anyhow::{Result, bail};
  7use client::Client;
  8use edit_prediction_types::SuggestionDisplayType;
  9use futures::{AsyncReadExt as _, channel::mpsc};
 10use gpui::{
 11    App, AppContext as _, Entity, Global, SharedString, Task,
 12    http_client::{self, AsyncBody, Method},
 13};
 14use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
 15use language_model::{ApiKeyState, EnvVar, env_var};
 16use lsp::DiagnosticSeverity;
 17use serde::{Deserialize, Serialize};
 18use std::{
 19    fmt::{self, Write as _},
 20    ops::Range,
 21    path::Path,
 22    sync::Arc,
 23    time::Instant,
 24};
 25
 26const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 27const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
 28
 29pub struct SweepAi {
 30    pub api_token: Entity<ApiKeyState>,
 31    pub debug_info: Arc<str>,
 32}
 33
 34impl SweepAi {
 35    pub fn new(cx: &mut App) -> Self {
 36        SweepAi {
 37            api_token: sweep_api_token(cx),
 38            debug_info: debug_info(cx),
 39        }
 40    }
 41
 42    pub fn request_prediction_with_sweep(
 43        &self,
 44        inputs: EditPredictionModelInput,
 45        cx: &mut App,
 46    ) -> Task<Result<Option<EditPredictionResult>>> {
 47        let debug_info = self.debug_info.clone();
 48        self.api_token.update(cx, |key_state, cx| {
 49            _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
 50        });
 51
 52        let buffer = inputs.buffer.clone();
 53        let debug_tx = inputs.debug_tx.clone();
 54
 55        let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
 56            return Task::ready(Ok(None));
 57        };
 58        let full_path: Arc<Path> = inputs
 59            .snapshot
 60            .file()
 61            .map(|file| file.full_path(cx))
 62            .unwrap_or_else(|| "untitled".into())
 63            .into();
 64
 65        let project_file = project::File::from_dyn(inputs.snapshot.file());
 66        let repo_name = project_file
 67            .map(|file| file.worktree.read(cx).root_name_str())
 68            .unwrap_or("untitled")
 69            .into();
 70        let offset = inputs.position.to_offset(&inputs.snapshot);
 71        let buffer_entity_id = inputs.buffer.entity_id();
 72
 73        let recent_buffers = inputs.recent_paths.iter().cloned();
 74        let http_client = cx.http_client();
 75
 76        let recent_buffer_snapshots = recent_buffers
 77            .filter_map(|project_path| {
 78                let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
 79                if inputs.buffer == buffer {
 80                    None
 81                } else {
 82                    Some(buffer.read(cx).snapshot())
 83                }
 84            })
 85            .take(3)
 86            .collect::<Vec<_>>();
 87
 88        let buffer_snapshotted_at = Instant::now();
 89
 90        let result = cx.background_spawn(async move {
 91            let text = inputs.snapshot.text();
 92
 93            let mut recent_changes = String::new();
 94            for event in &inputs.events {
 95                write_event(event.as_ref(), &mut recent_changes).unwrap();
 96            }
 97
 98            let file_chunks = recent_buffer_snapshots
 99                .into_iter()
100                .map(|snapshot| {
101                    let end_point = Point::new(30, 0).min(snapshot.max_point());
102                    FileChunk {
103                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
104                        file_path: snapshot
105                            .file()
106                            .map(|f| f.path().as_unix_str())
107                            .unwrap_or("untitled")
108                            .to_string(),
109                        start_line: 0,
110                        end_line: end_point.row as usize,
111                        timestamp: snapshot.file().and_then(|file| {
112                            Some(
113                                file.disk_state()
114                                    .mtime()?
115                                    .to_seconds_and_nanos_for_persistence()?
116                                    .0,
117                            )
118                        }),
119                    }
120                })
121                .collect::<Vec<_>>();
122
123            let mut retrieval_chunks: Vec<FileChunk> = inputs
124                .related_files
125                .iter()
126                .flat_map(|related_file| {
127                    related_file.excerpts.iter().map(|excerpt| FileChunk {
128                        file_path: related_file.path.to_string_lossy().to_string(),
129                        start_line: excerpt.row_range.start as usize,
130                        end_line: excerpt.row_range.end as usize,
131                        content: excerpt.text.to_string(),
132                        timestamp: None,
133                    })
134                })
135                .collect();
136
137            let diagnostic_entries = inputs
138                .snapshot
139                .diagnostics_in_range(inputs.diagnostic_search_range, false);
140            let mut diagnostic_content = String::new();
141            let mut diagnostic_count = 0;
142
143            for entry in diagnostic_entries {
144                let start_point: Point = entry.range.start;
145
146                let severity = match entry.diagnostic.severity {
147                    DiagnosticSeverity::ERROR => "error",
148                    DiagnosticSeverity::WARNING => "warning",
149                    DiagnosticSeverity::INFORMATION => "info",
150                    DiagnosticSeverity::HINT => "hint",
151                    _ => continue,
152                };
153
154                diagnostic_count += 1;
155
156                writeln!(
157                    &mut diagnostic_content,
158                    "{}:{}:{}: {}: {}",
159                    full_path.display(),
160                    start_point.row + 1,
161                    start_point.column + 1,
162                    severity,
163                    entry.diagnostic.message
164                )?;
165            }
166
167            if !diagnostic_content.is_empty() {
168                retrieval_chunks.push(FileChunk {
169                    file_path: "diagnostics".to_string(),
170                    start_line: 1,
171                    end_line: diagnostic_count,
172                    content: diagnostic_content,
173                    timestamp: None,
174                });
175            }
176
177            let file_path_str = full_path.display().to_string();
178            let recent_user_actions = inputs
179                .user_actions
180                .iter()
181                .filter(|r| r.buffer_id == buffer_entity_id)
182                .map(|r| to_sweep_user_action(r, &file_path_str))
183                .collect();
184
185            let request_body = AutocompleteRequest {
186                debug_info,
187                repo_name,
188                file_path: full_path.clone(),
189                file_contents: text.clone(),
190                original_file_contents: text,
191                cursor_position: offset,
192                recent_changes: recent_changes.clone(),
193                changes_above_cursor: true,
194                multiple_suggestions: false,
195                branch: None,
196                file_chunks,
197                retrieval_chunks,
198                recent_user_actions,
199                use_bytes: true,
200                // TODO
201                privacy_mode_enabled: false,
202            };
203
204            let mut buf: Vec<u8> = Vec::new();
205            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 1, 22);
206            serde_json::to_writer(writer, &request_body)?;
207            let body: AsyncBody = buf.into();
208
209            let ep_inputs = zeta_prompt::ZetaPromptInput {
210                events: inputs.events,
211                related_files: inputs.related_files.clone(),
212                cursor_path: full_path.clone(),
213                cursor_excerpt: request_body.file_contents.clone().into(),
214                // we actually don't know
215                editable_range_in_excerpt: 0..inputs.snapshot.len(),
216                cursor_offset_in_excerpt: request_body.cursor_position,
217            };
218
219            send_started_event(
220                &debug_tx,
221                &buffer,
222                inputs.position,
223                serde_json::to_string(&request_body).unwrap_or_default(),
224            );
225
226            let request = http_client::Request::builder()
227                .uri(SWEEP_API_URL)
228                .header("Content-Type", "application/json")
229                .header("Authorization", format!("Bearer {}", api_token))
230                .header("Connection", "keep-alive")
231                .header("Content-Encoding", "br")
232                .method(Method::POST)
233                .body(body)?;
234
235            let mut response = http_client.send(request).await?;
236
237            let mut body = String::new();
238            response.body_mut().read_to_string(&mut body).await?;
239
240            let response_received_at = Instant::now();
241            if !response.status().is_success() {
242                let message = format!(
243                    "Request failed with status: {:?}\nBody: {}",
244                    response.status(),
245                    body,
246                );
247                send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
248                bail!(message);
249            };
250
251            let response: AutocompleteResponse = serde_json::from_str(&body)?;
252
253            send_finished_event(&debug_tx, &buffer, inputs.position, body);
254
255            let old_text = inputs
256                .snapshot
257                .text_for_range(response.start_index..response.end_index)
258                .collect::<String>();
259            let edits = language::text_diff(&old_text, &response.completion)
260                .into_iter()
261                .map(|(range, text)| {
262                    (
263                        inputs
264                            .snapshot
265                            .anchor_after(response.start_index + range.start)
266                            ..inputs
267                                .snapshot
268                                .anchor_before(response.start_index + range.end),
269                        text,
270                    )
271                })
272                .collect::<Vec<_>>();
273
274            anyhow::Ok((
275                response.autocomplete_id,
276                edits,
277                inputs.snapshot,
278                response_received_at,
279                ep_inputs,
280            ))
281        });
282
283        let buffer = inputs.buffer.clone();
284
285        cx.spawn(async move |cx| {
286            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
287            anyhow::Ok(Some(
288                EditPredictionResult::new(
289                    EditPredictionId(id.into()),
290                    &buffer,
291                    &old_snapshot,
292                    edits.into(),
293                    buffer_snapshotted_at,
294                    response_received_at,
295                    inputs,
296                    cx,
297                )
298                .await,
299            ))
300        })
301    }
302}
303
304fn send_started_event(
305    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
306    buffer: &Entity<Buffer>,
307    position: Anchor,
308    prompt: String,
309) {
310    if let Some(debug_tx) = debug_tx {
311        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
312            EditPredictionStartedDebugEvent {
313                buffer: buffer.downgrade(),
314                position,
315                prompt: Some(prompt),
316            },
317        ));
318    }
319}
320
321fn send_finished_event(
322    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
323    buffer: &Entity<Buffer>,
324    position: Anchor,
325    model_output: String,
326) {
327    if let Some(debug_tx) = debug_tx {
328        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
329            EditPredictionFinishedDebugEvent {
330                buffer: buffer.downgrade(),
331                position,
332                model_output: Some(model_output),
333            },
334        ));
335    }
336}
337
338pub const SWEEP_CREDENTIALS_URL: SharedString =
339    SharedString::new_static("https://autocomplete.sweep.dev");
340pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
341pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
342
343struct GlobalSweepApiKey(Entity<ApiKeyState>);
344
345impl Global for GlobalSweepApiKey {}
346
347pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
348    if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
349        return global.0.clone();
350    }
351    let entity =
352        cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
353    cx.set_global(GlobalSweepApiKey(entity.clone()));
354    entity
355}
356
357pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
358    sweep_api_token(cx).update(cx, |key_state, cx| {
359        key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
360    })
361}
362
363#[derive(Debug, Clone, Serialize)]
364struct AutocompleteRequest {
365    pub debug_info: Arc<str>,
366    pub repo_name: String,
367    pub branch: Option<String>,
368    pub file_path: Arc<Path>,
369    pub file_contents: String,
370    pub recent_changes: String,
371    pub cursor_position: usize,
372    pub original_file_contents: String,
373    pub file_chunks: Vec<FileChunk>,
374    pub retrieval_chunks: Vec<FileChunk>,
375    pub recent_user_actions: Vec<UserAction>,
376    pub multiple_suggestions: bool,
377    pub privacy_mode_enabled: bool,
378    pub changes_above_cursor: bool,
379    pub use_bytes: bool,
380}
381
382#[derive(Debug, Clone, Serialize)]
383struct FileChunk {
384    pub file_path: String,
385    pub start_line: usize,
386    pub end_line: usize,
387    pub content: String,
388    pub timestamp: Option<u64>,
389}
390
391#[derive(Debug, Clone, Serialize)]
392struct UserAction {
393    pub action_type: ActionType,
394    pub line_number: usize,
395    pub offset: usize,
396    pub file_path: String,
397    pub timestamp: u64,
398}
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
401#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
402enum ActionType {
403    CursorMovement,
404    InsertChar,
405    DeleteChar,
406    InsertSelection,
407    DeleteSelection,
408}
409
410fn to_sweep_user_action(record: &UserActionRecord, file_path: &str) -> UserAction {
411    UserAction {
412        action_type: match record.action_type {
413            UserActionType::InsertChar => ActionType::InsertChar,
414            UserActionType::InsertSelection => ActionType::InsertSelection,
415            UserActionType::DeleteChar => ActionType::DeleteChar,
416            UserActionType::DeleteSelection => ActionType::DeleteSelection,
417            UserActionType::CursorMovement => ActionType::CursorMovement,
418        },
419        line_number: record.line_number as usize,
420        offset: record.offset,
421        file_path: file_path.to_string(),
422        timestamp: record.timestamp_epoch_ms,
423    }
424}
425
426#[derive(Debug, Clone, Deserialize)]
427struct AutocompleteResponse {
428    pub autocomplete_id: String,
429    pub start_index: usize,
430    pub end_index: usize,
431    pub completion: String,
432    #[allow(dead_code)]
433    pub confidence: f64,
434    #[allow(dead_code)]
435    pub logprobs: Option<serde_json::Value>,
436    #[allow(dead_code)]
437    pub finish_reason: Option<String>,
438    #[allow(dead_code)]
439    pub elapsed_time_ms: u64,
440    #[allow(dead_code)]
441    #[serde(default, rename = "completions")]
442    pub additional_completions: Vec<AdditionalCompletion>,
443}
444
445#[allow(dead_code)]
446#[derive(Debug, Clone, Deserialize)]
447struct AdditionalCompletion {
448    pub start_index: usize,
449    pub end_index: usize,
450    pub completion: String,
451    pub confidence: f64,
452    pub autocomplete_id: String,
453    pub logprobs: Option<serde_json::Value>,
454    pub finish_reason: Option<String>,
455}
456
457fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
458    match event {
459        zeta_prompt::Event::BufferChange {
460            old_path,
461            path,
462            diff,
463            ..
464        } => {
465            if old_path != path {
466                // TODO confirm how to do this for sweep
467                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
468            }
469
470            if !diff.is_empty() {
471                write!(f, "File: {}:\n{}\n", path.display(), diff)?
472            }
473
474            fmt::Result::Ok(())
475        }
476    }
477}
478
479fn debug_info(cx: &gpui::App) -> Arc<str> {
480    format!(
481        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
482        version = release_channel::AppVersion::global(cx),
483        sha = release_channel::AppCommitSha::try_global(cx)
484            .map_or("unknown".to_string(), |sha| sha.full()),
485        os = client::telemetry::os_name(),
486    )
487    .into()
488}
489
490#[derive(Debug, Clone, Copy, Serialize)]
491#[serde(rename_all = "snake_case")]
492pub enum SweepEventType {
493    AutocompleteSuggestionShown,
494    AutocompleteSuggestionAccepted,
495}
496
497#[derive(Debug, Clone, Copy, Serialize)]
498#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
499pub enum SweepSuggestionType {
500    GhostText,
501    Popup,
502    JumpToEdit,
503}
504
505#[derive(Debug, Clone, Serialize)]
506struct AutocompleteMetricsRequest {
507    event_type: SweepEventType,
508    suggestion_type: SweepSuggestionType,
509    additions: u32,
510    deletions: u32,
511    autocomplete_id: String,
512    edit_tracking: String,
513    edit_tracking_line: Option<u32>,
514    lifespan: Option<u64>,
515    debug_info: Arc<str>,
516    device_id: String,
517    privacy_mode_enabled: bool,
518}
519
520fn send_autocomplete_metrics_request(
521    cx: &App,
522    client: Arc<Client>,
523    api_token: Arc<str>,
524    request_body: AutocompleteMetricsRequest,
525) {
526    let http_client = client.http_client();
527    cx.background_spawn(async move {
528        let body: AsyncBody = serde_json::to_string(&request_body)?.into();
529
530        let request = http_client::Request::builder()
531            .uri(SWEEP_METRICS_URL)
532            .header("Content-Type", "application/json")
533            .header("Authorization", format!("Bearer {}", api_token))
534            .method(Method::POST)
535            .body(body)?;
536
537        let mut response = http_client.send(request).await?;
538
539        if !response.status().is_success() {
540            let mut body = String::new();
541            response.body_mut().read_to_string(&mut body).await?;
542            anyhow::bail!(
543                "Failed to send autocomplete metrics for sweep_ai: {:?}\nBody: {}",
544                response.status(),
545                body,
546            );
547        }
548
549        Ok(())
550    })
551    .detach_and_log_err(cx);
552}
553
554pub(crate) fn edit_prediction_accepted(
555    store: &EditPredictionStore,
556    current_prediction: CurrentEditPrediction,
557    cx: &App,
558) {
559    let Some(api_token) = store
560        .sweep_ai
561        .api_token
562        .read(cx)
563        .key(&SWEEP_CREDENTIALS_URL)
564    else {
565        return;
566    };
567    let debug_info = store.sweep_ai.debug_info.clone();
568
569    let prediction = current_prediction.prediction;
570
571    let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
572    let autocomplete_id = prediction.id.to_string();
573
574    let device_id = store
575        .client
576        .user_id()
577        .as_ref()
578        .map(ToString::to_string)
579        .unwrap_or_default();
580
581    let suggestion_type = match current_prediction.shown_with {
582        Some(SuggestionDisplayType::DiffPopover) => SweepSuggestionType::Popup,
583        Some(SuggestionDisplayType::Jump) => return, // should'nt happen
584        Some(SuggestionDisplayType::GhostText) | None => SweepSuggestionType::GhostText,
585    };
586
587    let request_body = AutocompleteMetricsRequest {
588        event_type: SweepEventType::AutocompleteSuggestionAccepted,
589        suggestion_type,
590        additions,
591        deletions,
592        autocomplete_id,
593        edit_tracking: String::new(),
594        edit_tracking_line: None,
595        lifespan: None,
596        debug_info,
597        device_id,
598        privacy_mode_enabled: false,
599    };
600
601    send_autocomplete_metrics_request(cx, store.client.clone(), api_token, request_body);
602}
603
604pub fn edit_prediction_shown(
605    sweep_ai: &SweepAi,
606    client: Arc<Client>,
607    prediction: &EditPrediction,
608    display_type: SuggestionDisplayType,
609    cx: &App,
610) {
611    let Some(api_token) = sweep_ai.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
612        return;
613    };
614    let debug_info = sweep_ai.debug_info.clone();
615
616    let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
617    let autocomplete_id = prediction.id.to_string();
618
619    let suggestion_type = match display_type {
620        SuggestionDisplayType::GhostText => SweepSuggestionType::GhostText,
621        SuggestionDisplayType::DiffPopover => SweepSuggestionType::Popup,
622        SuggestionDisplayType::Jump => SweepSuggestionType::JumpToEdit,
623    };
624
625    let request_body = AutocompleteMetricsRequest {
626        event_type: SweepEventType::AutocompleteSuggestionShown,
627        suggestion_type,
628        additions,
629        deletions,
630        autocomplete_id,
631        edit_tracking: String::new(),
632        edit_tracking_line: None,
633        lifespan: None,
634        debug_info,
635        device_id: String::new(),
636        privacy_mode_enabled: false,
637    };
638
639    send_autocomplete_metrics_request(cx, client, api_token, request_body);
640}
641
642fn compute_edit_metrics(
643    edits: &[(Range<Anchor>, Arc<str>)],
644    snapshot: &BufferSnapshot,
645) -> (u32, u32) {
646    let mut additions = 0u32;
647    let mut deletions = 0u32;
648
649    for (range, new_text) in edits {
650        let old_text = snapshot.text_for_range(range.clone());
651        deletions += old_text
652            .map(|chunk| chunk.lines().count())
653            .sum::<usize>()
654            .max(1) as u32;
655        additions += new_text.lines().count().max(1) as u32;
656    }
657
658    (additions, deletions)
659}