1use anyhow::Result;
2use cloud_llm_client::predict_edits_v3::Event;
3use futures::AsyncReadExt as _;
4use gpui::{
5 App, AppContext as _, Entity, Task,
6 http_client::{self, AsyncBody, Method},
7};
8use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
9use lsp::DiagnosticSeverity;
10use project::{Project, ProjectPath};
11use serde::{Deserialize, Serialize};
12use std::{
13 collections::VecDeque,
14 fmt::{self, Write as _},
15 ops::Range,
16 path::Path,
17 sync::Arc,
18 time::Instant,
19};
20
21use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
22
23const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
24
25pub struct SweepAi {
26 pub api_token: Option<String>,
27 pub debug_info: Arc<str>,
28}
29
30impl SweepAi {
31 pub fn new(cx: &App) -> Self {
32 SweepAi {
33 api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
34 debug_info: debug_info(cx),
35 }
36 }
37
38 pub fn request_prediction_with_sweep(
39 &self,
40 project: &Entity<Project>,
41 active_buffer: &Entity<Buffer>,
42 snapshot: BufferSnapshot,
43 position: language::Anchor,
44 events: Vec<Arc<Event>>,
45 recent_paths: &VecDeque<ProjectPath>,
46 diagnostic_search_range: Range<Point>,
47 cx: &mut App,
48 ) -> Task<Result<Option<EditPredictionResult>>> {
49 let debug_info = self.debug_info.clone();
50 let Some(api_token) = self.api_token.clone() else {
51 return Task::ready(Ok(None));
52 };
53 let full_path: Arc<Path> = snapshot
54 .file()
55 .map(|file| file.full_path(cx))
56 .unwrap_or_else(|| "untitled".into())
57 .into();
58
59 let project_file = project::File::from_dyn(snapshot.file());
60 let repo_name = project_file
61 .map(|file| file.worktree.read(cx).root_name_str())
62 .unwrap_or("untitled")
63 .into();
64 let offset = position.to_offset(&snapshot);
65
66 let recent_buffers = recent_paths.iter().cloned();
67 let http_client = cx.http_client();
68
69 let recent_buffer_snapshots = recent_buffers
70 .filter_map(|project_path| {
71 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
72 if active_buffer == &buffer {
73 None
74 } else {
75 Some(buffer.read(cx).snapshot())
76 }
77 })
78 .take(3)
79 .collect::<Vec<_>>();
80
81 let cursor_point = position.to_point(&snapshot);
82 let buffer_snapshotted_at = Instant::now();
83
84 let result = cx.background_spawn(async move {
85 let text = snapshot.text();
86
87 let mut recent_changes = String::new();
88 for event in &events {
89 write_event(event.as_ref(), &mut recent_changes).unwrap();
90 }
91
92 let mut file_chunks = recent_buffer_snapshots
93 .into_iter()
94 .map(|snapshot| {
95 let end_point = Point::new(30, 0).min(snapshot.max_point());
96 FileChunk {
97 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
98 file_path: snapshot
99 .file()
100 .map(|f| f.path().as_unix_str())
101 .unwrap_or("untitled")
102 .to_string(),
103 start_line: 0,
104 end_line: end_point.row as usize,
105 timestamp: snapshot.file().and_then(|file| {
106 Some(
107 file.disk_state()
108 .mtime()?
109 .to_seconds_and_nanos_for_persistence()?
110 .0,
111 )
112 }),
113 }
114 })
115 .collect::<Vec<_>>();
116
117 let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
118 let mut diagnostic_content = String::new();
119 let mut diagnostic_count = 0;
120
121 for entry in diagnostic_entries {
122 let start_point: Point = entry.range.start;
123
124 let severity = match entry.diagnostic.severity {
125 DiagnosticSeverity::ERROR => "error",
126 DiagnosticSeverity::WARNING => "warning",
127 DiagnosticSeverity::INFORMATION => "info",
128 DiagnosticSeverity::HINT => "hint",
129 _ => continue,
130 };
131
132 diagnostic_count += 1;
133
134 writeln!(
135 &mut diagnostic_content,
136 "{} at line {}: {}",
137 severity,
138 start_point.row + 1,
139 entry.diagnostic.message
140 )?;
141 }
142
143 if !diagnostic_content.is_empty() {
144 file_chunks.push(FileChunk {
145 file_path: format!("Diagnostics for {}", full_path.display()),
146 start_line: 0,
147 end_line: diagnostic_count,
148 content: diagnostic_content,
149 timestamp: None,
150 });
151 }
152
153 let request_body = AutocompleteRequest {
154 debug_info,
155 repo_name,
156 file_path: full_path.clone(),
157 file_contents: text.clone(),
158 original_file_contents: text,
159 cursor_position: offset,
160 recent_changes: recent_changes.clone(),
161 changes_above_cursor: true,
162 multiple_suggestions: false,
163 branch: None,
164 file_chunks,
165 retrieval_chunks: vec![],
166 recent_user_actions: vec![],
167 use_bytes: true,
168 // TODO
169 privacy_mode_enabled: false,
170 };
171
172 let mut buf: Vec<u8> = Vec::new();
173 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
174 serde_json::to_writer(writer, &request_body)?;
175 let body: AsyncBody = buf.into();
176
177 let inputs = EditPredictionInputs {
178 events,
179 included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
180 path: full_path.clone(),
181 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
182 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
183 start_line: cloud_llm_client::predict_edits_v3::Line(0),
184 text: request_body.file_contents.into(),
185 }],
186 }],
187 cursor_point: cloud_llm_client::predict_edits_v3::Point {
188 column: cursor_point.column,
189 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
190 },
191 cursor_path: full_path.clone(),
192 };
193
194 let request = http_client::Request::builder()
195 .uri(SWEEP_API_URL)
196 .header("Content-Type", "application/json")
197 .header("Authorization", format!("Bearer {}", api_token))
198 .header("Connection", "keep-alive")
199 .header("Content-Encoding", "br")
200 .method(Method::POST)
201 .body(body)?;
202
203 let mut response = http_client.send(request).await?;
204
205 let mut body: Vec<u8> = Vec::new();
206 response.body_mut().read_to_end(&mut body).await?;
207
208 let response_received_at = Instant::now();
209 if !response.status().is_success() {
210 anyhow::bail!(
211 "Request failed with status: {:?}\nBody: {}",
212 response.status(),
213 String::from_utf8_lossy(&body),
214 );
215 };
216
217 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
218
219 let old_text = snapshot
220 .text_for_range(response.start_index..response.end_index)
221 .collect::<String>();
222 let edits = language::text_diff(&old_text, &response.completion)
223 .into_iter()
224 .map(|(range, text)| {
225 (
226 snapshot.anchor_after(response.start_index + range.start)
227 ..snapshot.anchor_before(response.start_index + range.end),
228 text,
229 )
230 })
231 .collect::<Vec<_>>();
232
233 anyhow::Ok((
234 response.autocomplete_id,
235 edits,
236 snapshot,
237 response_received_at,
238 inputs,
239 ))
240 });
241
242 let buffer = active_buffer.clone();
243
244 cx.spawn(async move |cx| {
245 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
246 anyhow::Ok(Some(
247 EditPredictionResult::new(
248 EditPredictionId(id.into()),
249 &buffer,
250 &old_snapshot,
251 edits.into(),
252 buffer_snapshotted_at,
253 response_received_at,
254 inputs,
255 cx,
256 )
257 .await,
258 ))
259 })
260 }
261}
262
263#[derive(Debug, Clone, Serialize)]
264struct AutocompleteRequest {
265 pub debug_info: Arc<str>,
266 pub repo_name: String,
267 pub branch: Option<String>,
268 pub file_path: Arc<Path>,
269 pub file_contents: String,
270 pub recent_changes: String,
271 pub cursor_position: usize,
272 pub original_file_contents: String,
273 pub file_chunks: Vec<FileChunk>,
274 pub retrieval_chunks: Vec<RetrievalChunk>,
275 pub recent_user_actions: Vec<UserAction>,
276 pub multiple_suggestions: bool,
277 pub privacy_mode_enabled: bool,
278 pub changes_above_cursor: bool,
279 pub use_bytes: bool,
280}
281
282#[derive(Debug, Clone, Serialize)]
283struct FileChunk {
284 pub file_path: String,
285 pub start_line: usize,
286 pub end_line: usize,
287 pub content: String,
288 pub timestamp: Option<u64>,
289}
290
291#[derive(Debug, Clone, Serialize)]
292struct RetrievalChunk {
293 pub file_path: String,
294 pub start_line: usize,
295 pub end_line: usize,
296 pub content: String,
297 pub timestamp: u64,
298}
299
300#[derive(Debug, Clone, Serialize)]
301struct UserAction {
302 pub action_type: ActionType,
303 pub line_number: usize,
304 pub offset: usize,
305 pub file_path: String,
306 pub timestamp: u64,
307}
308
309#[allow(dead_code)]
310#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
311#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
312enum ActionType {
313 CursorMovement,
314 InsertChar,
315 DeleteChar,
316 InsertSelection,
317 DeleteSelection,
318}
319
320#[derive(Debug, Clone, Deserialize)]
321struct AutocompleteResponse {
322 pub autocomplete_id: String,
323 pub start_index: usize,
324 pub end_index: usize,
325 pub completion: String,
326 #[allow(dead_code)]
327 pub confidence: f64,
328 #[allow(dead_code)]
329 pub logprobs: Option<serde_json::Value>,
330 #[allow(dead_code)]
331 pub finish_reason: Option<String>,
332 #[allow(dead_code)]
333 pub elapsed_time_ms: u64,
334 #[allow(dead_code)]
335 #[serde(default, rename = "completions")]
336 pub additional_completions: Vec<AdditionalCompletion>,
337}
338
339#[allow(dead_code)]
340#[derive(Debug, Clone, Deserialize)]
341struct AdditionalCompletion {
342 pub start_index: usize,
343 pub end_index: usize,
344 pub completion: String,
345 pub confidence: f64,
346 pub autocomplete_id: String,
347 pub logprobs: Option<serde_json::Value>,
348 pub finish_reason: Option<String>,
349}
350
351fn write_event(
352 event: &cloud_llm_client::predict_edits_v3::Event,
353 f: &mut impl fmt::Write,
354) -> fmt::Result {
355 match event {
356 cloud_llm_client::predict_edits_v3::Event::BufferChange {
357 old_path,
358 path,
359 diff,
360 ..
361 } => {
362 if old_path != path {
363 // TODO confirm how to do this for sweep
364 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
365 }
366
367 if !diff.is_empty() {
368 write!(f, "File: {}:\n{}\n", path.display(), diff)?
369 }
370
371 fmt::Result::Ok(())
372 }
373 }
374}
375
376fn debug_info(cx: &gpui::App) -> Arc<str> {
377 format!(
378 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
379 version = release_channel::AppVersion::global(cx),
380 sha = release_channel::AppCommitSha::try_global(cx)
381 .map_or("unknown".to_string(), |sha| sha.full()),
382 os = client::telemetry::os_name(),
383 )
384 .into()
385}