sweep_ai.rs

  1use anyhow::Result;
  2use cloud_llm_client::predict_edits_v3::Event;
  3use futures::AsyncReadExt as _;
  4use gpui::{
  5    App, AppContext as _, Entity, Task,
  6    http_client::{self, AsyncBody, Method},
  7};
  8use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
  9use lsp::DiagnosticSeverity;
 10use project::{Project, ProjectPath};
 11use serde::{Deserialize, Serialize};
 12use std::{
 13    collections::VecDeque,
 14    fmt::{self, Write as _},
 15    ops::Range,
 16    path::Path,
 17    sync::Arc,
 18    time::Instant,
 19};
 20
 21use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
 22
 23const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 24
 25pub struct SweepAi {
 26    pub api_token: Option<String>,
 27    pub debug_info: Arc<str>,
 28}
 29
 30impl SweepAi {
 31    pub fn new(cx: &App) -> Self {
 32        SweepAi {
 33            api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
 34            debug_info: debug_info(cx),
 35        }
 36    }
 37
 38    pub fn request_prediction_with_sweep(
 39        &self,
 40        project: &Entity<Project>,
 41        active_buffer: &Entity<Buffer>,
 42        snapshot: BufferSnapshot,
 43        position: language::Anchor,
 44        events: Vec<Arc<Event>>,
 45        recent_paths: &VecDeque<ProjectPath>,
 46        diagnostic_search_range: Range<Point>,
 47        cx: &mut App,
 48    ) -> Task<Result<Option<EditPredictionResult>>> {
 49        let debug_info = self.debug_info.clone();
 50        let Some(api_token) = self.api_token.clone() else {
 51            return Task::ready(Ok(None));
 52        };
 53        let full_path: Arc<Path> = snapshot
 54            .file()
 55            .map(|file| file.full_path(cx))
 56            .unwrap_or_else(|| "untitled".into())
 57            .into();
 58
 59        let project_file = project::File::from_dyn(snapshot.file());
 60        let repo_name = project_file
 61            .map(|file| file.worktree.read(cx).root_name_str())
 62            .unwrap_or("untitled")
 63            .into();
 64        let offset = position.to_offset(&snapshot);
 65
 66        let recent_buffers = recent_paths.iter().cloned();
 67        let http_client = cx.http_client();
 68
 69        let recent_buffer_snapshots = recent_buffers
 70            .filter_map(|project_path| {
 71                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
 72                if active_buffer == &buffer {
 73                    None
 74                } else {
 75                    Some(buffer.read(cx).snapshot())
 76                }
 77            })
 78            .take(3)
 79            .collect::<Vec<_>>();
 80
 81        let cursor_point = position.to_point(&snapshot);
 82        let buffer_snapshotted_at = Instant::now();
 83
 84        let result = cx.background_spawn(async move {
 85            let text = snapshot.text();
 86
 87            let mut recent_changes = String::new();
 88            for event in &events {
 89                write_event(event.as_ref(), &mut recent_changes).unwrap();
 90            }
 91
 92            let mut file_chunks = recent_buffer_snapshots
 93                .into_iter()
 94                .map(|snapshot| {
 95                    let end_point = Point::new(30, 0).min(snapshot.max_point());
 96                    FileChunk {
 97                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
 98                        file_path: snapshot
 99                            .file()
100                            .map(|f| f.path().as_unix_str())
101                            .unwrap_or("untitled")
102                            .to_string(),
103                        start_line: 0,
104                        end_line: end_point.row as usize,
105                        timestamp: snapshot.file().and_then(|file| {
106                            Some(
107                                file.disk_state()
108                                    .mtime()?
109                                    .to_seconds_and_nanos_for_persistence()?
110                                    .0,
111                            )
112                        }),
113                    }
114                })
115                .collect::<Vec<_>>();
116
117            let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
118            let mut diagnostic_content = String::new();
119            let mut diagnostic_count = 0;
120
121            for entry in diagnostic_entries {
122                let start_point: Point = entry.range.start;
123
124                let severity = match entry.diagnostic.severity {
125                    DiagnosticSeverity::ERROR => "error",
126                    DiagnosticSeverity::WARNING => "warning",
127                    DiagnosticSeverity::INFORMATION => "info",
128                    DiagnosticSeverity::HINT => "hint",
129                    _ => continue,
130                };
131
132                diagnostic_count += 1;
133
134                writeln!(
135                    &mut diagnostic_content,
136                    "{} at line {}: {}",
137                    severity,
138                    start_point.row + 1,
139                    entry.diagnostic.message
140                )?;
141            }
142
143            if !diagnostic_content.is_empty() {
144                file_chunks.push(FileChunk {
145                    file_path: format!("Diagnostics for {}", full_path.display()),
146                    start_line: 0,
147                    end_line: diagnostic_count,
148                    content: diagnostic_content,
149                    timestamp: None,
150                });
151            }
152
153            let request_body = AutocompleteRequest {
154                debug_info,
155                repo_name,
156                file_path: full_path.clone(),
157                file_contents: text.clone(),
158                original_file_contents: text,
159                cursor_position: offset,
160                recent_changes: recent_changes.clone(),
161                changes_above_cursor: true,
162                multiple_suggestions: false,
163                branch: None,
164                file_chunks,
165                retrieval_chunks: vec![],
166                recent_user_actions: vec![],
167                // TODO
168                privacy_mode_enabled: false,
169            };
170
171            let mut buf: Vec<u8> = Vec::new();
172            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
173            serde_json::to_writer(writer, &request_body)?;
174            let body: AsyncBody = buf.into();
175
176            let inputs = EditPredictionInputs {
177                events,
178                included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
179                    path: full_path.clone(),
180                    max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
181                    excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
182                        start_line: cloud_llm_client::predict_edits_v3::Line(0),
183                        text: request_body.file_contents.into(),
184                    }],
185                }],
186                cursor_point: cloud_llm_client::predict_edits_v3::Point {
187                    column: cursor_point.column,
188                    line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
189                },
190                cursor_path: full_path.clone(),
191            };
192
193            let request = http_client::Request::builder()
194                .uri(SWEEP_API_URL)
195                .header("Content-Type", "application/json")
196                .header("Authorization", format!("Bearer {}", api_token))
197                .header("Connection", "keep-alive")
198                .header("Content-Encoding", "br")
199                .method(Method::POST)
200                .body(body)?;
201
202            let mut response = http_client.send(request).await?;
203
204            let mut body: Vec<u8> = Vec::new();
205            response.body_mut().read_to_end(&mut body).await?;
206
207            let response_received_at = Instant::now();
208            if !response.status().is_success() {
209                anyhow::bail!(
210                    "Request failed with status: {:?}\nBody: {}",
211                    response.status(),
212                    String::from_utf8_lossy(&body),
213                );
214            };
215
216            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
217
218            let old_text = snapshot
219                .text_for_range(response.start_index..response.end_index)
220                .collect::<String>();
221            let edits = language::text_diff(&old_text, &response.completion)
222                .into_iter()
223                .map(|(range, text)| {
224                    (
225                        snapshot.anchor_after(response.start_index + range.start)
226                            ..snapshot.anchor_before(response.start_index + range.end),
227                        text,
228                    )
229                })
230                .collect::<Vec<_>>();
231
232            anyhow::Ok((
233                response.autocomplete_id,
234                edits,
235                snapshot,
236                response_received_at,
237                inputs,
238            ))
239        });
240
241        let buffer = active_buffer.clone();
242
243        cx.spawn(async move |cx| {
244            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
245            anyhow::Ok(Some(
246                EditPredictionResult::new(
247                    EditPredictionId(id.into()),
248                    &buffer,
249                    &old_snapshot,
250                    edits.into(),
251                    buffer_snapshotted_at,
252                    response_received_at,
253                    inputs,
254                    cx,
255                )
256                .await,
257            ))
258        })
259    }
260}
261
262#[derive(Debug, Clone, Serialize)]
263struct AutocompleteRequest {
264    pub debug_info: Arc<str>,
265    pub repo_name: String,
266    pub branch: Option<String>,
267    pub file_path: Arc<Path>,
268    pub file_contents: String,
269    pub recent_changes: String,
270    pub cursor_position: usize,
271    pub original_file_contents: String,
272    pub file_chunks: Vec<FileChunk>,
273    pub retrieval_chunks: Vec<RetrievalChunk>,
274    pub recent_user_actions: Vec<UserAction>,
275    pub multiple_suggestions: bool,
276    pub privacy_mode_enabled: bool,
277    pub changes_above_cursor: bool,
278}
279
280#[derive(Debug, Clone, Serialize)]
281struct FileChunk {
282    pub file_path: String,
283    pub start_line: usize,
284    pub end_line: usize,
285    pub content: String,
286    pub timestamp: Option<u64>,
287}
288
289#[derive(Debug, Clone, Serialize)]
290struct RetrievalChunk {
291    pub file_path: String,
292    pub start_line: usize,
293    pub end_line: usize,
294    pub content: String,
295    pub timestamp: u64,
296}
297
298#[derive(Debug, Clone, Serialize)]
299struct UserAction {
300    pub action_type: ActionType,
301    pub line_number: usize,
302    pub offset: usize,
303    pub file_path: String,
304    pub timestamp: u64,
305}
306
307#[allow(dead_code)]
308#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
309#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
310enum ActionType {
311    CursorMovement,
312    InsertChar,
313    DeleteChar,
314    InsertSelection,
315    DeleteSelection,
316}
317
318#[derive(Debug, Clone, Deserialize)]
319struct AutocompleteResponse {
320    pub autocomplete_id: String,
321    pub start_index: usize,
322    pub end_index: usize,
323    pub completion: String,
324    #[allow(dead_code)]
325    pub confidence: f64,
326    #[allow(dead_code)]
327    pub logprobs: Option<serde_json::Value>,
328    #[allow(dead_code)]
329    pub finish_reason: Option<String>,
330    #[allow(dead_code)]
331    pub elapsed_time_ms: u64,
332    #[allow(dead_code)]
333    #[serde(default, rename = "completions")]
334    pub additional_completions: Vec<AdditionalCompletion>,
335}
336
337#[allow(dead_code)]
338#[derive(Debug, Clone, Deserialize)]
339struct AdditionalCompletion {
340    pub start_index: usize,
341    pub end_index: usize,
342    pub completion: String,
343    pub confidence: f64,
344    pub autocomplete_id: String,
345    pub logprobs: Option<serde_json::Value>,
346    pub finish_reason: Option<String>,
347}
348
349fn write_event(
350    event: &cloud_llm_client::predict_edits_v3::Event,
351    f: &mut impl fmt::Write,
352) -> fmt::Result {
353    match event {
354        cloud_llm_client::predict_edits_v3::Event::BufferChange {
355            old_path,
356            path,
357            diff,
358            ..
359        } => {
360            if old_path != path {
361                // TODO confirm how to do this for sweep
362                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
363            }
364
365            if !diff.is_empty() {
366                write!(f, "File: {}:\n{}\n", path.display(), diff)?
367            }
368
369            fmt::Result::Ok(())
370        }
371    }
372}
373
374fn debug_info(cx: &gpui::App) -> Arc<str> {
375    format!(
376        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
377        version = release_channel::AppVersion::global(cx),
378        sha = release_channel::AppCommitSha::try_global(cx)
379            .map_or("unknown".to_string(), |sha| sha.full()),
380        os = client::telemetry::os_name(),
381    )
382    .into()
383}