1use anyhow::Result;
2use futures::AsyncReadExt as _;
3use gpui::{
4 App, AppContext as _, Entity, Global, SharedString, Task,
5 http_client::{self, AsyncBody, Method},
6};
7use language::{Point, ToOffset as _};
8use language_model::{ApiKeyState, EnvVar, env_var};
9use lsp::DiagnosticSeverity;
10use serde::{Deserialize, Serialize};
11use std::{
12 fmt::{self, Write as _},
13 path::Path,
14 sync::Arc,
15 time::Instant,
16};
17
18use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
19
20const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
21
22pub struct SweepAi {
23 pub api_token: Entity<ApiKeyState>,
24 pub debug_info: Arc<str>,
25}
26
27impl SweepAi {
28 pub fn new(cx: &mut App) -> Self {
29 SweepAi {
30 api_token: sweep_api_token(cx),
31 debug_info: debug_info(cx),
32 }
33 }
34
35 pub fn request_prediction_with_sweep(
36 &self,
37 inputs: EditPredictionModelInput,
38 cx: &mut App,
39 ) -> Task<Result<Option<EditPredictionResult>>> {
40 let debug_info = self.debug_info.clone();
41 self.api_token.update(cx, |key_state, cx| {
42 _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
43 });
44 let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
45 return Task::ready(Ok(None));
46 };
47 let full_path: Arc<Path> = inputs
48 .snapshot
49 .file()
50 .map(|file| file.full_path(cx))
51 .unwrap_or_else(|| "untitled".into())
52 .into();
53
54 let project_file = project::File::from_dyn(inputs.snapshot.file());
55 let repo_name = project_file
56 .map(|file| file.worktree.read(cx).root_name_str())
57 .unwrap_or("untitled")
58 .into();
59 let offset = inputs.position.to_offset(&inputs.snapshot);
60
61 let recent_buffers = inputs.recent_paths.iter().cloned();
62 let http_client = cx.http_client();
63
64 let recent_buffer_snapshots = recent_buffers
65 .filter_map(|project_path| {
66 let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
67 if inputs.buffer == buffer {
68 None
69 } else {
70 Some(buffer.read(cx).snapshot())
71 }
72 })
73 .take(3)
74 .collect::<Vec<_>>();
75
76 let buffer_snapshotted_at = Instant::now();
77
78 let result = cx.background_spawn(async move {
79 let text = inputs.snapshot.text();
80
81 let mut recent_changes = String::new();
82 for event in &inputs.events {
83 write_event(event.as_ref(), &mut recent_changes).unwrap();
84 }
85
86 let mut file_chunks = recent_buffer_snapshots
87 .into_iter()
88 .map(|snapshot| {
89 let end_point = Point::new(30, 0).min(snapshot.max_point());
90 FileChunk {
91 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
92 file_path: snapshot
93 .file()
94 .map(|f| f.path().as_unix_str())
95 .unwrap_or("untitled")
96 .to_string(),
97 start_line: 0,
98 end_line: end_point.row as usize,
99 timestamp: snapshot.file().and_then(|file| {
100 Some(
101 file.disk_state()
102 .mtime()?
103 .to_seconds_and_nanos_for_persistence()?
104 .0,
105 )
106 }),
107 }
108 })
109 .collect::<Vec<_>>();
110
111 let retrieval_chunks = inputs
112 .related_files
113 .iter()
114 .flat_map(|related_file| {
115 related_file.excerpts.iter().map(|excerpt| FileChunk {
116 file_path: related_file.path.to_string_lossy().to_string(),
117 start_line: excerpt.row_range.start as usize,
118 end_line: excerpt.row_range.end as usize,
119 content: excerpt.text.to_string(),
120 timestamp: None,
121 })
122 })
123 .collect();
124
125 let diagnostic_entries = inputs
126 .snapshot
127 .diagnostics_in_range(inputs.diagnostic_search_range, false);
128 let mut diagnostic_content = String::new();
129 let mut diagnostic_count = 0;
130
131 for entry in diagnostic_entries {
132 let start_point: Point = entry.range.start;
133
134 let severity = match entry.diagnostic.severity {
135 DiagnosticSeverity::ERROR => "error",
136 DiagnosticSeverity::WARNING => "warning",
137 DiagnosticSeverity::INFORMATION => "info",
138 DiagnosticSeverity::HINT => "hint",
139 _ => continue,
140 };
141
142 diagnostic_count += 1;
143
144 writeln!(
145 &mut diagnostic_content,
146 "{} at line {}: {}",
147 severity,
148 start_point.row + 1,
149 entry.diagnostic.message
150 )?;
151 }
152
153 if !diagnostic_content.is_empty() {
154 file_chunks.push(FileChunk {
155 file_path: format!("Diagnostics for {}", full_path.display()),
156 start_line: 0,
157 end_line: diagnostic_count,
158 content: diagnostic_content,
159 timestamp: None,
160 });
161 }
162
163 let request_body = AutocompleteRequest {
164 debug_info,
165 repo_name,
166 file_path: full_path.clone(),
167 file_contents: text.clone(),
168 original_file_contents: text,
169 cursor_position: offset,
170 recent_changes: recent_changes.clone(),
171 changes_above_cursor: true,
172 multiple_suggestions: false,
173 branch: None,
174 file_chunks,
175 retrieval_chunks,
176 recent_user_actions: vec![],
177 use_bytes: true,
178 // TODO
179 privacy_mode_enabled: false,
180 };
181
182 let mut buf: Vec<u8> = Vec::new();
183 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
184 serde_json::to_writer(writer, &request_body)?;
185 let body: AsyncBody = buf.into();
186
187 let ep_inputs = zeta_prompt::ZetaPromptInput {
188 events: inputs.events,
189 related_files: inputs.related_files.clone(),
190 cursor_path: full_path.clone(),
191 cursor_excerpt: request_body.file_contents.into(),
192 // we actually don't know
193 editable_range_in_excerpt: 0..inputs.snapshot.len(),
194 cursor_offset_in_excerpt: request_body.cursor_position,
195 };
196
197 let request = http_client::Request::builder()
198 .uri(SWEEP_API_URL)
199 .header("Content-Type", "application/json")
200 .header("Authorization", format!("Bearer {}", api_token))
201 .header("Connection", "keep-alive")
202 .header("Content-Encoding", "br")
203 .method(Method::POST)
204 .body(body)?;
205
206 let mut response = http_client.send(request).await?;
207
208 let mut body: Vec<u8> = Vec::new();
209 response.body_mut().read_to_end(&mut body).await?;
210
211 let response_received_at = Instant::now();
212 if !response.status().is_success() {
213 anyhow::bail!(
214 "Request failed with status: {:?}\nBody: {}",
215 response.status(),
216 String::from_utf8_lossy(&body),
217 );
218 };
219
220 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
221
222 let old_text = inputs
223 .snapshot
224 .text_for_range(response.start_index..response.end_index)
225 .collect::<String>();
226 let edits = language::text_diff(&old_text, &response.completion)
227 .into_iter()
228 .map(|(range, text)| {
229 (
230 inputs
231 .snapshot
232 .anchor_after(response.start_index + range.start)
233 ..inputs
234 .snapshot
235 .anchor_before(response.start_index + range.end),
236 text,
237 )
238 })
239 .collect::<Vec<_>>();
240
241 anyhow::Ok((
242 response.autocomplete_id,
243 edits,
244 inputs.snapshot,
245 response_received_at,
246 ep_inputs,
247 ))
248 });
249
250 let buffer = inputs.buffer.clone();
251
252 cx.spawn(async move |cx| {
253 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
254 anyhow::Ok(Some(
255 EditPredictionResult::new(
256 EditPredictionId(id.into()),
257 &buffer,
258 &old_snapshot,
259 edits.into(),
260 buffer_snapshotted_at,
261 response_received_at,
262 inputs,
263 cx,
264 )
265 .await,
266 ))
267 })
268 }
269}
270
271pub const SWEEP_CREDENTIALS_URL: SharedString =
272 SharedString::new_static("https://autocomplete.sweep.dev");
273pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
274pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
275
276struct GlobalSweepApiKey(Entity<ApiKeyState>);
277
278impl Global for GlobalSweepApiKey {}
279
280pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
281 if let Some(global) = cx.try_global::<GlobalSweepApiKey>() {
282 return global.0.clone();
283 }
284 let entity =
285 cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()));
286 cx.set_global(GlobalSweepApiKey(entity.clone()));
287 entity
288}
289
290pub fn load_sweep_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
291 sweep_api_token(cx).update(cx, |key_state, cx| {
292 key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx)
293 })
294}
295
296#[derive(Debug, Clone, Serialize)]
297struct AutocompleteRequest {
298 pub debug_info: Arc<str>,
299 pub repo_name: String,
300 pub branch: Option<String>,
301 pub file_path: Arc<Path>,
302 pub file_contents: String,
303 pub recent_changes: String,
304 pub cursor_position: usize,
305 pub original_file_contents: String,
306 pub file_chunks: Vec<FileChunk>,
307 pub retrieval_chunks: Vec<FileChunk>,
308 pub recent_user_actions: Vec<UserAction>,
309 pub multiple_suggestions: bool,
310 pub privacy_mode_enabled: bool,
311 pub changes_above_cursor: bool,
312 pub use_bytes: bool,
313}
314
315#[derive(Debug, Clone, Serialize)]
316struct FileChunk {
317 pub file_path: String,
318 pub start_line: usize,
319 pub end_line: usize,
320 pub content: String,
321 pub timestamp: Option<u64>,
322}
323
324#[derive(Debug, Clone, Serialize)]
325struct UserAction {
326 pub action_type: ActionType,
327 pub line_number: usize,
328 pub offset: usize,
329 pub file_path: String,
330 pub timestamp: u64,
331}
332
333#[allow(dead_code)]
334#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
335#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
336enum ActionType {
337 CursorMovement,
338 InsertChar,
339 DeleteChar,
340 InsertSelection,
341 DeleteSelection,
342}
343
344#[derive(Debug, Clone, Deserialize)]
345struct AutocompleteResponse {
346 pub autocomplete_id: String,
347 pub start_index: usize,
348 pub end_index: usize,
349 pub completion: String,
350 #[allow(dead_code)]
351 pub confidence: f64,
352 #[allow(dead_code)]
353 pub logprobs: Option<serde_json::Value>,
354 #[allow(dead_code)]
355 pub finish_reason: Option<String>,
356 #[allow(dead_code)]
357 pub elapsed_time_ms: u64,
358 #[allow(dead_code)]
359 #[serde(default, rename = "completions")]
360 pub additional_completions: Vec<AdditionalCompletion>,
361}
362
363#[allow(dead_code)]
364#[derive(Debug, Clone, Deserialize)]
365struct AdditionalCompletion {
366 pub start_index: usize,
367 pub end_index: usize,
368 pub completion: String,
369 pub confidence: f64,
370 pub autocomplete_id: String,
371 pub logprobs: Option<serde_json::Value>,
372 pub finish_reason: Option<String>,
373}
374
375fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
376 match event {
377 zeta_prompt::Event::BufferChange {
378 old_path,
379 path,
380 diff,
381 ..
382 } => {
383 if old_path != path {
384 // TODO confirm how to do this for sweep
385 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
386 }
387
388 if !diff.is_empty() {
389 write!(f, "File: {}:\n{}\n", path.display(), diff)?
390 }
391
392 fmt::Result::Ok(())
393 }
394 }
395}
396
397fn debug_info(cx: &gpui::App) -> Arc<str> {
398 format!(
399 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
400 version = release_channel::AppVersion::global(cx),
401 sha = release_channel::AppCommitSha::try_global(cx)
402 .map_or("unknown".to_string(), |sha| sha.full()),
403 os = client::telemetry::os_name(),
404 )
405 .into()
406}