sweep_ai.rs

  1use anyhow::{Context as _, Result};
  2use cloud_llm_client::predict_edits_v3::Event;
  3use futures::AsyncReadExt as _;
  4use gpui::{
  5    App, AppContext as _, Entity, Task,
  6    http_client::{self, AsyncBody, Method},
  7};
  8use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
  9use lsp::DiagnosticSeverity;
 10use project::{Project, ProjectPath};
 11use serde::{Deserialize, Serialize};
 12use std::{
 13    collections::VecDeque,
 14    fmt::{self, Write as _},
 15    ops::Range,
 16    path::Path,
 17    sync::Arc,
 18    time::Instant,
 19};
 20use util::ResultExt as _;
 21
 22use crate::{EditPrediction, EditPredictionId, EditPredictionInputs};
 23
 24const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 25
 26pub struct SweepAi {
 27    pub api_token: Option<String>,
 28    pub debug_info: Arc<str>,
 29}
 30
 31impl SweepAi {
 32    pub fn new(cx: &App) -> Self {
 33        SweepAi {
 34            api_token: std::env::var("SWEEP_AI_TOKEN")
 35                .context("No SWEEP_AI_TOKEN environment variable set")
 36                .log_err(),
 37            debug_info: debug_info(cx),
 38        }
 39    }
 40
 41    pub fn request_prediction_with_sweep(
 42        &self,
 43        project: &Entity<Project>,
 44        active_buffer: &Entity<Buffer>,
 45        snapshot: BufferSnapshot,
 46        position: language::Anchor,
 47        events: Vec<Arc<Event>>,
 48        recent_paths: &VecDeque<ProjectPath>,
 49        diagnostic_search_range: Range<Point>,
 50        cx: &mut App,
 51    ) -> Task<Result<Option<EditPrediction>>> {
 52        let debug_info = self.debug_info.clone();
 53        let Some(api_token) = self.api_token.clone() else {
 54            return Task::ready(Ok(None));
 55        };
 56        let full_path: Arc<Path> = snapshot
 57            .file()
 58            .map(|file| file.full_path(cx))
 59            .unwrap_or_else(|| "untitled".into())
 60            .into();
 61
 62        let project_file = project::File::from_dyn(snapshot.file());
 63        let repo_name = project_file
 64            .map(|file| file.worktree.read(cx).root_name_str())
 65            .unwrap_or("untitled")
 66            .into();
 67        let offset = position.to_offset(&snapshot);
 68
 69        let recent_buffers = recent_paths.iter().cloned();
 70        let http_client = cx.http_client();
 71
 72        let recent_buffer_snapshots = recent_buffers
 73            .filter_map(|project_path| {
 74                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
 75                if active_buffer == &buffer {
 76                    None
 77                } else {
 78                    Some(buffer.read(cx).snapshot())
 79                }
 80            })
 81            .take(3)
 82            .collect::<Vec<_>>();
 83
 84        let cursor_point = position.to_point(&snapshot);
 85        let buffer_snapshotted_at = Instant::now();
 86
 87        let result = cx.background_spawn(async move {
 88            let text = snapshot.text();
 89
 90            let mut recent_changes = String::new();
 91            for event in &events {
 92                write_event(event.as_ref(), &mut recent_changes).unwrap();
 93            }
 94
 95            let mut file_chunks = recent_buffer_snapshots
 96                .into_iter()
 97                .map(|snapshot| {
 98                    let end_point = Point::new(30, 0).min(snapshot.max_point());
 99                    FileChunk {
100                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
101                        file_path: snapshot
102                            .file()
103                            .map(|f| f.path().as_unix_str())
104                            .unwrap_or("untitled")
105                            .to_string(),
106                        start_line: 0,
107                        end_line: end_point.row as usize,
108                        timestamp: snapshot.file().and_then(|file| {
109                            Some(
110                                file.disk_state()
111                                    .mtime()?
112                                    .to_seconds_and_nanos_for_persistence()?
113                                    .0,
114                            )
115                        }),
116                    }
117                })
118                .collect::<Vec<_>>();
119
120            let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
121            let mut diagnostic_content = String::new();
122            let mut diagnostic_count = 0;
123
124            for entry in diagnostic_entries {
125                let start_point: Point = entry.range.start;
126
127                let severity = match entry.diagnostic.severity {
128                    DiagnosticSeverity::ERROR => "error",
129                    DiagnosticSeverity::WARNING => "warning",
130                    DiagnosticSeverity::INFORMATION => "info",
131                    DiagnosticSeverity::HINT => "hint",
132                    _ => continue,
133                };
134
135                diagnostic_count += 1;
136
137                writeln!(
138                    &mut diagnostic_content,
139                    "{} at line {}: {}",
140                    severity,
141                    start_point.row + 1,
142                    entry.diagnostic.message
143                )?;
144            }
145
146            if !diagnostic_content.is_empty() {
147                file_chunks.push(FileChunk {
148                    file_path: format!("Diagnostics for {}", full_path.display()),
149                    start_line: 0,
150                    end_line: diagnostic_count,
151                    content: diagnostic_content,
152                    timestamp: None,
153                });
154            }
155
156            let request_body = AutocompleteRequest {
157                debug_info,
158                repo_name,
159                file_path: full_path.clone(),
160                file_contents: text.clone(),
161                original_file_contents: text,
162                cursor_position: offset,
163                recent_changes: recent_changes.clone(),
164                changes_above_cursor: true,
165                multiple_suggestions: false,
166                branch: None,
167                file_chunks,
168                retrieval_chunks: vec![],
169                recent_user_actions: vec![],
170                // TODO
171                privacy_mode_enabled: false,
172            };
173
174            let mut buf: Vec<u8> = Vec::new();
175            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
176            serde_json::to_writer(writer, &request_body)?;
177            let body: AsyncBody = buf.into();
178
179            let inputs = EditPredictionInputs {
180                events,
181                included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
182                    path: full_path.clone(),
183                    max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
184                    excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
185                        start_line: cloud_llm_client::predict_edits_v3::Line(0),
186                        text: request_body.file_contents.into(),
187                    }],
188                }],
189                cursor_point: cloud_llm_client::predict_edits_v3::Point {
190                    column: cursor_point.column,
191                    line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
192                },
193                cursor_path: full_path.clone(),
194            };
195
196            let request = http_client::Request::builder()
197                .uri(SWEEP_API_URL)
198                .header("Content-Type", "application/json")
199                .header("Authorization", format!("Bearer {}", api_token))
200                .header("Connection", "keep-alive")
201                .header("Content-Encoding", "br")
202                .method(Method::POST)
203                .body(body)?;
204
205            let mut response = http_client.send(request).await?;
206
207            let mut body: Vec<u8> = Vec::new();
208            response.body_mut().read_to_end(&mut body).await?;
209
210            let response_received_at = Instant::now();
211            if !response.status().is_success() {
212                anyhow::bail!(
213                    "Request failed with status: {:?}\nBody: {}",
214                    response.status(),
215                    String::from_utf8_lossy(&body),
216                );
217            };
218
219            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
220
221            let old_text = snapshot
222                .text_for_range(response.start_index..response.end_index)
223                .collect::<String>();
224            let edits = language::text_diff(&old_text, &response.completion)
225                .into_iter()
226                .map(|(range, text)| {
227                    (
228                        snapshot.anchor_after(response.start_index + range.start)
229                            ..snapshot.anchor_before(response.start_index + range.end),
230                        text,
231                    )
232                })
233                .collect::<Vec<_>>();
234
235            anyhow::Ok((
236                response.autocomplete_id,
237                edits,
238                snapshot,
239                response_received_at,
240                inputs,
241            ))
242        });
243
244        let buffer = active_buffer.clone();
245
246        cx.spawn(async move |cx| {
247            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
248            anyhow::Ok(
249                EditPrediction::new(
250                    EditPredictionId(id.into()),
251                    &buffer,
252                    &old_snapshot,
253                    edits.into(),
254                    buffer_snapshotted_at,
255                    response_received_at,
256                    inputs,
257                    cx,
258                )
259                .await,
260            )
261        })
262    }
263}
264
265#[derive(Debug, Clone, Serialize)]
266struct AutocompleteRequest {
267    pub debug_info: Arc<str>,
268    pub repo_name: String,
269    pub branch: Option<String>,
270    pub file_path: Arc<Path>,
271    pub file_contents: String,
272    pub recent_changes: String,
273    pub cursor_position: usize,
274    pub original_file_contents: String,
275    pub file_chunks: Vec<FileChunk>,
276    pub retrieval_chunks: Vec<RetrievalChunk>,
277    pub recent_user_actions: Vec<UserAction>,
278    pub multiple_suggestions: bool,
279    pub privacy_mode_enabled: bool,
280    pub changes_above_cursor: bool,
281}
282
283#[derive(Debug, Clone, Serialize)]
284struct FileChunk {
285    pub file_path: String,
286    pub start_line: usize,
287    pub end_line: usize,
288    pub content: String,
289    pub timestamp: Option<u64>,
290}
291
292#[derive(Debug, Clone, Serialize)]
293struct RetrievalChunk {
294    pub file_path: String,
295    pub start_line: usize,
296    pub end_line: usize,
297    pub content: String,
298    pub timestamp: u64,
299}
300
301#[derive(Debug, Clone, Serialize)]
302struct UserAction {
303    pub action_type: ActionType,
304    pub line_number: usize,
305    pub offset: usize,
306    pub file_path: String,
307    pub timestamp: u64,
308}
309
310#[allow(dead_code)]
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
312#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
313enum ActionType {
314    CursorMovement,
315    InsertChar,
316    DeleteChar,
317    InsertSelection,
318    DeleteSelection,
319}
320
321#[derive(Debug, Clone, Deserialize)]
322struct AutocompleteResponse {
323    pub autocomplete_id: String,
324    pub start_index: usize,
325    pub end_index: usize,
326    pub completion: String,
327    #[allow(dead_code)]
328    pub confidence: f64,
329    #[allow(dead_code)]
330    pub logprobs: Option<serde_json::Value>,
331    #[allow(dead_code)]
332    pub finish_reason: Option<String>,
333    #[allow(dead_code)]
334    pub elapsed_time_ms: u64,
335    #[allow(dead_code)]
336    #[serde(default, rename = "completions")]
337    pub additional_completions: Vec<AdditionalCompletion>,
338}
339
340#[allow(dead_code)]
341#[derive(Debug, Clone, Deserialize)]
342struct AdditionalCompletion {
343    pub start_index: usize,
344    pub end_index: usize,
345    pub completion: String,
346    pub confidence: f64,
347    pub autocomplete_id: String,
348    pub logprobs: Option<serde_json::Value>,
349    pub finish_reason: Option<String>,
350}
351
352fn write_event(
353    event: &cloud_llm_client::predict_edits_v3::Event,
354    f: &mut impl fmt::Write,
355) -> fmt::Result {
356    match event {
357        cloud_llm_client::predict_edits_v3::Event::BufferChange {
358            old_path,
359            path,
360            diff,
361            ..
362        } => {
363            if old_path != path {
364                // TODO confirm how to do this for sweep
365                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
366            }
367
368            if !diff.is_empty() {
369                write!(f, "File: {}:\n{}\n", path.display(), diff)?
370            }
371
372            fmt::Result::Ok(())
373        }
374    }
375}
376
377fn debug_info(cx: &gpui::App) -> Arc<str> {
378    format!(
379        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
380        version = release_channel::AppVersion::global(cx),
381        sha = release_channel::AppCommitSha::try_global(cx)
382            .map_or("unknown".to_string(), |sha| sha.full()),
383        os = client::telemetry::os_name(),
384    )
385    .into()
386}