sweep_ai.rs

  1use anyhow::{Context as _, Result};
  2use cloud_llm_client::predict_edits_v3::Event;
  3use credentials_provider::CredentialsProvider;
  4use edit_prediction_context::RelatedFile;
  5use futures::{AsyncReadExt as _, FutureExt, future::Shared};
  6use gpui::{
  7    App, AppContext as _, Entity, Task,
  8    http_client::{self, AsyncBody, Method},
  9};
 10use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
 11use lsp::DiagnosticSeverity;
 12use project::{Project, ProjectPath};
 13use serde::{Deserialize, Serialize};
 14use std::{
 15    collections::VecDeque,
 16    fmt::{self, Write as _},
 17    ops::Range,
 18    path::Path,
 19    sync::Arc,
 20    time::Instant,
 21};
 22
 23use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
 24
 25const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 26
 27pub struct SweepAi {
 28    pub api_token: Shared<Task<Option<String>>>,
 29    pub debug_info: Arc<str>,
 30}
 31
 32impl SweepAi {
 33    pub fn new(cx: &App) -> Self {
 34        SweepAi {
 35            api_token: load_api_token(cx).shared(),
 36            debug_info: debug_info(cx),
 37        }
 38    }
 39
 40    pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
 41        self.api_token = Task::ready(api_token.clone()).shared();
 42        store_api_token_in_keychain(api_token, cx)
 43    }
 44
 45    pub fn request_prediction_with_sweep(
 46        &self,
 47        project: &Entity<Project>,
 48        active_buffer: &Entity<Buffer>,
 49        snapshot: BufferSnapshot,
 50        position: language::Anchor,
 51        events: Vec<Arc<Event>>,
 52        recent_paths: &VecDeque<ProjectPath>,
 53        related_files: Vec<RelatedFile>,
 54        diagnostic_search_range: Range<Point>,
 55        cx: &mut App,
 56    ) -> Task<Result<Option<EditPredictionResult>>> {
 57        let debug_info = self.debug_info.clone();
 58        let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
 59            return Task::ready(Ok(None));
 60        };
 61        let full_path: Arc<Path> = snapshot
 62            .file()
 63            .map(|file| file.full_path(cx))
 64            .unwrap_or_else(|| "untitled".into())
 65            .into();
 66
 67        let project_file = project::File::from_dyn(snapshot.file());
 68        let repo_name = project_file
 69            .map(|file| file.worktree.read(cx).root_name_str())
 70            .unwrap_or("untitled")
 71            .into();
 72        let offset = position.to_offset(&snapshot);
 73
 74        let recent_buffers = recent_paths.iter().cloned();
 75        let http_client = cx.http_client();
 76
 77        let recent_buffer_snapshots = recent_buffers
 78            .filter_map(|project_path| {
 79                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
 80                if active_buffer == &buffer {
 81                    None
 82                } else {
 83                    Some(buffer.read(cx).snapshot())
 84                }
 85            })
 86            .take(3)
 87            .collect::<Vec<_>>();
 88
 89        let cursor_point = position.to_point(&snapshot);
 90        let buffer_snapshotted_at = Instant::now();
 91
 92        let result = cx.background_spawn(async move {
 93            let text = snapshot.text();
 94
 95            let mut recent_changes = String::new();
 96            for event in &events {
 97                write_event(event.as_ref(), &mut recent_changes).unwrap();
 98            }
 99
100            let mut file_chunks = recent_buffer_snapshots
101                .into_iter()
102                .map(|snapshot| {
103                    let end_point = Point::new(30, 0).min(snapshot.max_point());
104                    FileChunk {
105                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
106                        file_path: snapshot
107                            .file()
108                            .map(|f| f.path().as_unix_str())
109                            .unwrap_or("untitled")
110                            .to_string(),
111                        start_line: 0,
112                        end_line: end_point.row as usize,
113                        timestamp: snapshot.file().and_then(|file| {
114                            Some(
115                                file.disk_state()
116                                    .mtime()?
117                                    .to_seconds_and_nanos_for_persistence()?
118                                    .0,
119                            )
120                        }),
121                    }
122                })
123                .collect::<Vec<_>>();
124
125            let retrieval_chunks = related_files
126                .iter()
127                .flat_map(|related_file| {
128                    related_file.excerpts.iter().map(|excerpt| FileChunk {
129                        file_path: related_file.path.path.as_unix_str().to_string(),
130                        start_line: excerpt.point_range.start.row as usize,
131                        end_line: excerpt.point_range.end.row as usize,
132                        content: excerpt.text.to_string(),
133                        timestamp: None,
134                    })
135                })
136                .collect();
137
138            let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
139            let mut diagnostic_content = String::new();
140            let mut diagnostic_count = 0;
141
142            for entry in diagnostic_entries {
143                let start_point: Point = entry.range.start;
144
145                let severity = match entry.diagnostic.severity {
146                    DiagnosticSeverity::ERROR => "error",
147                    DiagnosticSeverity::WARNING => "warning",
148                    DiagnosticSeverity::INFORMATION => "info",
149                    DiagnosticSeverity::HINT => "hint",
150                    _ => continue,
151                };
152
153                diagnostic_count += 1;
154
155                writeln!(
156                    &mut diagnostic_content,
157                    "{} at line {}: {}",
158                    severity,
159                    start_point.row + 1,
160                    entry.diagnostic.message
161                )?;
162            }
163
164            if !diagnostic_content.is_empty() {
165                file_chunks.push(FileChunk {
166                    file_path: format!("Diagnostics for {}", full_path.display()),
167                    start_line: 0,
168                    end_line: diagnostic_count,
169                    content: diagnostic_content,
170                    timestamp: None,
171                });
172            }
173
174            let request_body = AutocompleteRequest {
175                debug_info,
176                repo_name,
177                file_path: full_path.clone(),
178                file_contents: text.clone(),
179                original_file_contents: text,
180                cursor_position: offset,
181                recent_changes: recent_changes.clone(),
182                changes_above_cursor: true,
183                multiple_suggestions: false,
184                branch: None,
185                file_chunks,
186                retrieval_chunks,
187                recent_user_actions: vec![],
188                use_bytes: true,
189                // TODO
190                privacy_mode_enabled: false,
191            };
192
193            let mut buf: Vec<u8> = Vec::new();
194            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
195            serde_json::to_writer(writer, &request_body)?;
196            let body: AsyncBody = buf.into();
197
198            let inputs = EditPredictionInputs {
199                events,
200                included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
201                    path: full_path.clone(),
202                    max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
203                    excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
204                        start_line: cloud_llm_client::predict_edits_v3::Line(0),
205                        text: request_body.file_contents.into(),
206                    }],
207                }],
208                cursor_point: cloud_llm_client::predict_edits_v3::Point {
209                    column: cursor_point.column,
210                    line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
211                },
212                cursor_path: full_path.clone(),
213            };
214
215            let request = http_client::Request::builder()
216                .uri(SWEEP_API_URL)
217                .header("Content-Type", "application/json")
218                .header("Authorization", format!("Bearer {}", api_token))
219                .header("Connection", "keep-alive")
220                .header("Content-Encoding", "br")
221                .method(Method::POST)
222                .body(body)?;
223
224            let mut response = http_client.send(request).await?;
225
226            let mut body: Vec<u8> = Vec::new();
227            response.body_mut().read_to_end(&mut body).await?;
228
229            let response_received_at = Instant::now();
230            if !response.status().is_success() {
231                anyhow::bail!(
232                    "Request failed with status: {:?}\nBody: {}",
233                    response.status(),
234                    String::from_utf8_lossy(&body),
235                );
236            };
237
238            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
239
240            let old_text = snapshot
241                .text_for_range(response.start_index..response.end_index)
242                .collect::<String>();
243            let edits = language::text_diff(&old_text, &response.completion)
244                .into_iter()
245                .map(|(range, text)| {
246                    (
247                        snapshot.anchor_after(response.start_index + range.start)
248                            ..snapshot.anchor_before(response.start_index + range.end),
249                        text,
250                    )
251                })
252                .collect::<Vec<_>>();
253
254            anyhow::Ok((
255                response.autocomplete_id,
256                edits,
257                snapshot,
258                response_received_at,
259                inputs,
260            ))
261        });
262
263        let buffer = active_buffer.clone();
264
265        cx.spawn(async move |cx| {
266            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
267            anyhow::Ok(Some(
268                EditPredictionResult::new(
269                    EditPredictionId(id.into()),
270                    &buffer,
271                    &old_snapshot,
272                    edits.into(),
273                    buffer_snapshotted_at,
274                    response_received_at,
275                    inputs,
276                    cx,
277                )
278                .await,
279            ))
280        })
281    }
282}
283
284pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
285pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
286
287pub fn load_api_token(cx: &App) -> Task<Option<String>> {
288    if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
289        .ok()
290        .filter(|value| !value.is_empty())
291    {
292        return Task::ready(Some(api_token));
293    }
294    let credentials_provider = <dyn CredentialsProvider>::global(cx);
295    cx.spawn(async move |cx| {
296        let (_, credentials) = credentials_provider
297            .read_credentials(SWEEP_CREDENTIALS_URL, &cx)
298            .await
299            .ok()??;
300        String::from_utf8(credentials).ok()
301    })
302}
303
304fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
305    let credentials_provider = <dyn CredentialsProvider>::global(cx);
306
307    cx.spawn(async move |cx| {
308        if let Some(api_token) = api_token {
309            credentials_provider
310                .write_credentials(
311                    SWEEP_CREDENTIALS_URL,
312                    SWEEP_CREDENTIALS_USERNAME,
313                    api_token.as_bytes(),
314                    cx,
315                )
316                .await
317                .context("Failed to save Sweep API token to system keychain")
318        } else {
319            credentials_provider
320                .delete_credentials(SWEEP_CREDENTIALS_URL, cx)
321                .await
322                .context("Failed to delete Sweep API token from system keychain")
323        }
324    })
325}
326
327#[derive(Debug, Clone, Serialize)]
328struct AutocompleteRequest {
329    pub debug_info: Arc<str>,
330    pub repo_name: String,
331    pub branch: Option<String>,
332    pub file_path: Arc<Path>,
333    pub file_contents: String,
334    pub recent_changes: String,
335    pub cursor_position: usize,
336    pub original_file_contents: String,
337    pub file_chunks: Vec<FileChunk>,
338    pub retrieval_chunks: Vec<FileChunk>,
339    pub recent_user_actions: Vec<UserAction>,
340    pub multiple_suggestions: bool,
341    pub privacy_mode_enabled: bool,
342    pub changes_above_cursor: bool,
343    pub use_bytes: bool,
344}
345
346#[derive(Debug, Clone, Serialize)]
347struct FileChunk {
348    pub file_path: String,
349    pub start_line: usize,
350    pub end_line: usize,
351    pub content: String,
352    pub timestamp: Option<u64>,
353}
354
355#[derive(Debug, Clone, Serialize)]
356struct UserAction {
357    pub action_type: ActionType,
358    pub line_number: usize,
359    pub offset: usize,
360    pub file_path: String,
361    pub timestamp: u64,
362}
363
364#[allow(dead_code)]
365#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
366#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
367enum ActionType {
368    CursorMovement,
369    InsertChar,
370    DeleteChar,
371    InsertSelection,
372    DeleteSelection,
373}
374
375#[derive(Debug, Clone, Deserialize)]
376struct AutocompleteResponse {
377    pub autocomplete_id: String,
378    pub start_index: usize,
379    pub end_index: usize,
380    pub completion: String,
381    #[allow(dead_code)]
382    pub confidence: f64,
383    #[allow(dead_code)]
384    pub logprobs: Option<serde_json::Value>,
385    #[allow(dead_code)]
386    pub finish_reason: Option<String>,
387    #[allow(dead_code)]
388    pub elapsed_time_ms: u64,
389    #[allow(dead_code)]
390    #[serde(default, rename = "completions")]
391    pub additional_completions: Vec<AdditionalCompletion>,
392}
393
394#[allow(dead_code)]
395#[derive(Debug, Clone, Deserialize)]
396struct AdditionalCompletion {
397    pub start_index: usize,
398    pub end_index: usize,
399    pub completion: String,
400    pub confidence: f64,
401    pub autocomplete_id: String,
402    pub logprobs: Option<serde_json::Value>,
403    pub finish_reason: Option<String>,
404}
405
406fn write_event(
407    event: &cloud_llm_client::predict_edits_v3::Event,
408    f: &mut impl fmt::Write,
409) -> fmt::Result {
410    match event {
411        cloud_llm_client::predict_edits_v3::Event::BufferChange {
412            old_path,
413            path,
414            diff,
415            ..
416        } => {
417            if old_path != path {
418                // TODO confirm how to do this for sweep
419                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
420            }
421
422            if !diff.is_empty() {
423                write!(f, "File: {}:\n{}\n", path.display(), diff)?
424            }
425
426            fmt::Result::Ok(())
427        }
428    }
429}
430
431fn debug_info(cx: &gpui::App) -> Arc<str> {
432    format!(
433        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
434        version = release_channel::AppVersion::global(cx),
435        sha = release_channel::AppCommitSha::try_global(cx)
436            .map_or("unknown".to_string(), |sha| sha.full()),
437        os = client::telemetry::os_name(),
438    )
439    .into()
440}