sweep_ai.rs

  1use anyhow::Result;
  2use cloud_llm_client::predict_edits_v3::Event;
  3use futures::AsyncReadExt as _;
  4use gpui::{
  5    App, AppContext as _, Entity, Task,
  6    http_client::{self, AsyncBody, Method},
  7};
  8use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
  9use lsp::DiagnosticSeverity;
 10use project::{Project, ProjectPath};
 11use serde::{Deserialize, Serialize};
 12use std::{
 13    collections::VecDeque,
 14    fmt::{self, Write as _},
 15    ops::Range,
 16    path::Path,
 17    sync::Arc,
 18    time::Instant,
 19};
 20
 21use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
 22
 23const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 24
 25pub struct SweepAi {
 26    pub api_token: Option<String>,
 27    pub debug_info: Arc<str>,
 28}
 29
 30impl SweepAi {
 31    pub fn new(cx: &App) -> Self {
 32        SweepAi {
 33            api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
 34            debug_info: debug_info(cx),
 35        }
 36    }
 37
 38    pub fn request_prediction_with_sweep(
 39        &self,
 40        project: &Entity<Project>,
 41        active_buffer: &Entity<Buffer>,
 42        snapshot: BufferSnapshot,
 43        position: language::Anchor,
 44        events: Vec<Arc<Event>>,
 45        recent_paths: &VecDeque<ProjectPath>,
 46        diagnostic_search_range: Range<Point>,
 47        cx: &mut App,
 48    ) -> Task<Result<Option<EditPredictionResult>>> {
 49        let debug_info = self.debug_info.clone();
 50        let Some(api_token) = self.api_token.clone() else {
 51            return Task::ready(Ok(None));
 52        };
 53        let full_path: Arc<Path> = snapshot
 54            .file()
 55            .map(|file| file.full_path(cx))
 56            .unwrap_or_else(|| "untitled".into())
 57            .into();
 58
 59        let project_file = project::File::from_dyn(snapshot.file());
 60        let repo_name = project_file
 61            .map(|file| file.worktree.read(cx).root_name_str())
 62            .unwrap_or("untitled")
 63            .into();
 64        let offset = position.to_offset(&snapshot);
 65
 66        let recent_buffers = recent_paths.iter().cloned();
 67        let http_client = cx.http_client();
 68
 69        let recent_buffer_snapshots = recent_buffers
 70            .filter_map(|project_path| {
 71                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
 72                if active_buffer == &buffer {
 73                    None
 74                } else {
 75                    Some(buffer.read(cx).snapshot())
 76                }
 77            })
 78            .take(3)
 79            .collect::<Vec<_>>();
 80
 81        let cursor_point = position.to_point(&snapshot);
 82        let buffer_snapshotted_at = Instant::now();
 83
 84        let result = cx.background_spawn(async move {
 85            let text = snapshot.text();
 86
 87            let mut recent_changes = String::new();
 88            for event in &events {
 89                write_event(event.as_ref(), &mut recent_changes).unwrap();
 90            }
 91
 92            let mut file_chunks = recent_buffer_snapshots
 93                .into_iter()
 94                .map(|snapshot| {
 95                    let end_point = Point::new(30, 0).min(snapshot.max_point());
 96                    FileChunk {
 97                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
 98                        file_path: snapshot
 99                            .file()
100                            .map(|f| f.path().as_unix_str())
101                            .unwrap_or("untitled")
102                            .to_string(),
103                        start_line: 0,
104                        end_line: end_point.row as usize,
105                        timestamp: snapshot.file().and_then(|file| {
106                            Some(
107                                file.disk_state()
108                                    .mtime()?
109                                    .to_seconds_and_nanos_for_persistence()?
110                                    .0,
111                            )
112                        }),
113                    }
114                })
115                .collect::<Vec<_>>();
116
117            let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
118            let mut diagnostic_content = String::new();
119            let mut diagnostic_count = 0;
120
121            for entry in diagnostic_entries {
122                let start_point: Point = entry.range.start;
123
124                let severity = match entry.diagnostic.severity {
125                    DiagnosticSeverity::ERROR => "error",
126                    DiagnosticSeverity::WARNING => "warning",
127                    DiagnosticSeverity::INFORMATION => "info",
128                    DiagnosticSeverity::HINT => "hint",
129                    _ => continue,
130                };
131
132                diagnostic_count += 1;
133
134                writeln!(
135                    &mut diagnostic_content,
136                    "{} at line {}: {}",
137                    severity,
138                    start_point.row + 1,
139                    entry.diagnostic.message
140                )?;
141            }
142
143            if !diagnostic_content.is_empty() {
144                file_chunks.push(FileChunk {
145                    file_path: format!("Diagnostics for {}", full_path.display()),
146                    start_line: 0,
147                    end_line: diagnostic_count,
148                    content: diagnostic_content,
149                    timestamp: None,
150                });
151            }
152
153            let request_body = AutocompleteRequest {
154                debug_info,
155                repo_name,
156                file_path: full_path.clone(),
157                file_contents: text.clone(),
158                original_file_contents: text,
159                cursor_position: offset,
160                recent_changes: recent_changes.clone(),
161                changes_above_cursor: true,
162                multiple_suggestions: false,
163                branch: None,
164                file_chunks,
165                retrieval_chunks: vec![],
166                recent_user_actions: vec![],
167                use_bytes: true,
168                // TODO
169                privacy_mode_enabled: false,
170            };
171
172            let mut buf: Vec<u8> = Vec::new();
173            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
174            serde_json::to_writer(writer, &request_body)?;
175            let body: AsyncBody = buf.into();
176
177            let inputs = EditPredictionInputs {
178                events,
179                included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
180                    path: full_path.clone(),
181                    max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
182                    excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
183                        start_line: cloud_llm_client::predict_edits_v3::Line(0),
184                        text: request_body.file_contents.into(),
185                    }],
186                }],
187                cursor_point: cloud_llm_client::predict_edits_v3::Point {
188                    column: cursor_point.column,
189                    line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
190                },
191                cursor_path: full_path.clone(),
192            };
193
194            let request = http_client::Request::builder()
195                .uri(SWEEP_API_URL)
196                .header("Content-Type", "application/json")
197                .header("Authorization", format!("Bearer {}", api_token))
198                .header("Connection", "keep-alive")
199                .header("Content-Encoding", "br")
200                .method(Method::POST)
201                .body(body)?;
202
203            let mut response = http_client.send(request).await?;
204
205            let mut body: Vec<u8> = Vec::new();
206            response.body_mut().read_to_end(&mut body).await?;
207
208            let response_received_at = Instant::now();
209            if !response.status().is_success() {
210                anyhow::bail!(
211                    "Request failed with status: {:?}\nBody: {}",
212                    response.status(),
213                    String::from_utf8_lossy(&body),
214                );
215            };
216
217            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
218
219            let old_text = snapshot
220                .text_for_range(response.start_index..response.end_index)
221                .collect::<String>();
222            let edits = language::text_diff(&old_text, &response.completion)
223                .into_iter()
224                .map(|(range, text)| {
225                    (
226                        snapshot.anchor_after(response.start_index + range.start)
227                            ..snapshot.anchor_before(response.start_index + range.end),
228                        text,
229                    )
230                })
231                .collect::<Vec<_>>();
232
233            anyhow::Ok((
234                response.autocomplete_id,
235                edits,
236                snapshot,
237                response_received_at,
238                inputs,
239            ))
240        });
241
242        let buffer = active_buffer.clone();
243
244        cx.spawn(async move |cx| {
245            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
246            anyhow::Ok(Some(
247                EditPredictionResult::new(
248                    EditPredictionId(id.into()),
249                    &buffer,
250                    &old_snapshot,
251                    edits.into(),
252                    buffer_snapshotted_at,
253                    response_received_at,
254                    inputs,
255                    cx,
256                )
257                .await,
258            ))
259        })
260    }
261}
262
263#[derive(Debug, Clone, Serialize)]
264struct AutocompleteRequest {
265    pub debug_info: Arc<str>,
266    pub repo_name: String,
267    pub branch: Option<String>,
268    pub file_path: Arc<Path>,
269    pub file_contents: String,
270    pub recent_changes: String,
271    pub cursor_position: usize,
272    pub original_file_contents: String,
273    pub file_chunks: Vec<FileChunk>,
274    pub retrieval_chunks: Vec<RetrievalChunk>,
275    pub recent_user_actions: Vec<UserAction>,
276    pub multiple_suggestions: bool,
277    pub privacy_mode_enabled: bool,
278    pub changes_above_cursor: bool,
279    pub use_bytes: bool,
280}
281
282#[derive(Debug, Clone, Serialize)]
283struct FileChunk {
284    pub file_path: String,
285    pub start_line: usize,
286    pub end_line: usize,
287    pub content: String,
288    pub timestamp: Option<u64>,
289}
290
291#[derive(Debug, Clone, Serialize)]
292struct RetrievalChunk {
293    pub file_path: String,
294    pub start_line: usize,
295    pub end_line: usize,
296    pub content: String,
297    pub timestamp: u64,
298}
299
300#[derive(Debug, Clone, Serialize)]
301struct UserAction {
302    pub action_type: ActionType,
303    pub line_number: usize,
304    pub offset: usize,
305    pub file_path: String,
306    pub timestamp: u64,
307}
308
309#[allow(dead_code)]
310#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
311#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
312enum ActionType {
313    CursorMovement,
314    InsertChar,
315    DeleteChar,
316    InsertSelection,
317    DeleteSelection,
318}
319
320#[derive(Debug, Clone, Deserialize)]
321struct AutocompleteResponse {
322    pub autocomplete_id: String,
323    pub start_index: usize,
324    pub end_index: usize,
325    pub completion: String,
326    #[allow(dead_code)]
327    pub confidence: f64,
328    #[allow(dead_code)]
329    pub logprobs: Option<serde_json::Value>,
330    #[allow(dead_code)]
331    pub finish_reason: Option<String>,
332    #[allow(dead_code)]
333    pub elapsed_time_ms: u64,
334    #[allow(dead_code)]
335    #[serde(default, rename = "completions")]
336    pub additional_completions: Vec<AdditionalCompletion>,
337}
338
339#[allow(dead_code)]
340#[derive(Debug, Clone, Deserialize)]
341struct AdditionalCompletion {
342    pub start_index: usize,
343    pub end_index: usize,
344    pub completion: String,
345    pub confidence: f64,
346    pub autocomplete_id: String,
347    pub logprobs: Option<serde_json::Value>,
348    pub finish_reason: Option<String>,
349}
350
351fn write_event(
352    event: &cloud_llm_client::predict_edits_v3::Event,
353    f: &mut impl fmt::Write,
354) -> fmt::Result {
355    match event {
356        cloud_llm_client::predict_edits_v3::Event::BufferChange {
357            old_path,
358            path,
359            diff,
360            ..
361        } => {
362            if old_path != path {
363                // TODO confirm how to do this for sweep
364                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
365            }
366
367            if !diff.is_empty() {
368                write!(f, "File: {}:\n{}\n", path.display(), diff)?
369            }
370
371            fmt::Result::Ok(())
372        }
373    }
374}
375
376fn debug_info(cx: &gpui::App) -> Arc<str> {
377    format!(
378        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
379        version = release_channel::AppVersion::global(cx),
380        sha = release_channel::AppCommitSha::try_global(cx)
381            .map_or("unknown".to_string(), |sha| sha.full()),
382        os = client::telemetry::os_name(),
383    )
384    .into()
385}