1use anyhow::{Context as _, Result};
2use credentials_provider::CredentialsProvider;
3use futures::{AsyncReadExt as _, FutureExt, future::Shared};
4use gpui::{
5 App, AppContext as _, Task,
6 http_client::{self, AsyncBody, Method},
7};
8use language::{Point, ToOffset as _};
9use lsp::DiagnosticSeverity;
10use serde::{Deserialize, Serialize};
11use std::{
12 fmt::{self, Write as _},
13 path::Path,
14 sync::Arc,
15 time::Instant,
16};
17
18use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
19
20const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
21
22pub struct SweepAi {
23 pub api_token: Shared<Task<Option<String>>>,
24 pub debug_info: Arc<str>,
25}
26
27impl SweepAi {
28 pub fn new(cx: &App) -> Self {
29 SweepAi {
30 api_token: load_api_token(cx).shared(),
31 debug_info: debug_info(cx),
32 }
33 }
34
35 pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
36 self.api_token = Task::ready(api_token.clone()).shared();
37 store_api_token_in_keychain(api_token, cx)
38 }
39
40 pub fn request_prediction_with_sweep(
41 &self,
42 inputs: EditPredictionModelInput,
43 cx: &mut App,
44 ) -> Task<Result<Option<EditPredictionResult>>> {
45 let debug_info = self.debug_info.clone();
46 let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
47 return Task::ready(Ok(None));
48 };
49 let full_path: Arc<Path> = inputs
50 .snapshot
51 .file()
52 .map(|file| file.full_path(cx))
53 .unwrap_or_else(|| "untitled".into())
54 .into();
55
56 let project_file = project::File::from_dyn(inputs.snapshot.file());
57 let repo_name = project_file
58 .map(|file| file.worktree.read(cx).root_name_str())
59 .unwrap_or("untitled")
60 .into();
61 let offset = inputs.position.to_offset(&inputs.snapshot);
62
63 let recent_buffers = inputs.recent_paths.iter().cloned();
64 let http_client = cx.http_client();
65
66 let recent_buffer_snapshots = recent_buffers
67 .filter_map(|project_path| {
68 let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
69 if inputs.buffer == buffer {
70 None
71 } else {
72 Some(buffer.read(cx).snapshot())
73 }
74 })
75 .take(3)
76 .collect::<Vec<_>>();
77
78 let buffer_snapshotted_at = Instant::now();
79
80 let result = cx.background_spawn(async move {
81 let text = inputs.snapshot.text();
82
83 let mut recent_changes = String::new();
84 for event in &inputs.events {
85 write_event(event.as_ref(), &mut recent_changes).unwrap();
86 }
87
88 let mut file_chunks = recent_buffer_snapshots
89 .into_iter()
90 .map(|snapshot| {
91 let end_point = Point::new(30, 0).min(snapshot.max_point());
92 FileChunk {
93 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
94 file_path: snapshot
95 .file()
96 .map(|f| f.path().as_unix_str())
97 .unwrap_or("untitled")
98 .to_string(),
99 start_line: 0,
100 end_line: end_point.row as usize,
101 timestamp: snapshot.file().and_then(|file| {
102 Some(
103 file.disk_state()
104 .mtime()?
105 .to_seconds_and_nanos_for_persistence()?
106 .0,
107 )
108 }),
109 }
110 })
111 .collect::<Vec<_>>();
112
113 let retrieval_chunks = inputs
114 .related_files
115 .iter()
116 .flat_map(|related_file| {
117 related_file.excerpts.iter().map(|excerpt| FileChunk {
118 file_path: related_file.path.to_string_lossy().to_string(),
119 start_line: excerpt.row_range.start as usize,
120 end_line: excerpt.row_range.end as usize,
121 content: excerpt.text.to_string(),
122 timestamp: None,
123 })
124 })
125 .collect();
126
127 let diagnostic_entries = inputs
128 .snapshot
129 .diagnostics_in_range(inputs.diagnostic_search_range, false);
130 let mut diagnostic_content = String::new();
131 let mut diagnostic_count = 0;
132
133 for entry in diagnostic_entries {
134 let start_point: Point = entry.range.start;
135
136 let severity = match entry.diagnostic.severity {
137 DiagnosticSeverity::ERROR => "error",
138 DiagnosticSeverity::WARNING => "warning",
139 DiagnosticSeverity::INFORMATION => "info",
140 DiagnosticSeverity::HINT => "hint",
141 _ => continue,
142 };
143
144 diagnostic_count += 1;
145
146 writeln!(
147 &mut diagnostic_content,
148 "{} at line {}: {}",
149 severity,
150 start_point.row + 1,
151 entry.diagnostic.message
152 )?;
153 }
154
155 if !diagnostic_content.is_empty() {
156 file_chunks.push(FileChunk {
157 file_path: format!("Diagnostics for {}", full_path.display()),
158 start_line: 0,
159 end_line: diagnostic_count,
160 content: diagnostic_content,
161 timestamp: None,
162 });
163 }
164
165 let request_body = AutocompleteRequest {
166 debug_info,
167 repo_name,
168 file_path: full_path.clone(),
169 file_contents: text.clone(),
170 original_file_contents: text,
171 cursor_position: offset,
172 recent_changes: recent_changes.clone(),
173 changes_above_cursor: true,
174 multiple_suggestions: false,
175 branch: None,
176 file_chunks,
177 retrieval_chunks,
178 recent_user_actions: vec![],
179 use_bytes: true,
180 // TODO
181 privacy_mode_enabled: false,
182 };
183
184 let mut buf: Vec<u8> = Vec::new();
185 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
186 serde_json::to_writer(writer, &request_body)?;
187 let body: AsyncBody = buf.into();
188
189 let ep_inputs = zeta_prompt::ZetaPromptInput {
190 events: inputs.events,
191 related_files: inputs.related_files.clone(),
192 cursor_path: full_path.clone(),
193 cursor_excerpt: request_body.file_contents.into(),
194 // we actually don't know
195 editable_range_in_excerpt: 0..inputs.snapshot.len(),
196 cursor_offset_in_excerpt: request_body.cursor_position,
197 };
198
199 let request = http_client::Request::builder()
200 .uri(SWEEP_API_URL)
201 .header("Content-Type", "application/json")
202 .header("Authorization", format!("Bearer {}", api_token))
203 .header("Connection", "keep-alive")
204 .header("Content-Encoding", "br")
205 .method(Method::POST)
206 .body(body)?;
207
208 let mut response = http_client.send(request).await?;
209
210 let mut body: Vec<u8> = Vec::new();
211 response.body_mut().read_to_end(&mut body).await?;
212
213 let response_received_at = Instant::now();
214 if !response.status().is_success() {
215 anyhow::bail!(
216 "Request failed with status: {:?}\nBody: {}",
217 response.status(),
218 String::from_utf8_lossy(&body),
219 );
220 };
221
222 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
223
224 let old_text = inputs
225 .snapshot
226 .text_for_range(response.start_index..response.end_index)
227 .collect::<String>();
228 let edits = language::text_diff(&old_text, &response.completion)
229 .into_iter()
230 .map(|(range, text)| {
231 (
232 inputs
233 .snapshot
234 .anchor_after(response.start_index + range.start)
235 ..inputs
236 .snapshot
237 .anchor_before(response.start_index + range.end),
238 text,
239 )
240 })
241 .collect::<Vec<_>>();
242
243 anyhow::Ok((
244 response.autocomplete_id,
245 edits,
246 inputs.snapshot,
247 response_received_at,
248 ep_inputs,
249 ))
250 });
251
252 let buffer = inputs.buffer.clone();
253
254 cx.spawn(async move |cx| {
255 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
256 anyhow::Ok(Some(
257 EditPredictionResult::new(
258 EditPredictionId(id.into()),
259 &buffer,
260 &old_snapshot,
261 edits.into(),
262 buffer_snapshotted_at,
263 response_received_at,
264 inputs,
265 cx,
266 )
267 .await,
268 ))
269 })
270 }
271}
272
273pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
274pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
275
276pub fn load_api_token(cx: &App) -> Task<Option<String>> {
277 if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
278 .ok()
279 .filter(|value| !value.is_empty())
280 {
281 return Task::ready(Some(api_token));
282 }
283 let credentials_provider = <dyn CredentialsProvider>::global(cx);
284 cx.spawn(async move |cx| {
285 let (_, credentials) = credentials_provider
286 .read_credentials(SWEEP_CREDENTIALS_URL, &cx)
287 .await
288 .ok()??;
289 String::from_utf8(credentials).ok()
290 })
291}
292
293fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
294 let credentials_provider = <dyn CredentialsProvider>::global(cx);
295
296 cx.spawn(async move |cx| {
297 if let Some(api_token) = api_token {
298 credentials_provider
299 .write_credentials(
300 SWEEP_CREDENTIALS_URL,
301 SWEEP_CREDENTIALS_USERNAME,
302 api_token.as_bytes(),
303 cx,
304 )
305 .await
306 .context("Failed to save Sweep API token to system keychain")
307 } else {
308 credentials_provider
309 .delete_credentials(SWEEP_CREDENTIALS_URL, cx)
310 .await
311 .context("Failed to delete Sweep API token from system keychain")
312 }
313 })
314}
315
316#[derive(Debug, Clone, Serialize)]
317struct AutocompleteRequest {
318 pub debug_info: Arc<str>,
319 pub repo_name: String,
320 pub branch: Option<String>,
321 pub file_path: Arc<Path>,
322 pub file_contents: String,
323 pub recent_changes: String,
324 pub cursor_position: usize,
325 pub original_file_contents: String,
326 pub file_chunks: Vec<FileChunk>,
327 pub retrieval_chunks: Vec<FileChunk>,
328 pub recent_user_actions: Vec<UserAction>,
329 pub multiple_suggestions: bool,
330 pub privacy_mode_enabled: bool,
331 pub changes_above_cursor: bool,
332 pub use_bytes: bool,
333}
334
335#[derive(Debug, Clone, Serialize)]
336struct FileChunk {
337 pub file_path: String,
338 pub start_line: usize,
339 pub end_line: usize,
340 pub content: String,
341 pub timestamp: Option<u64>,
342}
343
344#[derive(Debug, Clone, Serialize)]
345struct UserAction {
346 pub action_type: ActionType,
347 pub line_number: usize,
348 pub offset: usize,
349 pub file_path: String,
350 pub timestamp: u64,
351}
352
353#[allow(dead_code)]
354#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
355#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
356enum ActionType {
357 CursorMovement,
358 InsertChar,
359 DeleteChar,
360 InsertSelection,
361 DeleteSelection,
362}
363
364#[derive(Debug, Clone, Deserialize)]
365struct AutocompleteResponse {
366 pub autocomplete_id: String,
367 pub start_index: usize,
368 pub end_index: usize,
369 pub completion: String,
370 #[allow(dead_code)]
371 pub confidence: f64,
372 #[allow(dead_code)]
373 pub logprobs: Option<serde_json::Value>,
374 #[allow(dead_code)]
375 pub finish_reason: Option<String>,
376 #[allow(dead_code)]
377 pub elapsed_time_ms: u64,
378 #[allow(dead_code)]
379 #[serde(default, rename = "completions")]
380 pub additional_completions: Vec<AdditionalCompletion>,
381}
382
383#[allow(dead_code)]
384#[derive(Debug, Clone, Deserialize)]
385struct AdditionalCompletion {
386 pub start_index: usize,
387 pub end_index: usize,
388 pub completion: String,
389 pub confidence: f64,
390 pub autocomplete_id: String,
391 pub logprobs: Option<serde_json::Value>,
392 pub finish_reason: Option<String>,
393}
394
395fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
396 match event {
397 zeta_prompt::Event::BufferChange {
398 old_path,
399 path,
400 diff,
401 ..
402 } => {
403 if old_path != path {
404 // TODO confirm how to do this for sweep
405 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
406 }
407
408 if !diff.is_empty() {
409 write!(f, "File: {}:\n{}\n", path.display(), diff)?
410 }
411
412 fmt::Result::Ok(())
413 }
414 }
415}
416
417fn debug_info(cx: &gpui::App) -> Arc<str> {
418 format!(
419 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
420 version = release_channel::AppVersion::global(cx),
421 sha = release_channel::AppCommitSha::try_global(cx)
422 .map_or("unknown".to_string(), |sha| sha.full()),
423 os = client::telemetry::os_name(),
424 )
425 .into()
426}