1use anyhow::{Context as _, Result};
2use cloud_llm_client::predict_edits_v3::Event;
3use credentials_provider::CredentialsProvider;
4use futures::{AsyncReadExt as _, FutureExt, future::Shared};
5use gpui::{
6 App, AppContext as _, Entity, Task,
7 http_client::{self, AsyncBody, Method},
8};
9use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
10use lsp::DiagnosticSeverity;
11use project::{Project, ProjectPath};
12use serde::{Deserialize, Serialize};
13use std::{
14 collections::VecDeque,
15 fmt::{self, Write as _},
16 ops::Range,
17 path::Path,
18 sync::Arc,
19 time::Instant,
20};
21
22use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
23
24const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
25
26pub struct SweepAi {
27 pub api_token: Shared<Task<Option<String>>>,
28 pub debug_info: Arc<str>,
29}
30
31impl SweepAi {
32 pub fn new(cx: &App) -> Self {
33 SweepAi {
34 api_token: load_api_token(cx).shared(),
35 debug_info: debug_info(cx),
36 }
37 }
38
39 pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
40 self.api_token = Task::ready(api_token.clone()).shared();
41 store_api_token_in_keychain(api_token, cx)
42 }
43
44 pub fn request_prediction_with_sweep(
45 &self,
46 project: &Entity<Project>,
47 active_buffer: &Entity<Buffer>,
48 snapshot: BufferSnapshot,
49 position: language::Anchor,
50 events: Vec<Arc<Event>>,
51 recent_paths: &VecDeque<ProjectPath>,
52 diagnostic_search_range: Range<Point>,
53 cx: &mut App,
54 ) -> Task<Result<Option<EditPredictionResult>>> {
55 let debug_info = self.debug_info.clone();
56 let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
57 return Task::ready(Ok(None));
58 };
59 let full_path: Arc<Path> = snapshot
60 .file()
61 .map(|file| file.full_path(cx))
62 .unwrap_or_else(|| "untitled".into())
63 .into();
64
65 let project_file = project::File::from_dyn(snapshot.file());
66 let repo_name = project_file
67 .map(|file| file.worktree.read(cx).root_name_str())
68 .unwrap_or("untitled")
69 .into();
70 let offset = position.to_offset(&snapshot);
71
72 let recent_buffers = recent_paths.iter().cloned();
73 let http_client = cx.http_client();
74
75 let recent_buffer_snapshots = recent_buffers
76 .filter_map(|project_path| {
77 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
78 if active_buffer == &buffer {
79 None
80 } else {
81 Some(buffer.read(cx).snapshot())
82 }
83 })
84 .take(3)
85 .collect::<Vec<_>>();
86
87 let cursor_point = position.to_point(&snapshot);
88 let buffer_snapshotted_at = Instant::now();
89
90 let result = cx.background_spawn(async move {
91 let text = snapshot.text();
92
93 let mut recent_changes = String::new();
94 for event in &events {
95 write_event(event.as_ref(), &mut recent_changes).unwrap();
96 }
97
98 let mut file_chunks = recent_buffer_snapshots
99 .into_iter()
100 .map(|snapshot| {
101 let end_point = Point::new(30, 0).min(snapshot.max_point());
102 FileChunk {
103 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
104 file_path: snapshot
105 .file()
106 .map(|f| f.path().as_unix_str())
107 .unwrap_or("untitled")
108 .to_string(),
109 start_line: 0,
110 end_line: end_point.row as usize,
111 timestamp: snapshot.file().and_then(|file| {
112 Some(
113 file.disk_state()
114 .mtime()?
115 .to_seconds_and_nanos_for_persistence()?
116 .0,
117 )
118 }),
119 }
120 })
121 .collect::<Vec<_>>();
122
123 let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
124 let mut diagnostic_content = String::new();
125 let mut diagnostic_count = 0;
126
127 for entry in diagnostic_entries {
128 let start_point: Point = entry.range.start;
129
130 let severity = match entry.diagnostic.severity {
131 DiagnosticSeverity::ERROR => "error",
132 DiagnosticSeverity::WARNING => "warning",
133 DiagnosticSeverity::INFORMATION => "info",
134 DiagnosticSeverity::HINT => "hint",
135 _ => continue,
136 };
137
138 diagnostic_count += 1;
139
140 writeln!(
141 &mut diagnostic_content,
142 "{} at line {}: {}",
143 severity,
144 start_point.row + 1,
145 entry.diagnostic.message
146 )?;
147 }
148
149 if !diagnostic_content.is_empty() {
150 file_chunks.push(FileChunk {
151 file_path: format!("Diagnostics for {}", full_path.display()),
152 start_line: 0,
153 end_line: diagnostic_count,
154 content: diagnostic_content,
155 timestamp: None,
156 });
157 }
158
159 let request_body = AutocompleteRequest {
160 debug_info,
161 repo_name,
162 file_path: full_path.clone(),
163 file_contents: text.clone(),
164 original_file_contents: text,
165 cursor_position: offset,
166 recent_changes: recent_changes.clone(),
167 changes_above_cursor: true,
168 multiple_suggestions: false,
169 branch: None,
170 file_chunks,
171 retrieval_chunks: vec![],
172 recent_user_actions: vec![],
173 use_bytes: true,
174 // TODO
175 privacy_mode_enabled: false,
176 };
177
178 let mut buf: Vec<u8> = Vec::new();
179 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
180 serde_json::to_writer(writer, &request_body)?;
181 let body: AsyncBody = buf.into();
182
183 let inputs = EditPredictionInputs {
184 events,
185 included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
186 path: full_path.clone(),
187 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
188 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
189 start_line: cloud_llm_client::predict_edits_v3::Line(0),
190 text: request_body.file_contents.into(),
191 }],
192 }],
193 cursor_point: cloud_llm_client::predict_edits_v3::Point {
194 column: cursor_point.column,
195 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
196 },
197 cursor_path: full_path.clone(),
198 };
199
200 let request = http_client::Request::builder()
201 .uri(SWEEP_API_URL)
202 .header("Content-Type", "application/json")
203 .header("Authorization", format!("Bearer {}", api_token))
204 .header("Connection", "keep-alive")
205 .header("Content-Encoding", "br")
206 .method(Method::POST)
207 .body(body)?;
208
209 let mut response = http_client.send(request).await?;
210
211 let mut body: Vec<u8> = Vec::new();
212 response.body_mut().read_to_end(&mut body).await?;
213
214 let response_received_at = Instant::now();
215 if !response.status().is_success() {
216 anyhow::bail!(
217 "Request failed with status: {:?}\nBody: {}",
218 response.status(),
219 String::from_utf8_lossy(&body),
220 );
221 };
222
223 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
224
225 let old_text = snapshot
226 .text_for_range(response.start_index..response.end_index)
227 .collect::<String>();
228 let edits = language::text_diff(&old_text, &response.completion)
229 .into_iter()
230 .map(|(range, text)| {
231 (
232 snapshot.anchor_after(response.start_index + range.start)
233 ..snapshot.anchor_before(response.start_index + range.end),
234 text,
235 )
236 })
237 .collect::<Vec<_>>();
238
239 anyhow::Ok((
240 response.autocomplete_id,
241 edits,
242 snapshot,
243 response_received_at,
244 inputs,
245 ))
246 });
247
248 let buffer = active_buffer.clone();
249
250 cx.spawn(async move |cx| {
251 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
252 anyhow::Ok(Some(
253 EditPredictionResult::new(
254 EditPredictionId(id.into()),
255 &buffer,
256 &old_snapshot,
257 edits.into(),
258 buffer_snapshotted_at,
259 response_received_at,
260 inputs,
261 cx,
262 )
263 .await,
264 ))
265 })
266 }
267}
268
269pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
270pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
271
272pub fn load_api_token(cx: &App) -> Task<Option<String>> {
273 if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
274 .ok()
275 .filter(|value| !value.is_empty())
276 {
277 return Task::ready(Some(api_token));
278 }
279 let credentials_provider = <dyn CredentialsProvider>::global(cx);
280 cx.spawn(async move |cx| {
281 let (_, credentials) = credentials_provider
282 .read_credentials(SWEEP_CREDENTIALS_URL, &cx)
283 .await
284 .ok()??;
285 String::from_utf8(credentials).ok()
286 })
287}
288
289fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
290 let credentials_provider = <dyn CredentialsProvider>::global(cx);
291
292 cx.spawn(async move |cx| {
293 if let Some(api_token) = api_token {
294 credentials_provider
295 .write_credentials(
296 SWEEP_CREDENTIALS_URL,
297 SWEEP_CREDENTIALS_USERNAME,
298 api_token.as_bytes(),
299 cx,
300 )
301 .await
302 .context("Failed to save Sweep API token to system keychain")
303 } else {
304 credentials_provider
305 .delete_credentials(SWEEP_CREDENTIALS_URL, cx)
306 .await
307 .context("Failed to delete Sweep API token from system keychain")
308 }
309 })
310}
311
312#[derive(Debug, Clone, Serialize)]
313struct AutocompleteRequest {
314 pub debug_info: Arc<str>,
315 pub repo_name: String,
316 pub branch: Option<String>,
317 pub file_path: Arc<Path>,
318 pub file_contents: String,
319 pub recent_changes: String,
320 pub cursor_position: usize,
321 pub original_file_contents: String,
322 pub file_chunks: Vec<FileChunk>,
323 pub retrieval_chunks: Vec<RetrievalChunk>,
324 pub recent_user_actions: Vec<UserAction>,
325 pub multiple_suggestions: bool,
326 pub privacy_mode_enabled: bool,
327 pub changes_above_cursor: bool,
328 pub use_bytes: bool,
329}
330
331#[derive(Debug, Clone, Serialize)]
332struct FileChunk {
333 pub file_path: String,
334 pub start_line: usize,
335 pub end_line: usize,
336 pub content: String,
337 pub timestamp: Option<u64>,
338}
339
340#[derive(Debug, Clone, Serialize)]
341struct RetrievalChunk {
342 pub file_path: String,
343 pub start_line: usize,
344 pub end_line: usize,
345 pub content: String,
346 pub timestamp: u64,
347}
348
349#[derive(Debug, Clone, Serialize)]
350struct UserAction {
351 pub action_type: ActionType,
352 pub line_number: usize,
353 pub offset: usize,
354 pub file_path: String,
355 pub timestamp: u64,
356}
357
358#[allow(dead_code)]
359#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
360#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
361enum ActionType {
362 CursorMovement,
363 InsertChar,
364 DeleteChar,
365 InsertSelection,
366 DeleteSelection,
367}
368
369#[derive(Debug, Clone, Deserialize)]
370struct AutocompleteResponse {
371 pub autocomplete_id: String,
372 pub start_index: usize,
373 pub end_index: usize,
374 pub completion: String,
375 #[allow(dead_code)]
376 pub confidence: f64,
377 #[allow(dead_code)]
378 pub logprobs: Option<serde_json::Value>,
379 #[allow(dead_code)]
380 pub finish_reason: Option<String>,
381 #[allow(dead_code)]
382 pub elapsed_time_ms: u64,
383 #[allow(dead_code)]
384 #[serde(default, rename = "completions")]
385 pub additional_completions: Vec<AdditionalCompletion>,
386}
387
388#[allow(dead_code)]
389#[derive(Debug, Clone, Deserialize)]
390struct AdditionalCompletion {
391 pub start_index: usize,
392 pub end_index: usize,
393 pub completion: String,
394 pub confidence: f64,
395 pub autocomplete_id: String,
396 pub logprobs: Option<serde_json::Value>,
397 pub finish_reason: Option<String>,
398}
399
400fn write_event(
401 event: &cloud_llm_client::predict_edits_v3::Event,
402 f: &mut impl fmt::Write,
403) -> fmt::Result {
404 match event {
405 cloud_llm_client::predict_edits_v3::Event::BufferChange {
406 old_path,
407 path,
408 diff,
409 ..
410 } => {
411 if old_path != path {
412 // TODO confirm how to do this for sweep
413 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
414 }
415
416 if !diff.is_empty() {
417 write!(f, "File: {}:\n{}\n", path.display(), diff)?
418 }
419
420 fmt::Result::Ok(())
421 }
422 }
423}
424
425fn debug_info(cx: &gpui::App) -> Arc<str> {
426 format!(
427 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
428 version = release_channel::AppVersion::global(cx),
429 sha = release_channel::AppCommitSha::try_global(cx)
430 .map_or("unknown".to_string(), |sha| sha.full()),
431 os = client::telemetry::os_name(),
432 )
433 .into()
434}