1use anyhow::{Context as _, Result};
2use cloud_llm_client::predict_edits_v3::Event;
3use futures::AsyncReadExt as _;
4use gpui::{
5 App, AppContext as _, Entity, Task,
6 http_client::{self, AsyncBody, Method},
7};
8use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
9use lsp::DiagnosticSeverity;
10use project::{Project, ProjectPath};
11use serde::{Deserialize, Serialize};
12use std::{
13 collections::VecDeque,
14 fmt::{self, Write as _},
15 ops::Range,
16 path::Path,
17 sync::Arc,
18 time::Instant,
19};
20use util::ResultExt as _;
21
22use crate::{EditPrediction, EditPredictionId, EditPredictionInputs};
23
24const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
25
26pub struct SweepAi {
27 pub api_token: Option<String>,
28 pub debug_info: Arc<str>,
29}
30
31impl SweepAi {
32 pub fn new(cx: &App) -> Self {
33 SweepAi {
34 api_token: std::env::var("SWEEP_AI_TOKEN")
35 .context("No SWEEP_AI_TOKEN environment variable set")
36 .log_err(),
37 debug_info: debug_info(cx),
38 }
39 }
40
41 pub fn request_prediction_with_sweep(
42 &self,
43 project: &Entity<Project>,
44 active_buffer: &Entity<Buffer>,
45 snapshot: BufferSnapshot,
46 position: language::Anchor,
47 events: Vec<Arc<Event>>,
48 recent_paths: &VecDeque<ProjectPath>,
49 diagnostic_search_range: Range<Point>,
50 cx: &mut App,
51 ) -> Task<Result<Option<EditPrediction>>> {
52 let debug_info = self.debug_info.clone();
53 let Some(api_token) = self.api_token.clone() else {
54 return Task::ready(Ok(None));
55 };
56 let full_path: Arc<Path> = snapshot
57 .file()
58 .map(|file| file.full_path(cx))
59 .unwrap_or_else(|| "untitled".into())
60 .into();
61
62 let project_file = project::File::from_dyn(snapshot.file());
63 let repo_name = project_file
64 .map(|file| file.worktree.read(cx).root_name_str())
65 .unwrap_or("untitled")
66 .into();
67 let offset = position.to_offset(&snapshot);
68
69 let recent_buffers = recent_paths.iter().cloned();
70 let http_client = cx.http_client();
71
72 let recent_buffer_snapshots = recent_buffers
73 .filter_map(|project_path| {
74 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
75 if active_buffer == &buffer {
76 None
77 } else {
78 Some(buffer.read(cx).snapshot())
79 }
80 })
81 .take(3)
82 .collect::<Vec<_>>();
83
84 let cursor_point = position.to_point(&snapshot);
85 let buffer_snapshotted_at = Instant::now();
86
87 let result = cx.background_spawn(async move {
88 let text = snapshot.text();
89
90 let mut recent_changes = String::new();
91 for event in &events {
92 write_event(event.as_ref(), &mut recent_changes).unwrap();
93 }
94
95 let mut file_chunks = recent_buffer_snapshots
96 .into_iter()
97 .map(|snapshot| {
98 let end_point = Point::new(30, 0).min(snapshot.max_point());
99 FileChunk {
100 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
101 file_path: snapshot
102 .file()
103 .map(|f| f.path().as_unix_str())
104 .unwrap_or("untitled")
105 .to_string(),
106 start_line: 0,
107 end_line: end_point.row as usize,
108 timestamp: snapshot.file().and_then(|file| {
109 Some(
110 file.disk_state()
111 .mtime()?
112 .to_seconds_and_nanos_for_persistence()?
113 .0,
114 )
115 }),
116 }
117 })
118 .collect::<Vec<_>>();
119
120 let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
121 let mut diagnostic_content = String::new();
122 let mut diagnostic_count = 0;
123
124 for entry in diagnostic_entries {
125 let start_point: Point = entry.range.start;
126
127 let severity = match entry.diagnostic.severity {
128 DiagnosticSeverity::ERROR => "error",
129 DiagnosticSeverity::WARNING => "warning",
130 DiagnosticSeverity::INFORMATION => "info",
131 DiagnosticSeverity::HINT => "hint",
132 _ => continue,
133 };
134
135 diagnostic_count += 1;
136
137 writeln!(
138 &mut diagnostic_content,
139 "{} at line {}: {}",
140 severity,
141 start_point.row + 1,
142 entry.diagnostic.message
143 )?;
144 }
145
146 if !diagnostic_content.is_empty() {
147 file_chunks.push(FileChunk {
148 file_path: format!("Diagnostics for {}", full_path.display()),
149 start_line: 0,
150 end_line: diagnostic_count,
151 content: diagnostic_content,
152 timestamp: None,
153 });
154 }
155
156 let request_body = AutocompleteRequest {
157 debug_info,
158 repo_name,
159 file_path: full_path.clone(),
160 file_contents: text.clone(),
161 original_file_contents: text,
162 cursor_position: offset,
163 recent_changes: recent_changes.clone(),
164 changes_above_cursor: true,
165 multiple_suggestions: false,
166 branch: None,
167 file_chunks,
168 retrieval_chunks: vec![],
169 recent_user_actions: vec![],
170 // TODO
171 privacy_mode_enabled: false,
172 };
173
174 let mut buf: Vec<u8> = Vec::new();
175 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
176 serde_json::to_writer(writer, &request_body)?;
177 let body: AsyncBody = buf.into();
178
179 let inputs = EditPredictionInputs {
180 events,
181 included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
182 path: full_path.clone(),
183 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
184 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
185 start_line: cloud_llm_client::predict_edits_v3::Line(0),
186 text: request_body.file_contents.into(),
187 }],
188 }],
189 cursor_point: cloud_llm_client::predict_edits_v3::Point {
190 column: cursor_point.column,
191 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
192 },
193 cursor_path: full_path.clone(),
194 };
195
196 let request = http_client::Request::builder()
197 .uri(SWEEP_API_URL)
198 .header("Content-Type", "application/json")
199 .header("Authorization", format!("Bearer {}", api_token))
200 .header("Connection", "keep-alive")
201 .header("Content-Encoding", "br")
202 .method(Method::POST)
203 .body(body)?;
204
205 let mut response = http_client.send(request).await?;
206
207 let mut body: Vec<u8> = Vec::new();
208 response.body_mut().read_to_end(&mut body).await?;
209
210 let response_received_at = Instant::now();
211 if !response.status().is_success() {
212 anyhow::bail!(
213 "Request failed with status: {:?}\nBody: {}",
214 response.status(),
215 String::from_utf8_lossy(&body),
216 );
217 };
218
219 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
220
221 let old_text = snapshot
222 .text_for_range(response.start_index..response.end_index)
223 .collect::<String>();
224 let edits = language::text_diff(&old_text, &response.completion)
225 .into_iter()
226 .map(|(range, text)| {
227 (
228 snapshot.anchor_after(response.start_index + range.start)
229 ..snapshot.anchor_before(response.start_index + range.end),
230 text,
231 )
232 })
233 .collect::<Vec<_>>();
234
235 anyhow::Ok((
236 response.autocomplete_id,
237 edits,
238 snapshot,
239 response_received_at,
240 inputs,
241 ))
242 });
243
244 let buffer = active_buffer.clone();
245
246 cx.spawn(async move |cx| {
247 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
248 anyhow::Ok(
249 EditPrediction::new(
250 EditPredictionId(id.into()),
251 &buffer,
252 &old_snapshot,
253 edits.into(),
254 buffer_snapshotted_at,
255 response_received_at,
256 inputs,
257 cx,
258 )
259 .await,
260 )
261 })
262 }
263}
264
265#[derive(Debug, Clone, Serialize)]
266struct AutocompleteRequest {
267 pub debug_info: Arc<str>,
268 pub repo_name: String,
269 pub branch: Option<String>,
270 pub file_path: Arc<Path>,
271 pub file_contents: String,
272 pub recent_changes: String,
273 pub cursor_position: usize,
274 pub original_file_contents: String,
275 pub file_chunks: Vec<FileChunk>,
276 pub retrieval_chunks: Vec<RetrievalChunk>,
277 pub recent_user_actions: Vec<UserAction>,
278 pub multiple_suggestions: bool,
279 pub privacy_mode_enabled: bool,
280 pub changes_above_cursor: bool,
281}
282
283#[derive(Debug, Clone, Serialize)]
284struct FileChunk {
285 pub file_path: String,
286 pub start_line: usize,
287 pub end_line: usize,
288 pub content: String,
289 pub timestamp: Option<u64>,
290}
291
292#[derive(Debug, Clone, Serialize)]
293struct RetrievalChunk {
294 pub file_path: String,
295 pub start_line: usize,
296 pub end_line: usize,
297 pub content: String,
298 pub timestamp: u64,
299}
300
301#[derive(Debug, Clone, Serialize)]
302struct UserAction {
303 pub action_type: ActionType,
304 pub line_number: usize,
305 pub offset: usize,
306 pub file_path: String,
307 pub timestamp: u64,
308}
309
310#[allow(dead_code)]
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
312#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
313enum ActionType {
314 CursorMovement,
315 InsertChar,
316 DeleteChar,
317 InsertSelection,
318 DeleteSelection,
319}
320
321#[derive(Debug, Clone, Deserialize)]
322struct AutocompleteResponse {
323 pub autocomplete_id: String,
324 pub start_index: usize,
325 pub end_index: usize,
326 pub completion: String,
327 #[allow(dead_code)]
328 pub confidence: f64,
329 #[allow(dead_code)]
330 pub logprobs: Option<serde_json::Value>,
331 #[allow(dead_code)]
332 pub finish_reason: Option<String>,
333 #[allow(dead_code)]
334 pub elapsed_time_ms: u64,
335 #[allow(dead_code)]
336 #[serde(default, rename = "completions")]
337 pub additional_completions: Vec<AdditionalCompletion>,
338}
339
340#[allow(dead_code)]
341#[derive(Debug, Clone, Deserialize)]
342struct AdditionalCompletion {
343 pub start_index: usize,
344 pub end_index: usize,
345 pub completion: String,
346 pub confidence: f64,
347 pub autocomplete_id: String,
348 pub logprobs: Option<serde_json::Value>,
349 pub finish_reason: Option<String>,
350}
351
352fn write_event(
353 event: &cloud_llm_client::predict_edits_v3::Event,
354 f: &mut impl fmt::Write,
355) -> fmt::Result {
356 match event {
357 cloud_llm_client::predict_edits_v3::Event::BufferChange {
358 old_path,
359 path,
360 diff,
361 ..
362 } => {
363 if old_path != path {
364 // TODO confirm how to do this for sweep
365 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
366 }
367
368 if !diff.is_empty() {
369 write!(f, "File: {}:\n{}\n", path.display(), diff)?
370 }
371
372 fmt::Result::Ok(())
373 }
374 }
375}
376
377fn debug_info(cx: &gpui::App) -> Arc<str> {
378 format!(
379 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
380 version = release_channel::AppVersion::global(cx),
381 sha = release_channel::AppCommitSha::try_global(cx)
382 .map_or("unknown".to_string(), |sha| sha.full()),
383 os = client::telemetry::os_name(),
384 )
385 .into()
386}