1mod api;
2
3use anyhow::{Context as _, Result};
4use arrayvec::ArrayVec;
5use client::telemetry;
6use collections::HashMap;
7use feature_flags::FeatureFlag;
8use futures::AsyncReadExt as _;
9use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity};
10use http_client::{AsyncBody, Method};
11use language::{Anchor, Buffer, BufferSnapshot, EditPreview, ToOffset as _, ToPoint, text_diff};
12use project::Project;
13use release_channel::{AppCommitSha, AppVersion};
14use std::collections::{VecDeque, hash_map};
15use std::fmt::{self, Display};
16use std::mem;
17use std::{
18 cmp,
19 fmt::Write,
20 ops::Range,
21 path::Path,
22 sync::Arc,
23 time::{Duration, Instant},
24};
25use util::ResultExt;
26use util::rel_path::RelPath;
27use workspace::Workspace;
28
29use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk};
30
31const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
32const MAX_EVENT_COUNT: usize = 16;
33
34const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
35
36pub struct SweepFeatureFlag;
37
38impl FeatureFlag for SweepFeatureFlag {
39 const NAME: &str = "sweep-ai";
40}
41
42#[derive(Clone)]
43struct SweepAiGlobal(Entity<SweepAi>);
44
45impl Global for SweepAiGlobal {}
46
47#[derive(Clone)]
48pub struct EditPrediction {
49 id: EditPredictionId,
50 path: Arc<Path>,
51 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
52 snapshot: BufferSnapshot,
53 edit_preview: EditPreview,
54}
55
56impl EditPrediction {
57 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
58 edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
59 }
60}
61
62impl fmt::Debug for EditPrediction {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 f.debug_struct("EditPrediction")
65 .field("path", &self.path)
66 .field("edits", &self.edits)
67 .finish_non_exhaustive()
68 }
69}
70
71#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
72pub struct EditPredictionId(String);
73
74impl Display for EditPredictionId {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 write!(f, "{}", self.0)
77 }
78}
79
80pub struct SweepAi {
81 projects: HashMap<EntityId, SweepAiProject>,
82 debug_info: Arc<str>,
83 api_token: Option<String>,
84}
85
86struct SweepAiProject {
87 events: VecDeque<Event>,
88 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
89}
90
91impl SweepAi {
92 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
93 cx.try_global::<SweepAiGlobal>()
94 .map(|global| global.0.clone())
95 }
96
97 pub fn register(cx: &mut App) -> Entity<Self> {
98 Self::global(cx).unwrap_or_else(|| {
99 let entity = cx.new(|cx| Self::new(cx));
100 cx.set_global(SweepAiGlobal(entity.clone()));
101 entity
102 })
103 }
104
105 pub fn clear_history(&mut self) {
106 for sweep_ai_project in self.projects.values_mut() {
107 sweep_ai_project.events.clear();
108 }
109 }
110
111 fn new(cx: &mut Context<Self>) -> Self {
112 Self {
113 api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
114 projects: HashMap::default(),
115 debug_info: format!(
116 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
117 version = AppVersion::global(cx),
118 sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()),
119 os = telemetry::os_name(),
120 )
121 .into(),
122 }
123 }
124
125 fn get_or_init_sweep_ai_project(
126 &mut self,
127 project: &Entity<Project>,
128 cx: &mut Context<Self>,
129 ) -> &mut SweepAiProject {
130 let project_id = project.entity_id();
131 match self.projects.entry(project_id) {
132 hash_map::Entry::Occupied(entry) => entry.into_mut(),
133 hash_map::Entry::Vacant(entry) => {
134 cx.observe_release(project, move |this, _, _cx| {
135 this.projects.remove(&project_id);
136 })
137 .detach();
138 entry.insert(SweepAiProject {
139 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
140 registered_buffers: HashMap::default(),
141 })
142 }
143 }
144 }
145
146 fn push_event(sweep_ai_project: &mut SweepAiProject, event: Event) {
147 let events = &mut sweep_ai_project.events;
148
149 if let Some(Event::BufferChange {
150 new_snapshot: last_new_snapshot,
151 timestamp: last_timestamp,
152 ..
153 }) = events.back_mut()
154 {
155 // Coalesce edits for the same buffer when they happen one after the other.
156 let Event::BufferChange {
157 old_snapshot,
158 new_snapshot,
159 timestamp,
160 } = &event;
161
162 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
163 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
164 && old_snapshot.version == last_new_snapshot.version
165 {
166 *last_new_snapshot = new_snapshot.clone();
167 *last_timestamp = *timestamp;
168 return;
169 }
170 }
171
172 if events.len() >= MAX_EVENT_COUNT {
173 // These are halved instead of popping to improve prompt caching.
174 events.drain(..MAX_EVENT_COUNT / 2);
175 }
176
177 events.push_back(event);
178 }
179
180 pub fn register_buffer(
181 &mut self,
182 buffer: &Entity<Buffer>,
183 project: &Entity<Project>,
184 cx: &mut Context<Self>,
185 ) {
186 let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
187 Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
188 }
189
190 fn register_buffer_impl<'a>(
191 sweep_ai_project: &'a mut SweepAiProject,
192 buffer: &Entity<Buffer>,
193 project: &Entity<Project>,
194 cx: &mut Context<Self>,
195 ) -> &'a mut RegisteredBuffer {
196 let buffer_id = buffer.entity_id();
197 match sweep_ai_project.registered_buffers.entry(buffer_id) {
198 hash_map::Entry::Occupied(entry) => entry.into_mut(),
199 hash_map::Entry::Vacant(entry) => {
200 let snapshot = buffer.read(cx).snapshot();
201 let project_entity_id = project.entity_id();
202 entry.insert(RegisteredBuffer {
203 snapshot,
204 _subscriptions: [
205 cx.subscribe(buffer, {
206 let project = project.downgrade();
207 move |this, buffer, event, cx| {
208 if let language::BufferEvent::Edited = event
209 && let Some(project) = project.upgrade()
210 {
211 this.report_changes_for_buffer(&buffer, &project, cx);
212 }
213 }
214 }),
215 cx.observe_release(buffer, move |this, _buffer, _cx| {
216 let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id)
217 else {
218 return;
219 };
220 sweep_ai_project.registered_buffers.remove(&buffer_id);
221 }),
222 ],
223 })
224 }
225 }
226 }
227
228 pub fn request_completion(
229 &mut self,
230 workspace: &WeakEntity<Workspace>,
231 project: &Entity<Project>,
232 active_buffer: &Entity<Buffer>,
233 position: language::Anchor,
234 cx: &mut Context<Self>,
235 ) -> Task<Result<Option<EditPrediction>>> {
236 let snapshot = active_buffer.read(cx).snapshot();
237 let debug_info = self.debug_info.clone();
238 let Some(api_token) = self.api_token.clone() else {
239 return Task::ready(Ok(None));
240 };
241 let full_path: Arc<Path> = snapshot
242 .file()
243 .map(|file| file.full_path(cx))
244 .unwrap_or_else(|| "untitled".into())
245 .into();
246
247 let project_file = project::File::from_dyn(snapshot.file());
248 let repo_name = project_file
249 .map(|file| file.worktree.read(cx).root_name_str())
250 .unwrap_or("untitled")
251 .into();
252 let offset = position.to_offset(&snapshot);
253
254 let project_state = self.get_or_init_sweep_ai_project(project, cx);
255 let events = project_state.events.clone();
256 let http_client = cx.http_client();
257
258 let Some(recent_buffers) = workspace
259 .read_with(cx, |workspace, cx| {
260 workspace
261 .recent_navigation_history_iter(cx)
262 .filter_map(|(project_path, _)| {
263 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
264
265 if active_buffer == &buffer {
266 None
267 } else {
268 Some(buffer.read(cx).snapshot())
269 }
270 })
271 .take(3)
272 .collect::<Vec<_>>()
273 })
274 .log_err()
275 else {
276 return Task::ready(Ok(None));
277 };
278
279 let result = cx.background_spawn({
280 let full_path = full_path.clone();
281 async move {
282 let text = snapshot.text();
283
284 let mut recent_changes = String::new();
285
286 for event in events {
287 writeln!(&mut recent_changes, "{event}")?;
288 }
289
290 let file_chunks = recent_buffers
291 .into_iter()
292 .map(|snapshot| {
293 let end_point = language::Point::new(30, 0).min(snapshot.max_point());
294 FileChunk {
295 content: snapshot
296 .text_for_range(language::Point::zero()..end_point)
297 .collect(),
298 file_path: snapshot
299 .file()
300 .map(|f| f.path().as_unix_str())
301 .unwrap_or("untitled")
302 .to_string(),
303 start_line: 0,
304 end_line: end_point.row as usize,
305 timestamp: snapshot.file().and_then(|file| {
306 Some(
307 file.disk_state()
308 .mtime()?
309 .to_seconds_and_nanos_for_persistence()?
310 .0,
311 )
312 }),
313 }
314 })
315 .collect();
316
317 let request_body = AutocompleteRequest {
318 debug_info,
319 repo_name,
320 file_path: full_path.clone(),
321 file_contents: text.clone(),
322 original_file_contents: text,
323 cursor_position: offset,
324 recent_changes: recent_changes.clone(),
325 changes_above_cursor: true,
326 multiple_suggestions: false,
327 branch: None,
328 file_chunks,
329 retrieval_chunks: vec![],
330 recent_user_actions: vec![],
331 // TODO
332 privacy_mode_enabled: false,
333 };
334
335 let mut buf: Vec<u8> = Vec::new();
336 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
337 serde_json::to_writer(writer, &request_body)?;
338 let body: AsyncBody = buf.into();
339
340 let request = http_client::Request::builder()
341 .uri(SWEEP_API_URL)
342 .header("Content-Type", "application/json")
343 .header("Authorization", format!("Bearer {}", api_token))
344 .header("Connection", "keep-alive")
345 .header("Content-Encoding", "br")
346 .method(Method::POST)
347 .body(body)?;
348
349 let mut response = http_client.send(request).await?;
350
351 let mut body: Vec<u8> = Vec::new();
352 response.body_mut().read_to_end(&mut body).await?;
353
354 if !response.status().is_success() {
355 anyhow::bail!(
356 "Request failed with status: {:?}\nBody: {}",
357 response.status(),
358 String::from_utf8_lossy(&body),
359 );
360 };
361
362 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
363
364 let old_text = snapshot
365 .text_for_range(response.start_index..response.end_index)
366 .collect::<String>();
367 let edits = text_diff(&old_text, &response.completion)
368 .into_iter()
369 .map(|(range, text)| {
370 (
371 snapshot.anchor_after(response.start_index + range.start)
372 ..snapshot.anchor_before(response.start_index + range.end),
373 text,
374 )
375 })
376 .collect::<Vec<_>>();
377
378 anyhow::Ok((response.autocomplete_id, edits, snapshot))
379 }
380 });
381
382 let buffer = active_buffer.clone();
383
384 cx.spawn(async move |_, cx| {
385 let (id, edits, old_snapshot) = result.await?;
386
387 if edits.is_empty() {
388 return anyhow::Ok(None);
389 }
390
391 let Some((edits, new_snapshot, preview_task)) =
392 buffer.read_with(cx, |buffer, cx| {
393 let new_snapshot = buffer.snapshot();
394
395 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
396 edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
397 .into();
398 let preview_task = buffer.preview_edits(edits.clone(), cx);
399
400 Some((edits, new_snapshot, preview_task))
401 })?
402 else {
403 return anyhow::Ok(None);
404 };
405
406 let prediction = EditPrediction {
407 id: EditPredictionId(id),
408 path: full_path,
409 edits,
410 snapshot: new_snapshot,
411 edit_preview: preview_task.await,
412 };
413
414 anyhow::Ok(Some(prediction))
415 })
416 }
417
418 fn report_changes_for_buffer(
419 &mut self,
420 buffer: &Entity<Buffer>,
421 project: &Entity<Project>,
422 cx: &mut Context<Self>,
423 ) -> BufferSnapshot {
424 let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
425 let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
426
427 let new_snapshot = buffer.read(cx).snapshot();
428 if new_snapshot.version != registered_buffer.snapshot.version {
429 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
430 Self::push_event(
431 sweep_ai_project,
432 Event::BufferChange {
433 old_snapshot,
434 new_snapshot: new_snapshot.clone(),
435 timestamp: Instant::now(),
436 },
437 );
438 }
439
440 new_snapshot
441 }
442}
443
444struct RegisteredBuffer {
445 snapshot: BufferSnapshot,
446 _subscriptions: [gpui::Subscription; 2],
447}
448
449#[derive(Clone)]
450pub enum Event {
451 BufferChange {
452 old_snapshot: BufferSnapshot,
453 new_snapshot: BufferSnapshot,
454 timestamp: Instant,
455 },
456}
457
458impl Display for Event {
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 match self {
461 Event::BufferChange {
462 old_snapshot,
463 new_snapshot,
464 ..
465 } => {
466 let old_path = old_snapshot
467 .file()
468 .map(|f| f.path().as_ref())
469 .unwrap_or(RelPath::unix("untitled").unwrap());
470 let new_path = new_snapshot
471 .file()
472 .map(|f| f.path().as_ref())
473 .unwrap_or(RelPath::unix("untitled").unwrap());
474 if old_path != new_path {
475 // TODO confirm how to do this for sweep
476 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
477 }
478
479 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
480 if !diff.is_empty() {
481 write!(
482 f,
483 "File: {}:\n{}\n",
484 new_path.display(util::paths::PathStyle::Posix),
485 diff
486 )?
487 }
488
489 fmt::Result::Ok(())
490 }
491 }
492 }
493}
494
495#[derive(Debug, Clone)]
496struct CurrentEditPrediction {
497 buffer_id: EntityId,
498 completion: EditPrediction,
499}
500
501impl CurrentEditPrediction {
502 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
503 if self.buffer_id != old_completion.buffer_id {
504 return true;
505 }
506
507 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
508 return true;
509 };
510 let Some(new_edits) = self.completion.interpolate(snapshot) else {
511 return false;
512 };
513
514 if old_edits.len() == 1 && new_edits.len() == 1 {
515 let (old_range, old_text) = &old_edits[0];
516 let (new_range, new_text) = &new_edits[0];
517 new_range == old_range && new_text.starts_with(old_text.as_ref())
518 } else {
519 true
520 }
521 }
522}
523
524struct PendingCompletion {
525 id: usize,
526 _task: Task<()>,
527}
528
529pub struct SweepAiEditPredictionProvider {
530 workspace: WeakEntity<Workspace>,
531 sweep_ai: Entity<SweepAi>,
532 pending_completions: ArrayVec<PendingCompletion, 2>,
533 next_pending_completion_id: usize,
534 current_completion: Option<CurrentEditPrediction>,
535 last_request_timestamp: Instant,
536 project: Entity<Project>,
537}
538
539impl SweepAiEditPredictionProvider {
540 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
541
542 pub fn new(
543 sweep_ai: Entity<SweepAi>,
544 workspace: WeakEntity<Workspace>,
545 project: Entity<Project>,
546 ) -> Self {
547 Self {
548 sweep_ai,
549 pending_completions: ArrayVec::new(),
550 next_pending_completion_id: 0,
551 current_completion: None,
552 last_request_timestamp: Instant::now(),
553 project,
554 workspace,
555 }
556 }
557}
558
559impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
560 fn name() -> &'static str {
561 "zed-predict"
562 }
563
564 fn display_name() -> &'static str {
565 "Zed's Edit Predictions"
566 }
567
568 fn show_completions_in_menu() -> bool {
569 true
570 }
571
572 fn show_tab_accept_marker() -> bool {
573 true
574 }
575
576 fn is_enabled(
577 &self,
578 _buffer: &Entity<Buffer>,
579 _cursor_position: language::Anchor,
580 cx: &App,
581 ) -> bool {
582 self.sweep_ai.read(cx).api_token.is_some()
583 }
584
585 fn is_refreshing(&self) -> bool {
586 !self.pending_completions.is_empty()
587 }
588
589 fn refresh(
590 &mut self,
591 buffer: Entity<Buffer>,
592 position: language::Anchor,
593 _debounce: bool,
594 cx: &mut Context<Self>,
595 ) {
596 if let Some(current_completion) = self.current_completion.as_ref() {
597 let snapshot = buffer.read(cx).snapshot();
598 if current_completion
599 .completion
600 .interpolate(&snapshot)
601 .is_some()
602 {
603 return;
604 }
605 }
606
607 let pending_completion_id = self.next_pending_completion_id;
608 self.next_pending_completion_id += 1;
609 let last_request_timestamp = self.last_request_timestamp;
610
611 let project = self.project.clone();
612 let workspace = self.workspace.clone();
613 let task = cx.spawn(async move |this, cx| {
614 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
615 .checked_duration_since(Instant::now())
616 {
617 cx.background_executor().timer(timeout).await;
618 }
619
620 let completion_request = this.update(cx, |this, cx| {
621 this.last_request_timestamp = Instant::now();
622 this.sweep_ai.update(cx, |sweep_ai, cx| {
623 sweep_ai.request_completion(&workspace, &project, &buffer, position, cx)
624 })
625 });
626
627 let completion = match completion_request {
628 Ok(completion_request) => {
629 let completion_request = completion_request.await;
630 completion_request.map(|c| {
631 c.map(|completion| CurrentEditPrediction {
632 buffer_id: buffer.entity_id(),
633 completion,
634 })
635 })
636 }
637 Err(error) => Err(error),
638 };
639
640 let Some(new_completion) = completion
641 .context("edit prediction failed")
642 .log_err()
643 .flatten()
644 else {
645 this.update(cx, |this, cx| {
646 if this.pending_completions[0].id == pending_completion_id {
647 this.pending_completions.remove(0);
648 } else {
649 this.pending_completions.clear();
650 }
651
652 cx.notify();
653 })
654 .ok();
655 return;
656 };
657
658 this.update(cx, |this, cx| {
659 if this.pending_completions[0].id == pending_completion_id {
660 this.pending_completions.remove(0);
661 } else {
662 this.pending_completions.clear();
663 }
664
665 if let Some(old_completion) = this.current_completion.as_ref() {
666 let snapshot = buffer.read(cx).snapshot();
667 if new_completion.should_replace_completion(old_completion, &snapshot) {
668 this.current_completion = Some(new_completion);
669 }
670 } else {
671 this.current_completion = Some(new_completion);
672 }
673
674 cx.notify();
675 })
676 .ok();
677 });
678
679 // We always maintain at most two pending completions. When we already
680 // have two, we replace the newest one.
681 if self.pending_completions.len() <= 1 {
682 self.pending_completions.push(PendingCompletion {
683 id: pending_completion_id,
684 _task: task,
685 });
686 } else if self.pending_completions.len() == 2 {
687 self.pending_completions.pop();
688 self.pending_completions.push(PendingCompletion {
689 id: pending_completion_id,
690 _task: task,
691 });
692 }
693 }
694
695 fn cycle(
696 &mut self,
697 _buffer: Entity<Buffer>,
698 _cursor_position: language::Anchor,
699 _direction: edit_prediction::Direction,
700 _cx: &mut Context<Self>,
701 ) {
702 // Right now we don't support cycling.
703 }
704
705 fn accept(&mut self, _cx: &mut Context<Self>) {
706 self.pending_completions.clear();
707 }
708
709 fn discard(&mut self, _cx: &mut Context<Self>) {
710 self.pending_completions.clear();
711 self.current_completion.take();
712 }
713
714 fn suggest(
715 &mut self,
716 buffer: &Entity<Buffer>,
717 cursor_position: language::Anchor,
718 cx: &mut Context<Self>,
719 ) -> Option<edit_prediction::EditPrediction> {
720 let CurrentEditPrediction {
721 buffer_id,
722 completion,
723 ..
724 } = self.current_completion.as_mut()?;
725
726 // Invalidate previous completion if it was generated for a different buffer.
727 if *buffer_id != buffer.entity_id() {
728 self.current_completion.take();
729 return None;
730 }
731
732 let buffer = buffer.read(cx);
733 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
734 self.current_completion.take();
735 return None;
736 };
737
738 let cursor_row = cursor_position.to_point(buffer).row;
739 let (closest_edit_ix, (closest_edit_range, _)) =
740 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
741 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
742 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
743 cmp::min(distance_from_start, distance_from_end)
744 })?;
745
746 let mut edit_start_ix = closest_edit_ix;
747 for (range, _) in edits[..edit_start_ix].iter().rev() {
748 let distance_from_closest_edit =
749 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
750 if distance_from_closest_edit <= 1 {
751 edit_start_ix -= 1;
752 } else {
753 break;
754 }
755 }
756
757 let mut edit_end_ix = closest_edit_ix + 1;
758 for (range, _) in &edits[edit_end_ix..] {
759 let distance_from_closest_edit =
760 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
761 if distance_from_closest_edit <= 1 {
762 edit_end_ix += 1;
763 } else {
764 break;
765 }
766 }
767
768 Some(edit_prediction::EditPrediction::Local {
769 id: Some(completion.id.to_string().into()),
770 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
771 edit_preview: Some(completion.edit_preview.clone()),
772 })
773 }
774}