1use anyhow::{Context as _, Result};
2use cloud_llm_client::predict_edits_v3::Event;
3use credentials_provider::CredentialsProvider;
4use edit_prediction_context::RelatedFile;
5use futures::{AsyncReadExt as _, FutureExt, future::Shared};
6use gpui::{
7 App, AppContext as _, Entity, Task,
8 http_client::{self, AsyncBody, Method},
9};
10use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
11use lsp::DiagnosticSeverity;
12use project::{Project, ProjectPath};
13use serde::{Deserialize, Serialize};
14use std::{
15 collections::VecDeque,
16 fmt::{self, Write as _},
17 ops::Range,
18 path::Path,
19 sync::Arc,
20 time::Instant,
21};
22
23use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
24
25const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
26
27pub struct SweepAi {
28 pub api_token: Shared<Task<Option<String>>>,
29 pub debug_info: Arc<str>,
30}
31
32impl SweepAi {
33 pub fn new(cx: &App) -> Self {
34 SweepAi {
35 api_token: load_api_token(cx).shared(),
36 debug_info: debug_info(cx),
37 }
38 }
39
40 pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
41 self.api_token = Task::ready(api_token.clone()).shared();
42 store_api_token_in_keychain(api_token, cx)
43 }
44
45 pub fn request_prediction_with_sweep(
46 &self,
47 project: &Entity<Project>,
48 active_buffer: &Entity<Buffer>,
49 snapshot: BufferSnapshot,
50 position: language::Anchor,
51 events: Vec<Arc<Event>>,
52 recent_paths: &VecDeque<ProjectPath>,
53 related_files: Vec<RelatedFile>,
54 diagnostic_search_range: Range<Point>,
55 cx: &mut App,
56 ) -> Task<Result<Option<EditPredictionResult>>> {
57 let debug_info = self.debug_info.clone();
58 let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
59 return Task::ready(Ok(None));
60 };
61 let full_path: Arc<Path> = snapshot
62 .file()
63 .map(|file| file.full_path(cx))
64 .unwrap_or_else(|| "untitled".into())
65 .into();
66
67 let project_file = project::File::from_dyn(snapshot.file());
68 let repo_name = project_file
69 .map(|file| file.worktree.read(cx).root_name_str())
70 .unwrap_or("untitled")
71 .into();
72 let offset = position.to_offset(&snapshot);
73
74 let recent_buffers = recent_paths.iter().cloned();
75 let http_client = cx.http_client();
76
77 let recent_buffer_snapshots = recent_buffers
78 .filter_map(|project_path| {
79 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
80 if active_buffer == &buffer {
81 None
82 } else {
83 Some(buffer.read(cx).snapshot())
84 }
85 })
86 .take(3)
87 .collect::<Vec<_>>();
88
89 let cursor_point = position.to_point(&snapshot);
90 let buffer_snapshotted_at = Instant::now();
91
92 let result = cx.background_spawn(async move {
93 let text = snapshot.text();
94
95 let mut recent_changes = String::new();
96 for event in &events {
97 write_event(event.as_ref(), &mut recent_changes).unwrap();
98 }
99
100 let mut file_chunks = recent_buffer_snapshots
101 .into_iter()
102 .map(|snapshot| {
103 let end_point = Point::new(30, 0).min(snapshot.max_point());
104 FileChunk {
105 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
106 file_path: snapshot
107 .file()
108 .map(|f| f.path().as_unix_str())
109 .unwrap_or("untitled")
110 .to_string(),
111 start_line: 0,
112 end_line: end_point.row as usize,
113 timestamp: snapshot.file().and_then(|file| {
114 Some(
115 file.disk_state()
116 .mtime()?
117 .to_seconds_and_nanos_for_persistence()?
118 .0,
119 )
120 }),
121 }
122 })
123 .collect::<Vec<_>>();
124
125 let retrieval_chunks = related_files
126 .iter()
127 .flat_map(|related_file| {
128 related_file.excerpts.iter().map(|excerpt| FileChunk {
129 file_path: related_file.path.path.as_unix_str().to_string(),
130 start_line: excerpt.point_range.start.row as usize,
131 end_line: excerpt.point_range.end.row as usize,
132 content: excerpt.text.to_string(),
133 timestamp: None,
134 })
135 })
136 .collect();
137
138 let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
139 let mut diagnostic_content = String::new();
140 let mut diagnostic_count = 0;
141
142 for entry in diagnostic_entries {
143 let start_point: Point = entry.range.start;
144
145 let severity = match entry.diagnostic.severity {
146 DiagnosticSeverity::ERROR => "error",
147 DiagnosticSeverity::WARNING => "warning",
148 DiagnosticSeverity::INFORMATION => "info",
149 DiagnosticSeverity::HINT => "hint",
150 _ => continue,
151 };
152
153 diagnostic_count += 1;
154
155 writeln!(
156 &mut diagnostic_content,
157 "{} at line {}: {}",
158 severity,
159 start_point.row + 1,
160 entry.diagnostic.message
161 )?;
162 }
163
164 if !diagnostic_content.is_empty() {
165 file_chunks.push(FileChunk {
166 file_path: format!("Diagnostics for {}", full_path.display()),
167 start_line: 0,
168 end_line: diagnostic_count,
169 content: diagnostic_content,
170 timestamp: None,
171 });
172 }
173
174 let request_body = AutocompleteRequest {
175 debug_info,
176 repo_name,
177 file_path: full_path.clone(),
178 file_contents: text.clone(),
179 original_file_contents: text,
180 cursor_position: offset,
181 recent_changes: recent_changes.clone(),
182 changes_above_cursor: true,
183 multiple_suggestions: false,
184 branch: None,
185 file_chunks,
186 retrieval_chunks,
187 recent_user_actions: vec![],
188 use_bytes: true,
189 // TODO
190 privacy_mode_enabled: false,
191 };
192
193 let mut buf: Vec<u8> = Vec::new();
194 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
195 serde_json::to_writer(writer, &request_body)?;
196 let body: AsyncBody = buf.into();
197
198 let inputs = EditPredictionInputs {
199 events,
200 included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
201 path: full_path.clone(),
202 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
203 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
204 start_line: cloud_llm_client::predict_edits_v3::Line(0),
205 text: request_body.file_contents.into(),
206 }],
207 }],
208 cursor_point: cloud_llm_client::predict_edits_v3::Point {
209 column: cursor_point.column,
210 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
211 },
212 cursor_path: full_path.clone(),
213 };
214
215 let request = http_client::Request::builder()
216 .uri(SWEEP_API_URL)
217 .header("Content-Type", "application/json")
218 .header("Authorization", format!("Bearer {}", api_token))
219 .header("Connection", "keep-alive")
220 .header("Content-Encoding", "br")
221 .method(Method::POST)
222 .body(body)?;
223
224 let mut response = http_client.send(request).await?;
225
226 let mut body: Vec<u8> = Vec::new();
227 response.body_mut().read_to_end(&mut body).await?;
228
229 let response_received_at = Instant::now();
230 if !response.status().is_success() {
231 anyhow::bail!(
232 "Request failed with status: {:?}\nBody: {}",
233 response.status(),
234 String::from_utf8_lossy(&body),
235 );
236 };
237
238 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
239
240 let old_text = snapshot
241 .text_for_range(response.start_index..response.end_index)
242 .collect::<String>();
243 let edits = language::text_diff(&old_text, &response.completion)
244 .into_iter()
245 .map(|(range, text)| {
246 (
247 snapshot.anchor_after(response.start_index + range.start)
248 ..snapshot.anchor_before(response.start_index + range.end),
249 text,
250 )
251 })
252 .collect::<Vec<_>>();
253
254 anyhow::Ok((
255 response.autocomplete_id,
256 edits,
257 snapshot,
258 response_received_at,
259 inputs,
260 ))
261 });
262
263 let buffer = active_buffer.clone();
264
265 cx.spawn(async move |cx| {
266 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
267 anyhow::Ok(Some(
268 EditPredictionResult::new(
269 EditPredictionId(id.into()),
270 &buffer,
271 &old_snapshot,
272 edits.into(),
273 buffer_snapshotted_at,
274 response_received_at,
275 inputs,
276 cx,
277 )
278 .await,
279 ))
280 })
281 }
282}
283
284pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
285pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
286
287pub fn load_api_token(cx: &App) -> Task<Option<String>> {
288 if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
289 .ok()
290 .filter(|value| !value.is_empty())
291 {
292 return Task::ready(Some(api_token));
293 }
294 let credentials_provider = <dyn CredentialsProvider>::global(cx);
295 cx.spawn(async move |cx| {
296 let (_, credentials) = credentials_provider
297 .read_credentials(SWEEP_CREDENTIALS_URL, &cx)
298 .await
299 .ok()??;
300 String::from_utf8(credentials).ok()
301 })
302}
303
304fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
305 let credentials_provider = <dyn CredentialsProvider>::global(cx);
306
307 cx.spawn(async move |cx| {
308 if let Some(api_token) = api_token {
309 credentials_provider
310 .write_credentials(
311 SWEEP_CREDENTIALS_URL,
312 SWEEP_CREDENTIALS_USERNAME,
313 api_token.as_bytes(),
314 cx,
315 )
316 .await
317 .context("Failed to save Sweep API token to system keychain")
318 } else {
319 credentials_provider
320 .delete_credentials(SWEEP_CREDENTIALS_URL, cx)
321 .await
322 .context("Failed to delete Sweep API token from system keychain")
323 }
324 })
325}
326
327#[derive(Debug, Clone, Serialize)]
328struct AutocompleteRequest {
329 pub debug_info: Arc<str>,
330 pub repo_name: String,
331 pub branch: Option<String>,
332 pub file_path: Arc<Path>,
333 pub file_contents: String,
334 pub recent_changes: String,
335 pub cursor_position: usize,
336 pub original_file_contents: String,
337 pub file_chunks: Vec<FileChunk>,
338 pub retrieval_chunks: Vec<FileChunk>,
339 pub recent_user_actions: Vec<UserAction>,
340 pub multiple_suggestions: bool,
341 pub privacy_mode_enabled: bool,
342 pub changes_above_cursor: bool,
343 pub use_bytes: bool,
344}
345
346#[derive(Debug, Clone, Serialize)]
347struct FileChunk {
348 pub file_path: String,
349 pub start_line: usize,
350 pub end_line: usize,
351 pub content: String,
352 pub timestamp: Option<u64>,
353}
354
355#[derive(Debug, Clone, Serialize)]
356struct UserAction {
357 pub action_type: ActionType,
358 pub line_number: usize,
359 pub offset: usize,
360 pub file_path: String,
361 pub timestamp: u64,
362}
363
364#[allow(dead_code)]
365#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
366#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
367enum ActionType {
368 CursorMovement,
369 InsertChar,
370 DeleteChar,
371 InsertSelection,
372 DeleteSelection,
373}
374
375#[derive(Debug, Clone, Deserialize)]
376struct AutocompleteResponse {
377 pub autocomplete_id: String,
378 pub start_index: usize,
379 pub end_index: usize,
380 pub completion: String,
381 #[allow(dead_code)]
382 pub confidence: f64,
383 #[allow(dead_code)]
384 pub logprobs: Option<serde_json::Value>,
385 #[allow(dead_code)]
386 pub finish_reason: Option<String>,
387 #[allow(dead_code)]
388 pub elapsed_time_ms: u64,
389 #[allow(dead_code)]
390 #[serde(default, rename = "completions")]
391 pub additional_completions: Vec<AdditionalCompletion>,
392}
393
394#[allow(dead_code)]
395#[derive(Debug, Clone, Deserialize)]
396struct AdditionalCompletion {
397 pub start_index: usize,
398 pub end_index: usize,
399 pub completion: String,
400 pub confidence: f64,
401 pub autocomplete_id: String,
402 pub logprobs: Option<serde_json::Value>,
403 pub finish_reason: Option<String>,
404}
405
406fn write_event(
407 event: &cloud_llm_client::predict_edits_v3::Event,
408 f: &mut impl fmt::Write,
409) -> fmt::Result {
410 match event {
411 cloud_llm_client::predict_edits_v3::Event::BufferChange {
412 old_path,
413 path,
414 diff,
415 ..
416 } => {
417 if old_path != path {
418 // TODO confirm how to do this for sweep
419 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
420 }
421
422 if !diff.is_empty() {
423 write!(f, "File: {}:\n{}\n", path.display(), diff)?
424 }
425
426 fmt::Result::Ok(())
427 }
428 }
429}
430
431fn debug_info(cx: &gpui::App) -> Arc<str> {
432 format!(
433 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
434 version = release_channel::AppVersion::global(cx),
435 sha = release_channel::AppCommitSha::try_global(cx)
436 .map_or("unknown".to_string(), |sha| sha.full()),
437 os = client::telemetry::os_name(),
438 )
439 .into()
440}