sweep_ai.rs

  1use anyhow::Result;
  2use futures::AsyncReadExt as _;
  3use gpui::{
  4    App, AppContext as _, Entity, Global, SharedString, Task,
  5    http_client::{self, AsyncBody, Method},
  6};
  7use language::{Point, ToOffset as _};
  8use language_model::{ApiKeyState, EnvVar, env_var};
  9use lsp::DiagnosticSeverity;
 10use serde::{Deserialize, Serialize};
 11use std::{
 12    fmt::{self, Write as _},
 13    path::Path,
 14    sync::Arc,
 15    time::Instant,
 16};
 17
 18use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
 19
 20const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 21
 22pub struct SweepAi {
 23    pub api_token: Entity<ApiKeyState>,
 24    pub debug_info: Arc<str>,
 25}
 26
 27impl SweepAi {
 28    pub fn new(cx: &mut App) -> Self {
 29        SweepAi {
 30            api_token: sweep_api_token(cx),
 31            debug_info: debug_info(cx),
 32        }
 33    }
 34
 35    pub fn request_prediction_with_sweep(
 36        &self,
 37        inputs: EditPredictionModelInput,
 38        cx: &mut App,
 39    ) -> Task<Result<Option<EditPredictionResult>>> {
 40        let debug_info = self.debug_info.clone();
 41        self.api_token.update(cx, |key_state, cx| {
 42            _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
 43        });
 44        let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
 45            return Task::ready(Ok(None));
 46        };
 47        let full_path: Arc<Path> = inputs
 48            .snapshot
 49            .file()
 50            .map(|file| file.full_path(cx))
 51            .unwrap_or_else(|| "untitled".into())
 52            .into();
 53
 54        let project_file = project::File::from_dyn(inputs.snapshot.file());
 55        let repo_name = project_file
 56            .map(|file| file.worktree.read(cx).root_name_str())
 57            .unwrap_or("untitled")
 58            .into();
 59        let offset = inputs.position.to_offset(&inputs.snapshot);
 60
 61        let recent_buffers = inputs.recent_paths.iter().cloned();
 62        let http_client = cx.http_client();
 63
 64        let recent_buffer_snapshots = recent_buffers
 65            .filter_map(|project_path| {
 66                let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
 67                if inputs.buffer == buffer {
 68                    None
 69                } else {
 70                    Some(buffer.read(cx).snapshot())
 71                }
 72            })
 73            .take(3)
 74            .collect::<Vec<_>>();
 75
 76        let buffer_snapshotted_at = Instant::now();
 77
 78        let result = cx.background_spawn(async move {
 79            let text = inputs.snapshot.text();
 80
 81            let mut recent_changes = String::new();
 82            for event in &inputs.events {
 83                write_event(event.as_ref(), &mut recent_changes).unwrap();
 84            }
 85
 86            let mut file_chunks = recent_buffer_snapshots
 87                .into_iter()
 88                .map(|snapshot| {
 89                    let end_point = Point::new(30, 0).min(snapshot.max_point());
 90                    FileChunk {
 91                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
 92                        file_path: snapshot
 93                            .file()
 94                            .map(|f| f.path().as_unix_str())
 95                            .unwrap_or("untitled")
 96                            .to_string(),
 97                        start_line: 0,
 98                        end_line: end_point.row as usize,
 99                        timestamp: snapshot.file().and_then(|file| {
100                            Some(
101                                file.disk_state()
102                                    .mtime()?
103                                    .to_seconds_and_nanos_for_persistence()?
104                                    .0,
105                            )
106                        }),
107                    }
108                })
109                .collect::<Vec<_>>();
110
111            let retrieval_chunks = inputs
112                .related_files
113                .iter()
114                .flat_map(|related_file| {
115                    related_file.excerpts.iter().map(|excerpt| FileChunk {
116                        file_path: related_file.path.to_string_lossy().to_string(),
117                        start_line: excerpt.row_range.start as usize,
118                        end_line: excerpt.row_range.end as usize,
119                        content: excerpt.text.to_string(),
120                        timestamp: None,
121                    })
122                })
123                .collect();
124
125            let diagnostic_entries = inputs
126                .snapshot
127                .diagnostics_in_range(inputs.diagnostic_search_range, false);
128            let mut diagnostic_content = String::new();
129            let mut diagnostic_count = 0;
130
131            for entry in diagnostic_entries {
132                let start_point: Point = entry.range.start;
133
134                let severity = match entry.diagnostic.severity {
135                    DiagnosticSeverity::ERROR => "error",
136                    DiagnosticSeverity::WARNING => "warning",
137                    DiagnosticSeverity::INFORMATION => "info",
138                    DiagnosticSeverity::HINT => "hint",
139                    _ => continue,
140                };
141
142                diagnostic_count += 1;
143
144                writeln!(
145                    &mut diagnostic_content,
146                    "{} at line {}: {}",
147                    severity,
148                    start_point.row + 1,
149                    entry.diagnostic.message
150                )?;
151            }
152
153            if !diagnostic_content.is_empty() {
154                file_chunks.push(FileChunk {
155                    file_path: format!("Diagnostics for {}", full_path.display()),
156                    start_line: 0,
157                    end_line: diagnostic_count,
158                    content: diagnostic_content,
159                    timestamp: None,
160                });
161            }
162
163            let request_body = AutocompleteRequest {
164                debug_info,
165                repo_name,
166                file_path: full_path.clone(),
167                file_contents: text.clone(),
168                original_file_contents: text,
169                cursor_position: offset,
170                recent_changes: recent_changes.clone(),
171                changes_above_cursor: true,
172                multiple_suggestions: false,
173                branch: None,
174                file_chunks,
175                retrieval_chunks,
176                recent_user_actions: vec![],
177                use_bytes: true,
178                // TODO
179                privacy_mode_enabled: false,
180            };
181
182            let mut buf: Vec<u8> = Vec::new();
183            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
184            serde_json::to_writer(writer, &request_body)?;
185            let body: AsyncBody = buf.into();
186
187            let ep_inputs = zeta_prompt::ZetaPromptInput {
188                events: inputs.events,
189                related_files: inputs.related_files.clone(),
190                cursor_path: full_path.clone(),
191                cursor_excerpt: request_body.file_contents.into(),
192                // we actually don't know
193                editable_range_in_excerpt: 0..inputs.snapshot.len(),
194                cursor_offset_in_excerpt: request_body.cursor_position,
195            };
196
197            let request = http_client::Request::builder()
198                .uri(SWEEP_API_URL)
199                .header("Content-Type", "application/json")
200                .header("Authorization", format!("Bearer {}", api_token))
201                .header("Connection", "keep-alive")
202                .header("Content-Encoding", "br")
203                .method(Method::POST)
204                .body(body)?;
205
206            let mut response = http_client.send(request).await?;
207
208            let mut body: Vec<u8> = Vec::new();
209            response.body_mut().read_to_end(&mut body).await?;
210
211            let response_received_at = Instant::now();
212            if !response.status().is_success() {
213                anyhow::bail!(
214                    "Request failed with status: {:?}\nBody: {}",
215                    response.status(),
216                    String::from_utf8_lossy(&body),
217                );
218            };
219
220            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
221
222            let old_text = inputs
223                .snapshot
224                .text_for_range(response.start_index..response.end_index)
225                .collect::<String>();
226            let edits = language::text_diff(&old_text, &response.completion)
227                .into_iter()
228                .map(|(range, text)| {
229                    (
230                        inputs
231                            .snapshot
232                            .anchor_after(response.start_index + range.start)
233                            ..inputs
234                                .snapshot
235                                .anchor_before(response.start_index + range.end),
236                        text,
237                    )
238                })
239                .collect::<Vec<_>>();
240
241            anyhow::Ok((
242                response.autocomplete_id,
243                edits,
244                inputs.snapshot,
245                response_received_at,
246                ep_inputs,
247            ))
248        });
249
250        let buffer = inputs.buffer.clone();
251
252        cx.spawn(async move |cx| {
253            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
254            anyhow::Ok(Some(
255                EditPredictionResult::new(
256                    EditPredictionId(id.into()),
257                    &buffer,
258                    &old_snapshot,
259                    edits.into(),
260                    buffer_snapshotted_at,
261                    response_received_at,
262                    inputs,
263                    cx,
264                )
265                .await,
266            ))
267        })
268    }
269}
270
271pub const SWEEP_CREDENTIALS_URL: SharedString =
272    SharedString::new_static("https://autocomplete.sweep.dev");
273pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
274pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
275
276struct GlobalSweepApiKey(Entity<ApiKeyState>);
277
278impl Global for GlobalSweepApiKey {}
279
280pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
281    if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
282        return global.0.clone();
283    }
284    let entity =
285        cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
286    cx.set_global(GlobalSweepApiKey(entity.clone()));
287    entity
288}
289
290pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
291    sweep_api_token(cx).update(cx, |key_state, cx| {
292        key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
293    })
294}
295
296#[derive(Debug, Clone, Serialize)]
297struct AutocompleteRequest {
298    pub debug_info: Arc<str>,
299    pub repo_name: String,
300    pub branch: Option<String>,
301    pub file_path: Arc<Path>,
302    pub file_contents: String,
303    pub recent_changes: String,
304    pub cursor_position: usize,
305    pub original_file_contents: String,
306    pub file_chunks: Vec<FileChunk>,
307    pub retrieval_chunks: Vec<FileChunk>,
308    pub recent_user_actions: Vec<UserAction>,
309    pub multiple_suggestions: bool,
310    pub privacy_mode_enabled: bool,
311    pub changes_above_cursor: bool,
312    pub use_bytes: bool,
313}
314
315#[derive(Debug, Clone, Serialize)]
316struct FileChunk {
317    pub file_path: String,
318    pub start_line: usize,
319    pub end_line: usize,
320    pub content: String,
321    pub timestamp: Option<u64>,
322}
323
324#[derive(Debug, Clone, Serialize)]
325struct UserAction {
326    pub action_type: ActionType,
327    pub line_number: usize,
328    pub offset: usize,
329    pub file_path: String,
330    pub timestamp: u64,
331}
332
333#[allow(dead_code)]
334#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
335#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
336enum ActionType {
337    CursorMovement,
338    InsertChar,
339    DeleteChar,
340    InsertSelection,
341    DeleteSelection,
342}
343
344#[derive(Debug, Clone, Deserialize)]
345struct AutocompleteResponse {
346    pub autocomplete_id: String,
347    pub start_index: usize,
348    pub end_index: usize,
349    pub completion: String,
350    #[allow(dead_code)]
351    pub confidence: f64,
352    #[allow(dead_code)]
353    pub logprobs: Option<serde_json::Value>,
354    #[allow(dead_code)]
355    pub finish_reason: Option<String>,
356    #[allow(dead_code)]
357    pub elapsed_time_ms: u64,
358    #[allow(dead_code)]
359    #[serde(default, rename = "completions")]
360    pub additional_completions: Vec<AdditionalCompletion>,
361}
362
363#[allow(dead_code)]
364#[derive(Debug, Clone, Deserialize)]
365struct AdditionalCompletion {
366    pub start_index: usize,
367    pub end_index: usize,
368    pub completion: String,
369    pub confidence: f64,
370    pub autocomplete_id: String,
371    pub logprobs: Option<serde_json::Value>,
372    pub finish_reason: Option<String>,
373}
374
375fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
376    match event {
377        zeta_prompt::Event::BufferChange {
378            old_path,
379            path,
380            diff,
381            ..
382        } => {
383            if old_path != path {
384                // TODO confirm how to do this for sweep
385                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
386            }
387
388            if !diff.is_empty() {
389                write!(f, "File: {}:\n{}\n", path.display(), diff)?
390            }
391
392            fmt::Result::Ok(())
393        }
394    }
395}
396
397fn debug_info(cx: &gpui::App) -> Arc<str> {
398    format!(
399        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
400        version = release_channel::AppVersion::global(cx),
401        sha = release_channel::AppCommitSha::try_global(cx)
402            .map_or("unknown".to_string(), |sha| sha.full()),
403        os = client::telemetry::os_name(),
404    )
405    .into()
406}