sweep_ai.rs

  1use anyhow::Result;
  2use futures::AsyncReadExt as _;
  3use gpui::{
  4    App, AppContext as _, Entity, SharedString, Task,
  5    http_client::{self, AsyncBody, Method},
  6};
  7use language::{Point, ToOffset as _};
  8use language_model::{ApiKeyState, EnvVar, env_var};
  9use lsp::DiagnosticSeverity;
 10use serde::{Deserialize, Serialize};
 11use std::{
 12    fmt::{self, Write as _},
 13    path::Path,
 14    sync::Arc,
 15    time::Instant,
 16};
 17
 18use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
 19
 20const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 21
 22pub struct SweepAi {
 23    pub api_token: Entity<ApiKeyState>,
 24    pub debug_info: Arc<str>,
 25}
 26
 27impl SweepAi {
 28    pub fn new(cx: &mut App) -> Self {
 29        SweepAi {
 30            api_token: sweep_api_token(cx),
 31            debug_info: debug_info(cx),
 32        }
 33    }
 34
 35    pub fn request_prediction_with_sweep(
 36        &self,
 37        inputs: EditPredictionModelInput,
 38        cx: &mut App,
 39    ) -> Task<Result<Option<EditPredictionResult>>> {
 40        let debug_info = self.debug_info.clone();
 41        self.api_token.update(cx, |key_state, cx| {
 42            _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
 43        });
 44        let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
 45            return Task::ready(Ok(None));
 46        };
 47        let full_path: Arc<Path> = inputs
 48            .snapshot
 49            .file()
 50            .map(|file| file.full_path(cx))
 51            .unwrap_or_else(|| "untitled".into())
 52            .into();
 53
 54        let project_file = project::File::from_dyn(inputs.snapshot.file());
 55        let repo_name = project_file
 56            .map(|file| file.worktree.read(cx).root_name_str())
 57            .unwrap_or("untitled")
 58            .into();
 59        let offset = inputs.position.to_offset(&inputs.snapshot);
 60
 61        let recent_buffers = inputs.recent_paths.iter().cloned();
 62        let http_client = cx.http_client();
 63
 64        let recent_buffer_snapshots = recent_buffers
 65            .filter_map(|project_path| {
 66                let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
 67                if inputs.buffer == buffer {
 68                    None
 69                } else {
 70                    Some(buffer.read(cx).snapshot())
 71                }
 72            })
 73            .take(3)
 74            .collect::<Vec<_>>();
 75
 76        let buffer_snapshotted_at = Instant::now();
 77
 78        let result = cx.background_spawn(async move {
 79            let text = inputs.snapshot.text();
 80
 81            let mut recent_changes = String::new();
 82            for event in &inputs.events {
 83                write_event(event.as_ref(), &mut recent_changes).unwrap();
 84            }
 85
 86            let mut file_chunks = recent_buffer_snapshots
 87                .into_iter()
 88                .map(|snapshot| {
 89                    let end_point = Point::new(30, 0).min(snapshot.max_point());
 90                    FileChunk {
 91                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
 92                        file_path: snapshot
 93                            .file()
 94                            .map(|f| f.path().as_unix_str())
 95                            .unwrap_or("untitled")
 96                            .to_string(),
 97                        start_line: 0,
 98                        end_line: end_point.row as usize,
 99                        timestamp: snapshot.file().and_then(|file| {
100                            Some(
101                                file.disk_state()
102                                    .mtime()?
103                                    .to_seconds_and_nanos_for_persistence()?
104                                    .0,
105                            )
106                        }),
107                    }
108                })
109                .collect::<Vec<_>>();
110
111            let retrieval_chunks = inputs
112                .related_files
113                .iter()
114                .flat_map(|related_file| {
115                    related_file.excerpts.iter().map(|excerpt| FileChunk {
116                        file_path: related_file.path.to_string_lossy().to_string(),
117                        start_line: excerpt.row_range.start as usize,
118                        end_line: excerpt.row_range.end as usize,
119                        content: excerpt.text.to_string(),
120                        timestamp: None,
121                    })
122                })
123                .collect();
124
125            let diagnostic_entries = inputs
126                .snapshot
127                .diagnostics_in_range(inputs.diagnostic_search_range, false);
128            let mut diagnostic_content = String::new();
129            let mut diagnostic_count = 0;
130
131            for entry in diagnostic_entries {
132                let start_point: Point = entry.range.start;
133
134                let severity = match entry.diagnostic.severity {
135                    DiagnosticSeverity::ERROR => "error",
136                    DiagnosticSeverity::WARNING => "warning",
137                    DiagnosticSeverity::INFORMATION => "info",
138                    DiagnosticSeverity::HINT => "hint",
139                    _ => continue,
140                };
141
142                diagnostic_count += 1;
143
144                writeln!(
145                    &mut diagnostic_content,
146                    "{} at line {}: {}",
147                    severity,
148                    start_point.row + 1,
149                    entry.diagnostic.message
150                )?;
151            }
152
153            if !diagnostic_content.is_empty() {
154                file_chunks.push(FileChunk {
155                    file_path: format!("Diagnostics for {}", full_path.display()),
156                    start_line: 0,
157                    end_line: diagnostic_count,
158                    content: diagnostic_content,
159                    timestamp: None,
160                });
161            }
162
163            let request_body = AutocompleteRequest {
164                debug_info,
165                repo_name,
166                file_path: full_path.clone(),
167                file_contents: text.clone(),
168                original_file_contents: text,
169                cursor_position: offset,
170                recent_changes: recent_changes.clone(),
171                changes_above_cursor: true,
172                multiple_suggestions: false,
173                branch: None,
174                file_chunks,
175                retrieval_chunks,
176                recent_user_actions: vec![],
177                use_bytes: true,
178                // TODO
179                privacy_mode_enabled: false,
180            };
181
182            let mut buf: Vec<u8> = Vec::new();
183            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
184            serde_json::to_writer(writer, &request_body)?;
185            let body: AsyncBody = buf.into();
186
187            let ep_inputs = zeta_prompt::ZetaPromptInput {
188                events: inputs.events,
189                related_files: inputs.related_files.clone(),
190                cursor_path: full_path.clone(),
191                cursor_excerpt: request_body.file_contents.into(),
192                // we actually don't know
193                editable_range_in_excerpt: 0..inputs.snapshot.len(),
194                cursor_offset_in_excerpt: request_body.cursor_position,
195            };
196
197            let request = http_client::Request::builder()
198                .uri(SWEEP_API_URL)
199                .header("Content-Type", "application/json")
200                .header("Authorization", format!("Bearer {}", api_token))
201                .header("Connection", "keep-alive")
202                .header("Content-Encoding", "br")
203                .method(Method::POST)
204                .body(body)?;
205
206            let mut response = http_client.send(request).await?;
207
208            let mut body: Vec<u8> = Vec::new();
209            response.body_mut().read_to_end(&mut body).await?;
210
211            let response_received_at = Instant::now();
212            if !response.status().is_success() {
213                anyhow::bail!(
214                    "Request failed with status: {:?}\nBody: {}",
215                    response.status(),
216                    String::from_utf8_lossy(&body),
217                );
218            };
219
220            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
221
222            let old_text = inputs
223                .snapshot
224                .text_for_range(response.start_index..response.end_index)
225                .collect::<String>();
226            let edits = language::text_diff(&old_text, &response.completion)
227                .into_iter()
228                .map(|(range, text)| {
229                    (
230                        inputs
231                            .snapshot
232                            .anchor_after(response.start_index + range.start)
233                            ..inputs
234                                .snapshot
235                                .anchor_before(response.start_index + range.end),
236                        text,
237                    )
238                })
239                .collect::<Vec<_>>();
240
241            anyhow::Ok((
242                response.autocomplete_id,
243                edits,
244                inputs.snapshot,
245                response_received_at,
246                ep_inputs,
247            ))
248        });
249
250        let buffer = inputs.buffer.clone();
251
252        cx.spawn(async move |cx| {
253            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
254            anyhow::Ok(Some(
255                EditPredictionResult::new(
256                    EditPredictionId(id.into()),
257                    &buffer,
258                    &old_snapshot,
259                    edits.into(),
260                    buffer_snapshotted_at,
261                    response_received_at,
262                    inputs,
263                    cx,
264                )
265                .await,
266            ))
267        })
268    }
269}
270
271pub const SWEEP_CREDENTIALS_URL: SharedString =
272    SharedString::new_static("https://autocomplete.sweep.dev");
273pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
274pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
275pub static SWEEP_API_KEY: std::sync::OnceLock<Entity<ApiKeyState>> = std::sync::OnceLock::new();
276
277pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
278    SWEEP_API_KEY
279        .get_or_init(|| {
280            cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()))
281        })
282        .clone()
283}
284
285#[derive(Debug, Clone, Serialize)]
286struct AutocompleteRequest {
287    pub debug_info: Arc<str>,
288    pub repo_name: String,
289    pub branch: Option<String>,
290    pub file_path: Arc<Path>,
291    pub file_contents: String,
292    pub recent_changes: String,
293    pub cursor_position: usize,
294    pub original_file_contents: String,
295    pub file_chunks: Vec<FileChunk>,
296    pub retrieval_chunks: Vec<FileChunk>,
297    pub recent_user_actions: Vec<UserAction>,
298    pub multiple_suggestions: bool,
299    pub privacy_mode_enabled: bool,
300    pub changes_above_cursor: bool,
301    pub use_bytes: bool,
302}
303
304#[derive(Debug, Clone, Serialize)]
305struct FileChunk {
306    pub file_path: String,
307    pub start_line: usize,
308    pub end_line: usize,
309    pub content: String,
310    pub timestamp: Option<u64>,
311}
312
313#[derive(Debug, Clone, Serialize)]
314struct UserAction {
315    pub action_type: ActionType,
316    pub line_number: usize,
317    pub offset: usize,
318    pub file_path: String,
319    pub timestamp: u64,
320}
321
322#[allow(dead_code)]
323#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
324#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
325enum ActionType {
326    CursorMovement,
327    InsertChar,
328    DeleteChar,
329    InsertSelection,
330    DeleteSelection,
331}
332
333#[derive(Debug, Clone, Deserialize)]
334struct AutocompleteResponse {
335    pub autocomplete_id: String,
336    pub start_index: usize,
337    pub end_index: usize,
338    pub completion: String,
339    #[allow(dead_code)]
340    pub confidence: f64,
341    #[allow(dead_code)]
342    pub logprobs: Option<serde_json::Value>,
343    #[allow(dead_code)]
344    pub finish_reason: Option<String>,
345    #[allow(dead_code)]
346    pub elapsed_time_ms: u64,
347    #[allow(dead_code)]
348    #[serde(default, rename = "completions")]
349    pub additional_completions: Vec<AdditionalCompletion>,
350}
351
352#[allow(dead_code)]
353#[derive(Debug, Clone, Deserialize)]
354struct AdditionalCompletion {
355    pub start_index: usize,
356    pub end_index: usize,
357    pub completion: String,
358    pub confidence: f64,
359    pub autocomplete_id: String,
360    pub logprobs: Option<serde_json::Value>,
361    pub finish_reason: Option<String>,
362}
363
364fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
365    match event {
366        zeta_prompt::Event::BufferChange {
367            old_path,
368            path,
369            diff,
370            ..
371        } => {
372            if old_path != path {
373                // TODO confirm how to do this for sweep
374                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
375            }
376
377            if !diff.is_empty() {
378                write!(f, "File: {}:\n{}\n", path.display(), diff)?
379            }
380
381            fmt::Result::Ok(())
382        }
383    }
384}
385
386fn debug_info(cx: &gpui::App) -> Arc<str> {
387    format!(
388        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
389        version = release_channel::AppVersion::global(cx),
390        sha = release_channel::AppCommitSha::try_global(cx)
391            .map_or("unknown".to_string(), |sha| sha.full()),
392        os = client::telemetry::os_name(),
393    )
394    .into()
395}