sweep_ai.rs

  1use anyhow::{Context as _, Result};
  2use cloud_llm_client::predict_edits_v3::Event;
  3use credentials_provider::CredentialsProvider;
  4use futures::{AsyncReadExt as _, FutureExt, future::Shared};
  5use gpui::{
  6    App, AppContext as _, Entity, Task,
  7    http_client::{self, AsyncBody, Method},
  8};
  9use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
 10use lsp::DiagnosticSeverity;
 11use project::{Project, ProjectPath};
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    collections::VecDeque,
 15    fmt::{self, Write as _},
 16    ops::Range,
 17    path::Path,
 18    sync::Arc,
 19    time::Instant,
 20};
 21
 22use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
 23
 24const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 25
 26pub struct SweepAi {
 27    pub api_token: Shared<Task<Option<String>>>,
 28    pub debug_info: Arc<str>,
 29}
 30
 31impl SweepAi {
 32    pub fn new(cx: &App) -> Self {
 33        SweepAi {
 34            api_token: load_api_token(cx).shared(),
 35            debug_info: debug_info(cx),
 36        }
 37    }
 38
 39    pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
 40        self.api_token = Task::ready(api_token.clone()).shared();
 41        store_api_token_in_keychain(api_token, cx)
 42    }
 43
 44    pub fn request_prediction_with_sweep(
 45        &self,
 46        project: &Entity<Project>,
 47        active_buffer: &Entity<Buffer>,
 48        snapshot: BufferSnapshot,
 49        position: language::Anchor,
 50        events: Vec<Arc<Event>>,
 51        recent_paths: &VecDeque<ProjectPath>,
 52        diagnostic_search_range: Range<Point>,
 53        cx: &mut App,
 54    ) -> Task<Result<Option<EditPredictionResult>>> {
 55        let debug_info = self.debug_info.clone();
 56        let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
 57            return Task::ready(Ok(None));
 58        };
 59        let full_path: Arc<Path> = snapshot
 60            .file()
 61            .map(|file| file.full_path(cx))
 62            .unwrap_or_else(|| "untitled".into())
 63            .into();
 64
 65        let project_file = project::File::from_dyn(snapshot.file());
 66        let repo_name = project_file
 67            .map(|file| file.worktree.read(cx).root_name_str())
 68            .unwrap_or("untitled")
 69            .into();
 70        let offset = position.to_offset(&snapshot);
 71
 72        let recent_buffers = recent_paths.iter().cloned();
 73        let http_client = cx.http_client();
 74
 75        let recent_buffer_snapshots = recent_buffers
 76            .filter_map(|project_path| {
 77                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
 78                if active_buffer == &buffer {
 79                    None
 80                } else {
 81                    Some(buffer.read(cx).snapshot())
 82                }
 83            })
 84            .take(3)
 85            .collect::<Vec<_>>();
 86
 87        let cursor_point = position.to_point(&snapshot);
 88        let buffer_snapshotted_at = Instant::now();
 89
 90        let result = cx.background_spawn(async move {
 91            let text = snapshot.text();
 92
 93            let mut recent_changes = String::new();
 94            for event in &events {
 95                write_event(event.as_ref(), &mut recent_changes).unwrap();
 96            }
 97
 98            let mut file_chunks = recent_buffer_snapshots
 99                .into_iter()
100                .map(|snapshot| {
101                    let end_point = Point::new(30, 0).min(snapshot.max_point());
102                    FileChunk {
103                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
104                        file_path: snapshot
105                            .file()
106                            .map(|f| f.path().as_unix_str())
107                            .unwrap_or("untitled")
108                            .to_string(),
109                        start_line: 0,
110                        end_line: end_point.row as usize,
111                        timestamp: snapshot.file().and_then(|file| {
112                            Some(
113                                file.disk_state()
114                                    .mtime()?
115                                    .to_seconds_and_nanos_for_persistence()?
116                                    .0,
117                            )
118                        }),
119                    }
120                })
121                .collect::<Vec<_>>();
122
123            let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
124            let mut diagnostic_content = String::new();
125            let mut diagnostic_count = 0;
126
127            for entry in diagnostic_entries {
128                let start_point: Point = entry.range.start;
129
130                let severity = match entry.diagnostic.severity {
131                    DiagnosticSeverity::ERROR => "error",
132                    DiagnosticSeverity::WARNING => "warning",
133                    DiagnosticSeverity::INFORMATION => "info",
134                    DiagnosticSeverity::HINT => "hint",
135                    _ => continue,
136                };
137
138                diagnostic_count += 1;
139
140                writeln!(
141                    &mut diagnostic_content,
142                    "{} at line {}: {}",
143                    severity,
144                    start_point.row + 1,
145                    entry.diagnostic.message
146                )?;
147            }
148
149            if !diagnostic_content.is_empty() {
150                file_chunks.push(FileChunk {
151                    file_path: format!("Diagnostics for {}", full_path.display()),
152                    start_line: 0,
153                    end_line: diagnostic_count,
154                    content: diagnostic_content,
155                    timestamp: None,
156                });
157            }
158
159            let request_body = AutocompleteRequest {
160                debug_info,
161                repo_name,
162                file_path: full_path.clone(),
163                file_contents: text.clone(),
164                original_file_contents: text,
165                cursor_position: offset,
166                recent_changes: recent_changes.clone(),
167                changes_above_cursor: true,
168                multiple_suggestions: false,
169                branch: None,
170                file_chunks,
171                retrieval_chunks: vec![],
172                recent_user_actions: vec![],
173                use_bytes: true,
174                // TODO
175                privacy_mode_enabled: false,
176            };
177
178            let mut buf: Vec<u8> = Vec::new();
179            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
180            serde_json::to_writer(writer, &request_body)?;
181            let body: AsyncBody = buf.into();
182
183            let inputs = EditPredictionInputs {
184                events,
185                included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
186                    path: full_path.clone(),
187                    max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
188                    excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
189                        start_line: cloud_llm_client::predict_edits_v3::Line(0),
190                        text: request_body.file_contents.into(),
191                    }],
192                }],
193                cursor_point: cloud_llm_client::predict_edits_v3::Point {
194                    column: cursor_point.column,
195                    line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
196                },
197                cursor_path: full_path.clone(),
198            };
199
200            let request = http_client::Request::builder()
201                .uri(SWEEP_API_URL)
202                .header("Content-Type", "application/json")
203                .header("Authorization", format!("Bearer {}", api_token))
204                .header("Connection", "keep-alive")
205                .header("Content-Encoding", "br")
206                .method(Method::POST)
207                .body(body)?;
208
209            let mut response = http_client.send(request).await?;
210
211            let mut body: Vec<u8> = Vec::new();
212            response.body_mut().read_to_end(&mut body).await?;
213
214            let response_received_at = Instant::now();
215            if !response.status().is_success() {
216                anyhow::bail!(
217                    "Request failed with status: {:?}\nBody: {}",
218                    response.status(),
219                    String::from_utf8_lossy(&body),
220                );
221            };
222
223            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
224
225            let old_text = snapshot
226                .text_for_range(response.start_index..response.end_index)
227                .collect::<String>();
228            let edits = language::text_diff(&old_text, &response.completion)
229                .into_iter()
230                .map(|(range, text)| {
231                    (
232                        snapshot.anchor_after(response.start_index + range.start)
233                            ..snapshot.anchor_before(response.start_index + range.end),
234                        text,
235                    )
236                })
237                .collect::<Vec<_>>();
238
239            anyhow::Ok((
240                response.autocomplete_id,
241                edits,
242                snapshot,
243                response_received_at,
244                inputs,
245            ))
246        });
247
248        let buffer = active_buffer.clone();
249
250        cx.spawn(async move |cx| {
251            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
252            anyhow::Ok(Some(
253                EditPredictionResult::new(
254                    EditPredictionId(id.into()),
255                    &buffer,
256                    &old_snapshot,
257                    edits.into(),
258                    buffer_snapshotted_at,
259                    response_received_at,
260                    inputs,
261                    cx,
262                )
263                .await,
264            ))
265        })
266    }
267}
268
269pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
270pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
271
272pub fn load_api_token(cx: &App) -> Task<Option<String>> {
273    if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
274        .ok()
275        .filter(|value| !value.is_empty())
276    {
277        return Task::ready(Some(api_token));
278    }
279    let credentials_provider = <dyn CredentialsProvider>::global(cx);
280    cx.spawn(async move |cx| {
281        let (_, credentials) = credentials_provider
282            .read_credentials(SWEEP_CREDENTIALS_URL, &cx)
283            .await
284            .ok()??;
285        String::from_utf8(credentials).ok()
286    })
287}
288
289fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
290    let credentials_provider = <dyn CredentialsProvider>::global(cx);
291
292    cx.spawn(async move |cx| {
293        if let Some(api_token) = api_token {
294            credentials_provider
295                .write_credentials(
296                    SWEEP_CREDENTIALS_URL,
297                    SWEEP_CREDENTIALS_USERNAME,
298                    api_token.as_bytes(),
299                    cx,
300                )
301                .await
302                .context("Failed to save Sweep API token to system keychain")
303        } else {
304            credentials_provider
305                .delete_credentials(SWEEP_CREDENTIALS_URL, cx)
306                .await
307                .context("Failed to delete Sweep API token from system keychain")
308        }
309    })
310}
311
312#[derive(Debug, Clone, Serialize)]
313struct AutocompleteRequest {
314    pub debug_info: Arc<str>,
315    pub repo_name: String,
316    pub branch: Option<String>,
317    pub file_path: Arc<Path>,
318    pub file_contents: String,
319    pub recent_changes: String,
320    pub cursor_position: usize,
321    pub original_file_contents: String,
322    pub file_chunks: Vec<FileChunk>,
323    pub retrieval_chunks: Vec<RetrievalChunk>,
324    pub recent_user_actions: Vec<UserAction>,
325    pub multiple_suggestions: bool,
326    pub privacy_mode_enabled: bool,
327    pub changes_above_cursor: bool,
328    pub use_bytes: bool,
329}
330
331#[derive(Debug, Clone, Serialize)]
332struct FileChunk {
333    pub file_path: String,
334    pub start_line: usize,
335    pub end_line: usize,
336    pub content: String,
337    pub timestamp: Option<u64>,
338}
339
340#[derive(Debug, Clone, Serialize)]
341struct RetrievalChunk {
342    pub file_path: String,
343    pub start_line: usize,
344    pub end_line: usize,
345    pub content: String,
346    pub timestamp: u64,
347}
348
349#[derive(Debug, Clone, Serialize)]
350struct UserAction {
351    pub action_type: ActionType,
352    pub line_number: usize,
353    pub offset: usize,
354    pub file_path: String,
355    pub timestamp: u64,
356}
357
358#[allow(dead_code)]
359#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
360#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
361enum ActionType {
362    CursorMovement,
363    InsertChar,
364    DeleteChar,
365    InsertSelection,
366    DeleteSelection,
367}
368
369#[derive(Debug, Clone, Deserialize)]
370struct AutocompleteResponse {
371    pub autocomplete_id: String,
372    pub start_index: usize,
373    pub end_index: usize,
374    pub completion: String,
375    #[allow(dead_code)]
376    pub confidence: f64,
377    #[allow(dead_code)]
378    pub logprobs: Option<serde_json::Value>,
379    #[allow(dead_code)]
380    pub finish_reason: Option<String>,
381    #[allow(dead_code)]
382    pub elapsed_time_ms: u64,
383    #[allow(dead_code)]
384    #[serde(default, rename = "completions")]
385    pub additional_completions: Vec<AdditionalCompletion>,
386}
387
388#[allow(dead_code)]
389#[derive(Debug, Clone, Deserialize)]
390struct AdditionalCompletion {
391    pub start_index: usize,
392    pub end_index: usize,
393    pub completion: String,
394    pub confidence: f64,
395    pub autocomplete_id: String,
396    pub logprobs: Option<serde_json::Value>,
397    pub finish_reason: Option<String>,
398}
399
400fn write_event(
401    event: &cloud_llm_client::predict_edits_v3::Event,
402    f: &mut impl fmt::Write,
403) -> fmt::Result {
404    match event {
405        cloud_llm_client::predict_edits_v3::Event::BufferChange {
406            old_path,
407            path,
408            diff,
409            ..
410        } => {
411            if old_path != path {
412                // TODO confirm how to do this for sweep
413                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
414            }
415
416            if !diff.is_empty() {
417                write!(f, "File: {}:\n{}\n", path.display(), diff)?
418            }
419
420            fmt::Result::Ok(())
421        }
422    }
423}
424
425fn debug_info(cx: &gpui::App) -> Arc<str> {
426    format!(
427        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
428        version = release_channel::AppVersion::global(cx),
429        sha = release_channel::AppCommitSha::try_global(cx)
430            .map_or("unknown".to_string(), |sha| sha.full()),
431        os = client::telemetry::os_name(),
432    )
433    .into()
434}