sweep_ai.rs

  1use anyhow::{Context as _, Result};
  2use credentials_provider::CredentialsProvider;
  3use futures::{AsyncReadExt as _, FutureExt, future::Shared};
  4use gpui::{
  5    App, AppContext as _, Task,
  6    http_client::{self, AsyncBody, Method},
  7};
  8use language::{Point, ToOffset as _};
  9use lsp::DiagnosticSeverity;
 10use serde::{Deserialize, Serialize};
 11use std::{
 12    fmt::{self, Write as _},
 13    path::Path,
 14    sync::Arc,
 15    time::Instant,
 16};
 17
 18use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
 19
 20const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 21
 22pub struct SweepAi {
 23    pub api_token: Shared<Task<Option<String>>>,
 24    pub debug_info: Arc<str>,
 25}
 26
 27impl SweepAi {
 28    pub fn new(cx: &App) -> Self {
 29        SweepAi {
 30            api_token: load_api_token(cx).shared(),
 31            debug_info: debug_info(cx),
 32        }
 33    }
 34
 35    pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
 36        self.api_token = Task::ready(api_token.clone()).shared();
 37        store_api_token_in_keychain(api_token, cx)
 38    }
 39
 40    pub fn request_prediction_with_sweep(
 41        &self,
 42        inputs: EditPredictionModelInput,
 43        cx: &mut App,
 44    ) -> Task<Result<Option<EditPredictionResult>>> {
 45        let debug_info = self.debug_info.clone();
 46        let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
 47            return Task::ready(Ok(None));
 48        };
 49        let full_path: Arc<Path> = inputs
 50            .snapshot
 51            .file()
 52            .map(|file| file.full_path(cx))
 53            .unwrap_or_else(|| "untitled".into())
 54            .into();
 55
 56        let project_file = project::File::from_dyn(inputs.snapshot.file());
 57        let repo_name = project_file
 58            .map(|file| file.worktree.read(cx).root_name_str())
 59            .unwrap_or("untitled")
 60            .into();
 61        let offset = inputs.position.to_offset(&inputs.snapshot);
 62
 63        let recent_buffers = inputs.recent_paths.iter().cloned();
 64        let http_client = cx.http_client();
 65
 66        let recent_buffer_snapshots = recent_buffers
 67            .filter_map(|project_path| {
 68                let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
 69                if inputs.buffer == buffer {
 70                    None
 71                } else {
 72                    Some(buffer.read(cx).snapshot())
 73                }
 74            })
 75            .take(3)
 76            .collect::<Vec<_>>();
 77
 78        let buffer_snapshotted_at = Instant::now();
 79
 80        let result = cx.background_spawn(async move {
 81            let text = inputs.snapshot.text();
 82
 83            let mut recent_changes = String::new();
 84            for event in &inputs.events {
 85                write_event(event.as_ref(), &mut recent_changes).unwrap();
 86            }
 87
 88            let mut file_chunks = recent_buffer_snapshots
 89                .into_iter()
 90                .map(|snapshot| {
 91                    let end_point = Point::new(30, 0).min(snapshot.max_point());
 92                    FileChunk {
 93                        content: snapshot.text_for_range(Point::zero()..end_point).collect(),
 94                        file_path: snapshot
 95                            .file()
 96                            .map(|f| f.path().as_unix_str())
 97                            .unwrap_or("untitled")
 98                            .to_string(),
 99                        start_line: 0,
100                        end_line: end_point.row as usize,
101                        timestamp: snapshot.file().and_then(|file| {
102                            Some(
103                                file.disk_state()
104                                    .mtime()?
105                                    .to_seconds_and_nanos_for_persistence()?
106                                    .0,
107                            )
108                        }),
109                    }
110                })
111                .collect::<Vec<_>>();
112
113            let retrieval_chunks = inputs
114                .related_files
115                .iter()
116                .flat_map(|related_file| {
117                    related_file.excerpts.iter().map(|excerpt| FileChunk {
118                        file_path: related_file.path.to_string_lossy().to_string(),
119                        start_line: excerpt.row_range.start as usize,
120                        end_line: excerpt.row_range.end as usize,
121                        content: excerpt.text.to_string(),
122                        timestamp: None,
123                    })
124                })
125                .collect();
126
127            let diagnostic_entries = inputs
128                .snapshot
129                .diagnostics_in_range(inputs.diagnostic_search_range, false);
130            let mut diagnostic_content = String::new();
131            let mut diagnostic_count = 0;
132
133            for entry in diagnostic_entries {
134                let start_point: Point = entry.range.start;
135
136                let severity = match entry.diagnostic.severity {
137                    DiagnosticSeverity::ERROR => "error",
138                    DiagnosticSeverity::WARNING => "warning",
139                    DiagnosticSeverity::INFORMATION => "info",
140                    DiagnosticSeverity::HINT => "hint",
141                    _ => continue,
142                };
143
144                diagnostic_count += 1;
145
146                writeln!(
147                    &mut diagnostic_content,
148                    "{} at line {}: {}",
149                    severity,
150                    start_point.row + 1,
151                    entry.diagnostic.message
152                )?;
153            }
154
155            if !diagnostic_content.is_empty() {
156                file_chunks.push(FileChunk {
157                    file_path: format!("Diagnostics for {}", full_path.display()),
158                    start_line: 0,
159                    end_line: diagnostic_count,
160                    content: diagnostic_content,
161                    timestamp: None,
162                });
163            }
164
165            let request_body = AutocompleteRequest {
166                debug_info,
167                repo_name,
168                file_path: full_path.clone(),
169                file_contents: text.clone(),
170                original_file_contents: text,
171                cursor_position: offset,
172                recent_changes: recent_changes.clone(),
173                changes_above_cursor: true,
174                multiple_suggestions: false,
175                branch: None,
176                file_chunks,
177                retrieval_chunks,
178                recent_user_actions: vec![],
179                use_bytes: true,
180                // TODO
181                privacy_mode_enabled: false,
182            };
183
184            let mut buf: Vec<u8> = Vec::new();
185            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
186            serde_json::to_writer(writer, &request_body)?;
187            let body: AsyncBody = buf.into();
188
189            let ep_inputs = zeta_prompt::ZetaPromptInput {
190                events: inputs.events,
191                related_files: inputs.related_files.clone(),
192                cursor_path: full_path.clone(),
193                cursor_excerpt: request_body.file_contents.into(),
194                // we actually don't know
195                editable_range_in_excerpt: 0..inputs.snapshot.len(),
196                cursor_offset_in_excerpt: request_body.cursor_position,
197            };
198
199            let request = http_client::Request::builder()
200                .uri(SWEEP_API_URL)
201                .header("Content-Type", "application/json")
202                .header("Authorization", format!("Bearer {}", api_token))
203                .header("Connection", "keep-alive")
204                .header("Content-Encoding", "br")
205                .method(Method::POST)
206                .body(body)?;
207
208            let mut response = http_client.send(request).await?;
209
210            let mut body: Vec<u8> = Vec::new();
211            response.body_mut().read_to_end(&mut body).await?;
212
213            let response_received_at = Instant::now();
214            if !response.status().is_success() {
215                anyhow::bail!(
216                    "Request failed with status: {:?}\nBody: {}",
217                    response.status(),
218                    String::from_utf8_lossy(&body),
219                );
220            };
221
222            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
223
224            let old_text = inputs
225                .snapshot
226                .text_for_range(response.start_index..response.end_index)
227                .collect::<String>();
228            let edits = language::text_diff(&old_text, &response.completion)
229                .into_iter()
230                .map(|(range, text)| {
231                    (
232                        inputs
233                            .snapshot
234                            .anchor_after(response.start_index + range.start)
235                            ..inputs
236                                .snapshot
237                                .anchor_before(response.start_index + range.end),
238                        text,
239                    )
240                })
241                .collect::<Vec<_>>();
242
243            anyhow::Ok((
244                response.autocomplete_id,
245                edits,
246                inputs.snapshot,
247                response_received_at,
248                ep_inputs,
249            ))
250        });
251
252        let buffer = inputs.buffer.clone();
253
254        cx.spawn(async move |cx| {
255            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
256            anyhow::Ok(Some(
257                EditPredictionResult::new(
258                    EditPredictionId(id.into()),
259                    &buffer,
260                    &old_snapshot,
261                    edits.into(),
262                    buffer_snapshotted_at,
263                    response_received_at,
264                    inputs,
265                    cx,
266                )
267                .await,
268            ))
269        })
270    }
271}
272
273pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
274pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
275
276pub fn load_api_token(cx: &App) -> Task<Option<String>> {
277    if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
278        .ok()
279        .filter(|value| !value.is_empty())
280    {
281        return Task::ready(Some(api_token));
282    }
283    let credentials_provider = <dyn CredentialsProvider>::global(cx);
284    cx.spawn(async move |cx| {
285        let (_, credentials) = credentials_provider
286            .read_credentials(SWEEP_CREDENTIALS_URL, &cx)
287            .await
288            .ok()??;
289        String::from_utf8(credentials).ok()
290    })
291}
292
293fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
294    let credentials_provider = <dyn CredentialsProvider>::global(cx);
295
296    cx.spawn(async move |cx| {
297        if let Some(api_token) = api_token {
298            credentials_provider
299                .write_credentials(
300                    SWEEP_CREDENTIALS_URL,
301                    SWEEP_CREDENTIALS_USERNAME,
302                    api_token.as_bytes(),
303                    cx,
304                )
305                .await
306                .context("Failed to save Sweep API token to system keychain")
307        } else {
308            credentials_provider
309                .delete_credentials(SWEEP_CREDENTIALS_URL, cx)
310                .await
311                .context("Failed to delete Sweep API token from system keychain")
312        }
313    })
314}
315
316#[derive(Debug, Clone, Serialize)]
317struct AutocompleteRequest {
318    pub debug_info: Arc<str>,
319    pub repo_name: String,
320    pub branch: Option<String>,
321    pub file_path: Arc<Path>,
322    pub file_contents: String,
323    pub recent_changes: String,
324    pub cursor_position: usize,
325    pub original_file_contents: String,
326    pub file_chunks: Vec<FileChunk>,
327    pub retrieval_chunks: Vec<FileChunk>,
328    pub recent_user_actions: Vec<UserAction>,
329    pub multiple_suggestions: bool,
330    pub privacy_mode_enabled: bool,
331    pub changes_above_cursor: bool,
332    pub use_bytes: bool,
333}
334
335#[derive(Debug, Clone, Serialize)]
336struct FileChunk {
337    pub file_path: String,
338    pub start_line: usize,
339    pub end_line: usize,
340    pub content: String,
341    pub timestamp: Option<u64>,
342}
343
344#[derive(Debug, Clone, Serialize)]
345struct UserAction {
346    pub action_type: ActionType,
347    pub line_number: usize,
348    pub offset: usize,
349    pub file_path: String,
350    pub timestamp: u64,
351}
352
353#[allow(dead_code)]
354#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
355#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
356enum ActionType {
357    CursorMovement,
358    InsertChar,
359    DeleteChar,
360    InsertSelection,
361    DeleteSelection,
362}
363
364#[derive(Debug, Clone, Deserialize)]
365struct AutocompleteResponse {
366    pub autocomplete_id: String,
367    pub start_index: usize,
368    pub end_index: usize,
369    pub completion: String,
370    #[allow(dead_code)]
371    pub confidence: f64,
372    #[allow(dead_code)]
373    pub logprobs: Option<serde_json::Value>,
374    #[allow(dead_code)]
375    pub finish_reason: Option<String>,
376    #[allow(dead_code)]
377    pub elapsed_time_ms: u64,
378    #[allow(dead_code)]
379    #[serde(default, rename = "completions")]
380    pub additional_completions: Vec<AdditionalCompletion>,
381}
382
383#[allow(dead_code)]
384#[derive(Debug, Clone, Deserialize)]
385struct AdditionalCompletion {
386    pub start_index: usize,
387    pub end_index: usize,
388    pub completion: String,
389    pub confidence: f64,
390    pub autocomplete_id: String,
391    pub logprobs: Option<serde_json::Value>,
392    pub finish_reason: Option<String>,
393}
394
395fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
396    match event {
397        zeta_prompt::Event::BufferChange {
398            old_path,
399            path,
400            diff,
401            ..
402        } => {
403            if old_path != path {
404                // TODO confirm how to do this for sweep
405                // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
406            }
407
408            if !diff.is_empty() {
409                write!(f, "File: {}:\n{}\n", path.display(), diff)?
410            }
411
412            fmt::Result::Ok(())
413        }
414    }
415}
416
417fn debug_info(cx: &gpui::App) -> Arc<str> {
418    format!(
419        "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
420        version = release_channel::AppVersion::global(cx),
421        sha = release_channel::AppCommitSha::try_global(cx)
422            .map_or("unknown".to_string(), |sha| sha.full()),
423        os = client::telemetry::os_name(),
424    )
425    .into()
426}