sweep_ai.rs

  1use crate::{
  2    CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
  3    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  4    EditPredictionStore, prediction::EditPredictionResult,
  5};
  6use anyhow::{Result, bail};
  7use client::Client;
  8use edit_prediction_types::SuggestionDisplayType;
  9use futures::{AsyncReadExt as _, channel::mpsc};
 10use gpui::{
 11    App, AppContext as _, Entity, Global, SharedString, Task,
 12    http_client::{self, AsyncBody, Method},
 13};
 14use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
 15use language_model::{ApiKeyState, EnvVar, env_var};
 16use lsp::DiagnosticSeverity;
 17use serde::{Deserialize, Serialize};
 18use std::{
 19    fmt::{self, Write as _},
 20    ops::Range,
 21    path::Path,
 22    sync::Arc,
 23    time::Instant,
 24};
 25
 26const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 27const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
 28
 29pub struct SweepAi {
 30    pub api_token: Entity<ApiKeyState>,
 31    pub debug_info: Arc<str>,
 32}
 33
 34impl SweepAi {
 35    pub fn new(cx: &mut App) -> Self {
 36        SweepAi {
 37            api_token: sweep_api_token(cx),
 38            debug_info: debug_info(cx),
 39        }
 40    }
 41
 42    pub fn request_prediction_with_sweep(
 43        &self,
 44        inputs: EditPredictionModelInput,
 45        cx: &mut App,
 46    ) -> Task<Result<Option<EditPredictionResult>>> {
 47        let debug_info = self.debug_info.clone();
 48        self.api_token.update(cx, |key_state, cx| {
 49            _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
 50        });
 51
 52        let buffer = inputs.buffer.clone();
 53        let debug_tx = inputs.debug_tx.clone();
 54
 55        let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
 56            return Task::ready(Ok(None));
 57        };
 58        let full_path: Arc<Path> = inputs
 59            .snapshot
 60            .file()
 61            .map(|file| file.full_path(cx))
 62            .unwrap_or_else(|| "untitled".into())
 63            .into();
 64
 65        let project_file = project::File::from_dyn(inputs.snapshot.file());
 66        let repo_name = project_file
 67            .map(|file| file.worktree.read(cx).root_name_str())
 68            .unwrap_or("untitled")
 69            .into();
 70        let offset = inputs.position.to_offset(&inputs.snapshot);
 71
 72        let recent_buffers = inputs.recent_paths.iter().cloned();
 73        let http_client = cx.http_client();
 74
 75        let recent_buffer_snapshots = recent_buffers
 76            .filter_map(|project_path| {
 77                let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
 78                if inputs.buffer == buffer {
 79                    None
 80                } else {
 81                    Some(buffer.read(cx).snapshot())
 82                }
 83            })
 84            .take(3)
 85            .collect::<Vec<_>>();
 86
 87        let buffer_snapshotted_at = Instant::now();
 88
 89        let result = cx.background_spawn(async move {
 90            let text = inputs.snapshot.text();
 91
 92            let mut recent_changes = String::new();
 93            for event in &inputs.events {
 94                write_event(event.as_ref(), &mut recent_changes).unwrap();
 95            }
 96
 97            let mut file_chunks = recent_buffer_snapshots
 98                .into_iter()
 99                .map(|snapshot| {
100                    let end_point = Point::new(30, 0).min(snapshot.max_point());
101                    FileChunk {
102                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
103                        file_path: snapshot
104                            .file()
105                            .map(|f| f.path().as_unix_str())
106                            .unwrap_or("untitled")
107                            .to_string(),
108                        start_line: 0,
109                        end_line: end_point.row as usize,
110                        timestamp: snapshot.file().and_then(|file| {
111                            Some(
112                                file.disk_state()
113                                    .mtime()?
114                                    .to_seconds_and_nanos_for_persistence()?
115                                    .0,
116                            )
117                        }),
118                    }
119                })
120                .collect::<Vec<_>>();
121
122            let retrieval_chunks = inputs
123                .related_files
124                .iter()
125                .flat_map(|related_file| {
126                    related_file.excerpts.iter().map(|excerpt| FileChunk {
127                        file_path: related_file.path.to_string_lossy().to_string(),
128                        start_line: excerpt.row_range.start as usize,
129                        end_line: excerpt.row_range.end as usize,
130                        content: excerpt.text.to_string(),
131                        timestamp: None,
132                    })
133                })
134                .collect();
135
136            let diagnostic_entries = inputs
137                .snapshot
138                .diagnostics_in_range(inputs.diagnostic_search_range, false);
139            let mut diagnostic_content = String::new();
140            let mut diagnostic_count = 0;
141
142            for entry in diagnostic_entries {
143                let start_point: Point = entry.range.start;
144
145                let severity = match entry.diagnostic.severity {
146                    DiagnosticSeverity::ERROR => "error",
147                    DiagnosticSeverity::WARNING => "warning",
148                    DiagnosticSeverity::INFORMATION => "info",
149                    DiagnosticSeverity::HINT => "hint",
150                    _ => continue,
151                };
152
153                diagnostic_count += 1;
154
155                writeln!(
156                    &mut diagnostic_content,
157                    "{} at line {}: {}",
158                    severity,
159                    start_point.row + 1,
160                    entry.diagnostic.message
161                )?;
162            }
163
164            if !diagnostic_content.is_empty() {
165                file_chunks.push(FileChunk {
166                    file_path: format!("Diagnostics for {}", full_path.display()),
167                    start_line: 0,
168                    end_line: diagnostic_count,
169                    content: diagnostic_content,
170                    timestamp: None,
171                });
172            }
173
174            let request_body = AutocompleteRequest {
175                debug_info,
176                repo_name,
177                file_path: full_path.clone(),
178                file_contents: text.clone(),
179                original_file_contents: text,
180                cursor_position: offset,
181                recent_changes: recent_changes.clone(),
182                changes_above_cursor: true,
183                multiple_suggestions: false,
184                branch: None,
185                file_chunks,
186                retrieval_chunks,
187                recent_user_actions: vec![],
188                use_bytes: true,
189                // TODO
190                privacy_mode_enabled: false,
191            };
192
193            let mut buf: Vec<u8> = Vec::new();
194            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
195            serde_json::to_writer(writer, &request_body)?;
196            let body: AsyncBody = buf.into();
197
198            let ep_inputs = zeta_prompt::ZetaPromptInput {
199                events: inputs.events,
200                related_files: inputs.related_files.clone(),
201                cursor_path: full_path.clone(),
202                cursor_excerpt: request_body.file_contents.clone().into(),
203                // we actually don't know
204                editable_range_in_excerpt: 0..inputs.snapshot.len(),
205                cursor_offset_in_excerpt: request_body.cursor_position,
206            };
207
208            send_started_event(
209                &debug_tx,
210                &buffer,
211                inputs.position,
212                serde_json::to_string(&request_body).unwrap_or_default(),
213            );
214
215            let request = http_client::Request::builder()
216                .uri(SWEEP_API_URL)
217                .header("Content-Type", "application/json")
218                .header("Authorization", format!("Bearer {}", api_token))
219                .header("Connection", "keep-alive")
220                .header("Content-Encoding", "br")
221                .method(Method::POST)
222                .body(body)?;
223
224            let mut response = http_client.send(request).await?;
225
226            let mut body = String::new();
227            response.body_mut().read_to_string(&mut body).await?;
228
229            let response_received_at = Instant::now();
230            if !response.status().is_success() {
231                let message = format!(
232                    "Request failed with status: {:?}\nBody: {}",
233                    response.status(),
234                    body,
235                );
236                send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
237                bail!(message);
238            };
239
240            let response: AutocompleteResponse = serde_json::from_str(&body)?;
241
242            send_finished_event(&debug_tx, &buffer, inputs.position, body);
243
244            let old_text = inputs
245                .snapshot
246                .text_for_range(response.start_index..response.end_index)
247                .collect::<String>();
248            let edits = language::text_diff(&old_text, &response.completion)
249                .into_iter()
250                .map(|(range, text)| {
251                    (
252                        inputs
253                            .snapshot
254                            .anchor_after(response.start_index + range.start)
255                            ..inputs
256                                .snapshot
257                                .anchor_before(response.start_index + range.end),
258                        text,
259                    )
260                })
261                .collect::<Vec<_>>();
262
263            anyhow::Ok((
264                response.autocomplete_id,
265                edits,
266                inputs.snapshot,
267                response_received_at,
268                ep_inputs,
269            ))
270        });
271
272        let buffer = inputs.buffer.clone();
273
274        cx.spawn(async move |cx| {
275            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
276            anyhow::Ok(Some(
277                EditPredictionResult::new(
278                    EditPredictionId(id.into()),
279                    &buffer,
280                    &old_snapshot,
281                    edits.into(),
282                    buffer_snapshotted_at,
283                    response_received_at,
284                    inputs,
285                    cx,
286                )
287                .await,
288            ))
289        })
290    }
291}
292
293fn send_started_event(
294    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
295    buffer: &Entity<Buffer>,
296    position: Anchor,
297    prompt: String,
298) {
299    if let Some(debug_tx) = debug_tx {
300        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
301            EditPredictionStartedDebugEvent {
302                buffer: buffer.downgrade(),
303                position,
304                prompt: Some(prompt),
305            },
306        ));
307    }
308}
309
310fn send_finished_event(
311    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
312    buffer: &Entity<Buffer>,
313    position: Anchor,
314    model_output: String,
315) {
316    if let Some(debug_tx) = debug_tx {
317        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
318            EditPredictionFinishedDebugEvent {
319                buffer: buffer.downgrade(),
320                position,
321                model_output: Some(model_output),
322            },
323        ));
324    }
325}
326
327pub const SWEEP_CREDENTIALS_URL: SharedString =
328    SharedString::new_static("https://autocomplete.sweep.dev");
329pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
330pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
331
332struct GlobalSweepApiKey(Entity<ApiKeyState>);
333
334impl Global for GlobalSweepApiKey {}
335
336pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
337    if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
338        return global.0.clone();
339    }
340    let entity =
341        cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
342    cx.set_global(GlobalSweepApiKey(entity.clone()));
343    entity
344}
345
346pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
347    sweep_api_token(cx).update(cx, |key_state, cx| {
348        key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
349    })
350}
351
352#[derive(Debug, Clone, Serialize)]
353struct AutocompleteRequest {
354    pub debug_info: Arc<str>,
355    pub repo_name: String,
356    pub branch: Option<String>,
357    pub file_path: Arc<Path>,
358    pub file_contents: String,
359    pub recent_changes: String,
360    pub cursor_position: usize,
361    pub original_file_contents: String,
362    pub file_chunks: Vec<FileChunk>,
363    pub retrieval_chunks: Vec<FileChunk>,
364    pub recent_user_actions: Vec<UserAction>,
365    pub multiple_suggestions: bool,
366    pub privacy_mode_enabled: bool,
367    pub changes_above_cursor: bool,
368    pub use_bytes: bool,
369}
370
371#[derive(Debug, Clone, Serialize)]
372struct FileChunk {
373    pub file_path: String,
374    pub start_line: usize,
375    pub end_line: usize,
376    pub content: String,
377    pub timestamp: Option<u64>,
378}
379
380#[derive(Debug, Clone, Serialize)]
381struct UserAction {
382    pub action_type: ActionType,
383    pub line_number: usize,
384    pub offset: usize,
385    pub file_path: String,
386    pub timestamp: u64,
387}
388
389#[allow(dead_code)]
390#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
391#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
392enum ActionType {
393    CursorMovement,
394    InsertChar,
395    DeleteChar,
396    InsertSelection,
397    DeleteSelection,
398}
399
400#[derive(Debug, Clone, Deserialize)]
401struct AutocompleteResponse {
402    pub autocomplete_id: String,
403    pub start_index: usize,
404    pub end_index: usize,
405    pub completion: String,
406    #[allow(dead_code)]
407    pub confidence: f64,
408    #[allow(dead_code)]
409    pub logprobs: Option<serde_json::Value>,
410    #[allow(dead_code)]
411    pub finish_reason: Option<String>,
412    #[allow(dead_code)]
413    pub elapsed_time_ms: u64,
414    #[allow(dead_code)]
415    #[serde(default, rename = "completions")]
416    pub additional_completions: Vec<AdditionalCompletion>,
417}
418
419#[allow(dead_code)]
420#[derive(Debug, Clone, Deserialize)]
421struct AdditionalCompletion {
422    pub start_index: usize,
423    pub end_index: usize,
424    pub completion: String,
425    pub confidence: f64,
426    pub autocomplete_id: String,
427    pub logprobs: Option<serde_json::Value>,
428    pub finish_reason: Option<String>,
429}
430
431fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
432    match event {
433        zeta_prompt::Event::BufferChange {
434            old_path,
435            path,
436            diff,
437            ..
438        } => {
439            if old_path != path {
440                // TODO confirm how to do this for sweep
441                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
442            }
443
444            if !diff.is_empty() {
445                write!(f, "File: {}:\n{}\n", path.display(), diff)?
446            }
447
448            fmt::Result::Ok(())
449        }
450    }
451}
452
453fn debug_info(cx: &gpui::App) -> Arc<str> {
454    format!(
455        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
456        version = release_channel::AppVersion::global(cx),
457        sha = release_channel::AppCommitSha::try_global(cx)
458            .map_or("unknown".to_string(), |sha| sha.full()),
459        os = client::telemetry::os_name(),
460    )
461    .into()
462}
463
464#[derive(Debug, Clone, Copy, Serialize)]
465#[serde(rename_all = "snake_case")]
466pub enum SweepEventType {
467    AutocompleteSuggestionShown,
468    AutocompleteSuggestionAccepted,
469}
470
471#[derive(Debug, Clone, Copy, Serialize)]
472#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
473pub enum SweepSuggestionType {
474    GhostText,
475    Popup,
476    JumpToEdit,
477}
478
479#[derive(Debug, Clone, Serialize)]
480struct AutocompleteMetricsRequest {
481    event_type: SweepEventType,
482    suggestion_type: SweepSuggestionType,
483    additions: u32,
484    deletions: u32,
485    autocomplete_id: String,
486    edit_tracking: String,
487    edit_tracking_line: Option<u32>,
488    lifespan: Option<u64>,
489    debug_info: Arc<str>,
490    device_id: String,
491    privacy_mode_enabled: bool,
492}
493
494fn send_autocomplete_metrics_request(
495    cx: &App,
496    client: Arc<Client>,
497    api_token: Arc<str>,
498    request_body: AutocompleteMetricsRequest,
499) {
500    let http_client = client.http_client();
501    cx.background_spawn(async move {
502        let body: AsyncBody = serde_json::to_string(&request_body)?.into();
503
504        let request = http_client::Request::builder()
505            .uri(SWEEP_METRICS_URL)
506            .header("Content-Type", "application/json")
507            .header("Authorization", format!("Bearer {}", api_token))
508            .method(Method::POST)
509            .body(body)?;
510
511        let mut response = http_client.send(request).await?;
512
513        if !response.status().is_success() {
514            let mut body = String::new();
515            response.body_mut().read_to_string(&mut body).await?;
516            anyhow::bail!(
517                "Failed to send autocomplete metrics for sweep_ai: {:?}\nBody: {}",
518                response.status(),
519                body,
520            );
521        }
522
523        Ok(())
524    })
525    .detach_and_log_err(cx);
526}
527
528pub(crate) fn edit_prediction_accepted(
529    store: &EditPredictionStore,
530    current_prediction: CurrentEditPrediction,
531    cx: &App,
532) {
533    let Some(api_token) = store
534        .sweep_ai
535        .api_token
536        .read(cx)
537        .key(&SWEEP_CREDENTIALS_URL)
538    else {
539        return;
540    };
541    let debug_info = store.sweep_ai.debug_info.clone();
542
543    let prediction = current_prediction.prediction;
544
545    let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
546    let autocomplete_id = prediction.id.to_string();
547
548    let device_id = store
549        .client
550        .user_id()
551        .as_ref()
552        .map(ToString::to_string)
553        .unwrap_or_default();
554
555    let suggestion_type = match current_prediction.shown_with {
556        Some(SuggestionDisplayType::DiffPopover) => SweepSuggestionType::Popup,
557        Some(SuggestionDisplayType::Jump) => return, // should'nt happen
558        Some(SuggestionDisplayType::GhostText) | None => SweepSuggestionType::GhostText,
559    };
560
561    let request_body = AutocompleteMetricsRequest {
562        event_type: SweepEventType::AutocompleteSuggestionAccepted,
563        suggestion_type,
564        additions,
565        deletions,
566        autocomplete_id,
567        edit_tracking: String::new(),
568        edit_tracking_line: None,
569        lifespan: None,
570        debug_info,
571        device_id,
572        privacy_mode_enabled: false,
573    };
574
575    send_autocomplete_metrics_request(cx, store.client.clone(), api_token, request_body);
576}
577
578pub fn edit_prediction_shown(
579    sweep_ai: &SweepAi,
580    client: Arc<Client>,
581    prediction: &EditPrediction,
582    display_type: SuggestionDisplayType,
583    cx: &App,
584) {
585    let Some(api_token) = sweep_ai.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
586        return;
587    };
588    let debug_info = sweep_ai.debug_info.clone();
589
590    let (additions, deletions) = compute_edit_metrics(&prediction.edits, &prediction.snapshot);
591    let autocomplete_id = prediction.id.to_string();
592
593    let suggestion_type = match display_type {
594        SuggestionDisplayType::GhostText => SweepSuggestionType::GhostText,
595        SuggestionDisplayType::DiffPopover => SweepSuggestionType::Popup,
596        SuggestionDisplayType::Jump => SweepSuggestionType::JumpToEdit,
597    };
598
599    let request_body = AutocompleteMetricsRequest {
600        event_type: SweepEventType::AutocompleteSuggestionShown,
601        suggestion_type,
602        additions,
603        deletions,
604        autocomplete_id,
605        edit_tracking: String::new(),
606        edit_tracking_line: None,
607        lifespan: None,
608        debug_info,
609        device_id: String::new(),
610        privacy_mode_enabled: false,
611    };
612
613    send_autocomplete_metrics_request(cx, client, api_token, request_body);
614}
615
616fn compute_edit_metrics(
617    edits: &[(Range<Anchor>, Arc<str>)],
618    snapshot: &BufferSnapshot,
619) -> (u32, u32) {
620    let mut additions = 0u32;
621    let mut deletions = 0u32;
622
623    for (range, new_text) in edits {
624        let old_text = snapshot.text_for_range(range.clone());
625        deletions += old_text
626            .map(|chunk| chunk.lines().count())
627            .sum::<usize>()
628            .max(1) as u32;
629        additions += new_text.lines().count().max(1) as u32;
630    }
631
632    (additions, deletions)
633}