1use anyhow::Result;
2use futures::AsyncReadExt as _;
3use gpui::{
4 App, AppContext as _, Entity, SharedString, Task,
5 http_client::{self, AsyncBody, Method},
6};
7use language::{Point, ToOffset as _};
8use language_model::{ApiKeyState, EnvVar, env_var};
9use lsp::DiagnosticSeverity;
10use serde::{Deserialize, Serialize};
11use std::{
12 fmt::{self, Write as _},
13 path::Path,
14 sync::Arc,
15 time::Instant,
16};
17
18use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
19
20const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
21
22pub struct SweepAi {
23 pub api_token: Entity<ApiKeyState>,
24 pub debug_info: Arc<str>,
25}
26
27impl SweepAi {
28 pub fn new(cx: &mut App) -> Self {
29 SweepAi {
30 api_token: sweep_api_token(cx),
31 debug_info: debug_info(cx),
32 }
33 }
34
35 pub fn request_prediction_with_sweep(
36 &self,
37 inputs: EditPredictionModelInput,
38 cx: &mut App,
39 ) -> Task<Result<Option<EditPredictionResult>>> {
40 let debug_info = self.debug_info.clone();
41 self.api_token.update(cx, |key_state, cx| {
42 _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
43 });
44 let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
45 return Task::ready(Ok(None));
46 };
47 let full_path: Arc<Path> = inputs
48 .snapshot
49 .file()
50 .map(|file| file.full_path(cx))
51 .unwrap_or_else(|| "untitled".into())
52 .into();
53
54 let project_file = project::File::from_dyn(inputs.snapshot.file());
55 let repo_name = project_file
56 .map(|file| file.worktree.read(cx).root_name_str())
57 .unwrap_or("untitled")
58 .into();
59 let offset = inputs.position.to_offset(&inputs.snapshot);
60
61 let recent_buffers = inputs.recent_paths.iter().cloned();
62 let http_client = cx.http_client();
63
64 let recent_buffer_snapshots = recent_buffers
65 .filter_map(|project_path| {
66 let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
67 if inputs.buffer == buffer {
68 None
69 } else {
70 Some(buffer.read(cx).snapshot())
71 }
72 })
73 .take(3)
74 .collect::<Vec<_>>();
75
76 let buffer_snapshotted_at = Instant::now();
77
78 let result = cx.background_spawn(async move {
79 let text = inputs.snapshot.text();
80
81 let mut recent_changes = String::new();
82 for event in &inputs.events {
83 write_event(event.as_ref(), &mut recent_changes).unwrap();
84 }
85
86 let mut file_chunks = recent_buffer_snapshots
87 .into_iter()
88 .map(|snapshot| {
89 let end_point = Point::new(30, 0).min(snapshot.max_point());
90 FileChunk {
91 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
92 file_path: snapshot
93 .file()
94 .map(|f| f.path().as_unix_str())
95 .unwrap_or("untitled")
96 .to_string(),
97 start_line: 0,
98 end_line: end_point.row as usize,
99 timestamp: snapshot.file().and_then(|file| {
100 Some(
101 file.disk_state()
102 .mtime()?
103 .to_seconds_and_nanos_for_persistence()?
104 .0,
105 )
106 }),
107 }
108 })
109 .collect::<Vec<_>>();
110
111 let retrieval_chunks = inputs
112 .related_files
113 .iter()
114 .flat_map(|related_file| {
115 related_file.excerpts.iter().map(|excerpt| FileChunk {
116 file_path: related_file.path.to_string_lossy().to_string(),
117 start_line: excerpt.row_range.start as usize,
118 end_line: excerpt.row_range.end as usize,
119 content: excerpt.text.to_string(),
120 timestamp: None,
121 })
122 })
123 .collect();
124
125 let diagnostic_entries = inputs
126 .snapshot
127 .diagnostics_in_range(inputs.diagnostic_search_range, false);
128 let mut diagnostic_content = String::new();
129 let mut diagnostic_count = 0;
130
131 for entry in diagnostic_entries {
132 let start_point: Point = entry.range.start;
133
134 let severity = match entry.diagnostic.severity {
135 DiagnosticSeverity::ERROR => "error",
136 DiagnosticSeverity::WARNING => "warning",
137 DiagnosticSeverity::INFORMATION => "info",
138 DiagnosticSeverity::HINT => "hint",
139 _ => continue,
140 };
141
142 diagnostic_count += 1;
143
144 writeln!(
145 &mut diagnostic_content,
146 "{} at line {}: {}",
147 severity,
148 start_point.row + 1,
149 entry.diagnostic.message
150 )?;
151 }
152
153 if !diagnostic_content.is_empty() {
154 file_chunks.push(FileChunk {
155 file_path: format!("Diagnostics for {}", full_path.display()),
156 start_line: 0,
157 end_line: diagnostic_count,
158 content: diagnostic_content,
159 timestamp: None,
160 });
161 }
162
163 let request_body = AutocompleteRequest {
164 debug_info,
165 repo_name,
166 file_path: full_path.clone(),
167 file_contents: text.clone(),
168 original_file_contents: text,
169 cursor_position: offset,
170 recent_changes: recent_changes.clone(),
171 changes_above_cursor: true,
172 multiple_suggestions: false,
173 branch: None,
174 file_chunks,
175 retrieval_chunks,
176 recent_user_actions: vec![],
177 use_bytes: true,
178 // TODO
179 privacy_mode_enabled: false,
180 };
181
182 let mut buf: Vec<u8> = Vec::new();
183 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
184 serde_json::to_writer(writer, &request_body)?;
185 let body: AsyncBody = buf.into();
186
187 let ep_inputs = zeta_prompt::ZetaPromptInput {
188 events: inputs.events,
189 related_files: inputs.related_files.clone(),
190 cursor_path: full_path.clone(),
191 cursor_excerpt: request_body.file_contents.into(),
192 // we actually don't know
193 editable_range_in_excerpt: 0..inputs.snapshot.len(),
194 cursor_offset_in_excerpt: request_body.cursor_position,
195 };
196
197 let request = http_client::Request::builder()
198 .uri(SWEEP_API_URL)
199 .header("Content-Type", "application/json")
200 .header("Authorization", format!("Bearer {}", api_token))
201 .header("Connection", "keep-alive")
202 .header("Content-Encoding", "br")
203 .method(Method::POST)
204 .body(body)?;
205
206 let mut response = http_client.send(request).await?;
207
208 let mut body: Vec<u8> = Vec::new();
209 response.body_mut().read_to_end(&mut body).await?;
210
211 let response_received_at = Instant::now();
212 if !response.status().is_success() {
213 anyhow::bail!(
214 "Request failed with status: {:?}\nBody: {}",
215 response.status(),
216 String::from_utf8_lossy(&body),
217 );
218 };
219
220 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
221
222 let old_text = inputs
223 .snapshot
224 .text_for_range(response.start_index..response.end_index)
225 .collect::<String>();
226 let edits = language::text_diff(&old_text, &response.completion)
227 .into_iter()
228 .map(|(range, text)| {
229 (
230 inputs
231 .snapshot
232 .anchor_after(response.start_index + range.start)
233 ..inputs
234 .snapshot
235 .anchor_before(response.start_index + range.end),
236 text,
237 )
238 })
239 .collect::<Vec<_>>();
240
241 anyhow::Ok((
242 response.autocomplete_id,
243 edits,
244 inputs.snapshot,
245 response_received_at,
246 ep_inputs,
247 ))
248 });
249
250 let buffer = inputs.buffer.clone();
251
252 cx.spawn(async move |cx| {
253 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
254 anyhow::Ok(Some(
255 EditPredictionResult::new(
256 EditPredictionId(id.into()),
257 &buffer,
258 &old_snapshot,
259 edits.into(),
260 buffer_snapshotted_at,
261 response_received_at,
262 inputs,
263 cx,
264 )
265 .await,
266 ))
267 })
268 }
269}
270
271pub const SWEEP_CREDENTIALS_URL: SharedString =
272 SharedString::new_static("https://autocomplete.sweep.dev");
273pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
274pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
275pub static SWEEP_API_KEY: std::sync::OnceLock<Entity<ApiKeyState>> = std::sync::OnceLock::new();
276
277pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
278 SWEEP_API_KEY
279 .get_or_init(|| {
280 cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()))
281 })
282 .clone()
283}
284
285#[derive(Debug, Clone, Serialize)]
286struct AutocompleteRequest {
287 pub debug_info: Arc<str>,
288 pub repo_name: String,
289 pub branch: Option<String>,
290 pub file_path: Arc<Path>,
291 pub file_contents: String,
292 pub recent_changes: String,
293 pub cursor_position: usize,
294 pub original_file_contents: String,
295 pub file_chunks: Vec<FileChunk>,
296 pub retrieval_chunks: Vec<FileChunk>,
297 pub recent_user_actions: Vec<UserAction>,
298 pub multiple_suggestions: bool,
299 pub privacy_mode_enabled: bool,
300 pub changes_above_cursor: bool,
301 pub use_bytes: bool,
302}
303
304#[derive(Debug, Clone, Serialize)]
305struct FileChunk {
306 pub file_path: String,
307 pub start_line: usize,
308 pub end_line: usize,
309 pub content: String,
310 pub timestamp: Option<u64>,
311}
312
313#[derive(Debug, Clone, Serialize)]
314struct UserAction {
315 pub action_type: ActionType,
316 pub line_number: usize,
317 pub offset: usize,
318 pub file_path: String,
319 pub timestamp: u64,
320}
321
322#[allow(dead_code)]
323#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
324#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
325enum ActionType {
326 CursorMovement,
327 InsertChar,
328 DeleteChar,
329 InsertSelection,
330 DeleteSelection,
331}
332
333#[derive(Debug, Clone, Deserialize)]
334struct AutocompleteResponse {
335 pub autocomplete_id: String,
336 pub start_index: usize,
337 pub end_index: usize,
338 pub completion: String,
339 #[allow(dead_code)]
340 pub confidence: f64,
341 #[allow(dead_code)]
342 pub logprobs: Option<serde_json::Value>,
343 #[allow(dead_code)]
344 pub finish_reason: Option<String>,
345 #[allow(dead_code)]
346 pub elapsed_time_ms: u64,
347 #[allow(dead_code)]
348 #[serde(default, rename = "completions")]
349 pub additional_completions: Vec<AdditionalCompletion>,
350}
351
352#[allow(dead_code)]
353#[derive(Debug, Clone, Deserialize)]
354struct AdditionalCompletion {
355 pub start_index: usize,
356 pub end_index: usize,
357 pub completion: String,
358 pub confidence: f64,
359 pub autocomplete_id: String,
360 pub logprobs: Option<serde_json::Value>,
361 pub finish_reason: Option<String>,
362}
363
364fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
365 match event {
366 zeta_prompt::Event::BufferChange {
367 old_path,
368 path,
369 diff,
370 ..
371 } => {
372 if old_path != path {
373 // TODO confirm how to do this for sweep
374 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
375 }
376
377 if !diff.is_empty() {
378 write!(f, "File: {}:\n{}\n", path.display(), diff)?
379 }
380
381 fmt::Result::Ok(())
382 }
383 }
384}
385
386fn debug_info(cx: &gpui::App) -> Arc<str> {
387 format!(
388 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
389 version = release_channel::AppVersion::global(cx),
390 sha = release_channel::AppCommitSha::try_global(cx)
391 .map_or("unknown".to_string(), |sha| sha.full()),
392 os = client::telemetry::os_name(),
393 )
394 .into()
395}