1mod api;
2
3use anyhow::{Context as _, Result};
4use arrayvec::ArrayVec;
5use client::telemetry;
6use collections::HashMap;
7use feature_flags::FeatureFlag;
8use futures::AsyncReadExt as _;
9use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity};
10use http_client::{AsyncBody, Method};
11use language::{Anchor, Buffer, BufferSnapshot, EditPreview, ToOffset as _, ToPoint, text_diff};
12use project::Project;
13use release_channel::{AppCommitSha, AppVersion};
14use std::collections::{VecDeque, hash_map};
15use std::fmt::{self, Display};
16use std::mem;
17use std::{
18 cmp,
19 fmt::Write,
20 ops::Range,
21 path::Path,
22 sync::Arc,
23 time::{Duration, Instant},
24};
25use util::ResultExt;
26use util::rel_path::RelPath;
27use workspace::Workspace;
28
29use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk};
30
31const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
32const MAX_EVENT_COUNT: usize = 16;
33
34const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
35
36pub struct SweepFeatureFlag;
37
38impl FeatureFlag for SweepFeatureFlag {
39 const NAME: &str = "sweep-ai";
40
41 fn enabled_for_staff() -> bool {
42 false
43 }
44}
45
46#[derive(Clone)]
47struct SweepAiGlobal(Entity<SweepAi>);
48
49impl Global for SweepAiGlobal {}
50
51#[derive(Clone)]
52pub struct EditPrediction {
53 id: EditPredictionId,
54 path: Arc<Path>,
55 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
56 snapshot: BufferSnapshot,
57 edit_preview: EditPreview,
58}
59
60impl EditPrediction {
61 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
62 edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
63 }
64}
65
66impl fmt::Debug for EditPrediction {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.debug_struct("EditPrediction")
69 .field("path", &self.path)
70 .field("edits", &self.edits)
71 .finish_non_exhaustive()
72 }
73}
74
75#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
76pub struct EditPredictionId(String);
77
78impl Display for EditPredictionId {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84pub struct SweepAi {
85 projects: HashMap<EntityId, SweepAiProject>,
86 debug_info: Arc<str>,
87}
88
89struct SweepAiProject {
90 events: VecDeque<Event>,
91 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
92}
93
94impl SweepAi {
95 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
96 cx.try_global::<SweepAiGlobal>()
97 .map(|global| global.0.clone())
98 }
99
100 pub fn register(cx: &mut App) -> Entity<Self> {
101 Self::global(cx).unwrap_or_else(|| {
102 let entity = cx.new(|cx| Self::new(cx));
103 cx.set_global(SweepAiGlobal(entity.clone()));
104 entity
105 })
106 }
107
108 pub fn clear_history(&mut self) {
109 for sweep_ai_project in self.projects.values_mut() {
110 sweep_ai_project.events.clear();
111 }
112 }
113
114 fn new(cx: &mut Context<Self>) -> Self {
115 Self {
116 projects: HashMap::default(),
117 debug_info: format!(
118 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
119 version = AppVersion::global(cx),
120 sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()),
121 os = telemetry::os_name(),
122 )
123 .into(),
124 }
125 }
126
127 fn get_or_init_sweep_ai_project(
128 &mut self,
129 project: &Entity<Project>,
130 cx: &mut Context<Self>,
131 ) -> &mut SweepAiProject {
132 let project_id = project.entity_id();
133 match self.projects.entry(project_id) {
134 hash_map::Entry::Occupied(entry) => entry.into_mut(),
135 hash_map::Entry::Vacant(entry) => {
136 cx.observe_release(project, move |this, _, _cx| {
137 this.projects.remove(&project_id);
138 })
139 .detach();
140 entry.insert(SweepAiProject {
141 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
142 registered_buffers: HashMap::default(),
143 })
144 }
145 }
146 }
147
148 fn push_event(sweep_ai_project: &mut SweepAiProject, event: Event) {
149 let events = &mut sweep_ai_project.events;
150
151 if let Some(Event::BufferChange {
152 new_snapshot: last_new_snapshot,
153 timestamp: last_timestamp,
154 ..
155 }) = events.back_mut()
156 {
157 // Coalesce edits for the same buffer when they happen one after the other.
158 let Event::BufferChange {
159 old_snapshot,
160 new_snapshot,
161 timestamp,
162 } = &event;
163
164 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
165 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
166 && old_snapshot.version == last_new_snapshot.version
167 {
168 *last_new_snapshot = new_snapshot.clone();
169 *last_timestamp = *timestamp;
170 return;
171 }
172 }
173
174 if events.len() >= MAX_EVENT_COUNT {
175 // These are halved instead of popping to improve prompt caching.
176 events.drain(..MAX_EVENT_COUNT / 2);
177 }
178
179 events.push_back(event);
180 }
181
182 pub fn register_buffer(
183 &mut self,
184 buffer: &Entity<Buffer>,
185 project: &Entity<Project>,
186 cx: &mut Context<Self>,
187 ) {
188 let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
189 Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
190 }
191
192 fn register_buffer_impl<'a>(
193 sweep_ai_project: &'a mut SweepAiProject,
194 buffer: &Entity<Buffer>,
195 project: &Entity<Project>,
196 cx: &mut Context<Self>,
197 ) -> &'a mut RegisteredBuffer {
198 let buffer_id = buffer.entity_id();
199 match sweep_ai_project.registered_buffers.entry(buffer_id) {
200 hash_map::Entry::Occupied(entry) => entry.into_mut(),
201 hash_map::Entry::Vacant(entry) => {
202 let snapshot = buffer.read(cx).snapshot();
203 let project_entity_id = project.entity_id();
204 entry.insert(RegisteredBuffer {
205 snapshot,
206 _subscriptions: [
207 cx.subscribe(buffer, {
208 let project = project.downgrade();
209 move |this, buffer, event, cx| {
210 if let language::BufferEvent::Edited = event
211 && let Some(project) = project.upgrade()
212 {
213 this.report_changes_for_buffer(&buffer, &project, cx);
214 }
215 }
216 }),
217 cx.observe_release(buffer, move |this, _buffer, _cx| {
218 let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id)
219 else {
220 return;
221 };
222 sweep_ai_project.registered_buffers.remove(&buffer_id);
223 }),
224 ],
225 })
226 }
227 }
228 }
229
230 pub fn request_completion(
231 &mut self,
232 workspace: &WeakEntity<Workspace>,
233 project: &Entity<Project>,
234 active_buffer: &Entity<Buffer>,
235 position: language::Anchor,
236 cx: &mut Context<Self>,
237 ) -> Task<Result<Option<EditPrediction>>> {
238 let snapshot = active_buffer.read(cx).snapshot();
239 let debug_info = self.debug_info.clone();
240 let full_path: Arc<Path> = snapshot
241 .file()
242 .map(|file| file.full_path(cx))
243 .unwrap_or_else(|| "untitled".into())
244 .into();
245
246 let project_file = project::File::from_dyn(snapshot.file());
247 let repo_name = project_file
248 .map(|file| file.worktree.read(cx).root_name_str())
249 .unwrap_or("untitled")
250 .into();
251 let offset = position.to_offset(&snapshot);
252
253 let project_state = self.get_or_init_sweep_ai_project(project, cx);
254 let events = project_state.events.clone();
255 let http_client = cx.http_client();
256
257 let Some(recent_buffers) = workspace
258 .read_with(cx, |workspace, cx| {
259 workspace
260 .recent_navigation_history_iter(cx)
261 .filter_map(|(project_path, _)| {
262 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
263
264 if active_buffer == &buffer {
265 None
266 } else {
267 Some(buffer.read(cx).snapshot())
268 }
269 })
270 .take(3)
271 .collect::<Vec<_>>()
272 })
273 .log_err()
274 else {
275 return Task::ready(Ok(None));
276 };
277
278 let result = cx.background_spawn({
279 let full_path = full_path.clone();
280 async move {
281 let text = snapshot.text();
282
283 let mut recent_changes = String::new();
284
285 for event in events {
286 writeln!(&mut recent_changes, "{event}")?;
287 }
288
289 let file_chunks = recent_buffers
290 .into_iter()
291 .map(|snapshot| {
292 let end_point = language::Point::new(30, 0).min(snapshot.max_point());
293 FileChunk {
294 content: snapshot
295 .text_for_range(language::Point::zero()..end_point)
296 .collect(),
297 file_path: snapshot
298 .file()
299 .map(|f| f.path().as_unix_str())
300 .unwrap_or("untitled")
301 .to_string(),
302 start_line: 0,
303 end_line: end_point.row as usize,
304 timestamp: snapshot.file().and_then(|file| {
305 Some(
306 file.disk_state()
307 .mtime()?
308 .to_seconds_and_nanos_for_persistence()?
309 .0,
310 )
311 }),
312 }
313 })
314 .collect();
315
316 let request_body = AutocompleteRequest {
317 debug_info,
318 repo_name,
319 file_path: full_path.clone(),
320 file_contents: text.clone(),
321 original_file_contents: text,
322 cursor_position: offset,
323 recent_changes: recent_changes.clone(),
324 changes_above_cursor: true,
325 multiple_suggestions: false,
326 branch: None,
327 file_chunks,
328 retrieval_chunks: vec![],
329 recent_user_actions: vec![],
330 // TODO
331 privacy_mode_enabled: false,
332 };
333
334 let mut buf: Vec<u8> = Vec::new();
335 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
336 serde_json::to_writer(writer, &request_body)?;
337 let body: AsyncBody = buf.into();
338
339 let request = http_client::Request::builder()
340 .uri(SWEEP_API_URL)
341 .header("Content-Type", "application/json")
342 .header(
343 "Authorization",
344 format!("Bearer {}", std::env::var("SWEEP_TOKEN").unwrap()),
345 )
346 .header("Connection", "keep-alive")
347 .header("Content-Encoding", "br")
348 .method(Method::POST)
349 .body(body)?;
350
351 let mut response = http_client.send(request).await?;
352
353 let mut body: Vec<u8> = Vec::new();
354 response.body_mut().read_to_end(&mut body).await?;
355
356 if !response.status().is_success() {
357 anyhow::bail!(
358 "Request failed with status: {:?}\nBody: {}",
359 response.status(),
360 String::from_utf8_lossy(&body),
361 );
362 };
363
364 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
365
366 let old_text = snapshot
367 .text_for_range(response.start_index..response.end_index)
368 .collect::<String>();
369 let edits = text_diff(&old_text, &response.completion)
370 .into_iter()
371 .map(|(range, text)| {
372 (
373 snapshot.anchor_after(response.start_index + range.start)
374 ..snapshot.anchor_before(response.start_index + range.end),
375 text,
376 )
377 })
378 .collect::<Vec<_>>();
379
380 anyhow::Ok((response.autocomplete_id, edits, snapshot))
381 }
382 });
383
384 let buffer = active_buffer.clone();
385
386 cx.spawn(async move |_, cx| {
387 let (id, edits, old_snapshot) = result.await?;
388
389 if edits.is_empty() {
390 return anyhow::Ok(None);
391 }
392
393 let Some((edits, new_snapshot, preview_task)) =
394 buffer.read_with(cx, |buffer, cx| {
395 let new_snapshot = buffer.snapshot();
396
397 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
398 edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
399 .into();
400 let preview_task = buffer.preview_edits(edits.clone(), cx);
401
402 Some((edits, new_snapshot, preview_task))
403 })?
404 else {
405 return anyhow::Ok(None);
406 };
407
408 let prediction = EditPrediction {
409 id: EditPredictionId(id),
410 path: full_path,
411 edits,
412 snapshot: new_snapshot,
413 edit_preview: preview_task.await,
414 };
415
416 anyhow::Ok(Some(prediction))
417 })
418 }
419
420 fn report_changes_for_buffer(
421 &mut self,
422 buffer: &Entity<Buffer>,
423 project: &Entity<Project>,
424 cx: &mut Context<Self>,
425 ) -> BufferSnapshot {
426 let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
427 let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
428
429 let new_snapshot = buffer.read(cx).snapshot();
430 if new_snapshot.version != registered_buffer.snapshot.version {
431 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
432 Self::push_event(
433 sweep_ai_project,
434 Event::BufferChange {
435 old_snapshot,
436 new_snapshot: new_snapshot.clone(),
437 timestamp: Instant::now(),
438 },
439 );
440 }
441
442 new_snapshot
443 }
444}
445
446struct RegisteredBuffer {
447 snapshot: BufferSnapshot,
448 _subscriptions: [gpui::Subscription; 2],
449}
450
451#[derive(Clone)]
452pub enum Event {
453 BufferChange {
454 old_snapshot: BufferSnapshot,
455 new_snapshot: BufferSnapshot,
456 timestamp: Instant,
457 },
458}
459
460impl Display for Event {
461 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462 match self {
463 Event::BufferChange {
464 old_snapshot,
465 new_snapshot,
466 ..
467 } => {
468 let old_path = old_snapshot
469 .file()
470 .map(|f| f.path().as_ref())
471 .unwrap_or(RelPath::unix("untitled").unwrap());
472 let new_path = new_snapshot
473 .file()
474 .map(|f| f.path().as_ref())
475 .unwrap_or(RelPath::unix("untitled").unwrap());
476 if old_path != new_path {
477 // TODO confirm how to do this for sweep
478 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
479 }
480
481 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
482 if !diff.is_empty() {
483 write!(
484 f,
485 "File: {}:\n{}\n",
486 new_path.display(util::paths::PathStyle::Posix),
487 diff
488 )?
489 }
490
491 fmt::Result::Ok(())
492 }
493 }
494 }
495}
496
497#[derive(Debug, Clone)]
498struct CurrentEditPrediction {
499 buffer_id: EntityId,
500 completion: EditPrediction,
501}
502
503impl CurrentEditPrediction {
504 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
505 if self.buffer_id != old_completion.buffer_id {
506 return true;
507 }
508
509 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
510 return true;
511 };
512 let Some(new_edits) = self.completion.interpolate(snapshot) else {
513 return false;
514 };
515
516 if old_edits.len() == 1 && new_edits.len() == 1 {
517 let (old_range, old_text) = &old_edits[0];
518 let (new_range, new_text) = &new_edits[0];
519 new_range == old_range && new_text.starts_with(old_text.as_ref())
520 } else {
521 true
522 }
523 }
524}
525
526struct PendingCompletion {
527 id: usize,
528 _task: Task<()>,
529}
530
531pub struct SweepAiEditPredictionProvider {
532 workspace: WeakEntity<Workspace>,
533 sweep_ai: Entity<SweepAi>,
534 pending_completions: ArrayVec<PendingCompletion, 2>,
535 next_pending_completion_id: usize,
536 current_completion: Option<CurrentEditPrediction>,
537 last_request_timestamp: Instant,
538 project: Entity<Project>,
539}
540
541impl SweepAiEditPredictionProvider {
542 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
543
544 pub fn new(
545 sweep_ai: Entity<SweepAi>,
546 workspace: WeakEntity<Workspace>,
547 project: Entity<Project>,
548 ) -> Self {
549 Self {
550 sweep_ai,
551 pending_completions: ArrayVec::new(),
552 next_pending_completion_id: 0,
553 current_completion: None,
554 last_request_timestamp: Instant::now(),
555 project,
556 workspace,
557 }
558 }
559}
560
561impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
562 fn name() -> &'static str {
563 "zed-predict"
564 }
565
566 fn display_name() -> &'static str {
567 "Zed's Edit Predictions"
568 }
569
570 fn show_completions_in_menu() -> bool {
571 true
572 }
573
574 fn show_tab_accept_marker() -> bool {
575 true
576 }
577
578 fn is_enabled(
579 &self,
580 _buffer: &Entity<Buffer>,
581 _cursor_position: language::Anchor,
582 _cx: &App,
583 ) -> bool {
584 true
585 }
586
587 fn is_refreshing(&self) -> bool {
588 !self.pending_completions.is_empty()
589 }
590
591 fn refresh(
592 &mut self,
593 buffer: Entity<Buffer>,
594 position: language::Anchor,
595 _debounce: bool,
596 cx: &mut Context<Self>,
597 ) {
598 if let Some(current_completion) = self.current_completion.as_ref() {
599 let snapshot = buffer.read(cx).snapshot();
600 if current_completion
601 .completion
602 .interpolate(&snapshot)
603 .is_some()
604 {
605 return;
606 }
607 }
608
609 let pending_completion_id = self.next_pending_completion_id;
610 self.next_pending_completion_id += 1;
611 let last_request_timestamp = self.last_request_timestamp;
612
613 let project = self.project.clone();
614 let workspace = self.workspace.clone();
615 let task = cx.spawn(async move |this, cx| {
616 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
617 .checked_duration_since(Instant::now())
618 {
619 cx.background_executor().timer(timeout).await;
620 }
621
622 let completion_request = this.update(cx, |this, cx| {
623 this.last_request_timestamp = Instant::now();
624 this.sweep_ai.update(cx, |sweep_ai, cx| {
625 sweep_ai.request_completion(&workspace, &project, &buffer, position, cx)
626 })
627 });
628
629 let completion = match completion_request {
630 Ok(completion_request) => {
631 let completion_request = completion_request.await;
632 completion_request.map(|c| {
633 c.map(|completion| CurrentEditPrediction {
634 buffer_id: buffer.entity_id(),
635 completion,
636 })
637 })
638 }
639 Err(error) => Err(error),
640 };
641
642 let Some(new_completion) = completion
643 .context("edit prediction failed")
644 .log_err()
645 .flatten()
646 else {
647 this.update(cx, |this, cx| {
648 if this.pending_completions[0].id == pending_completion_id {
649 this.pending_completions.remove(0);
650 } else {
651 this.pending_completions.clear();
652 }
653
654 cx.notify();
655 })
656 .ok();
657 return;
658 };
659
660 this.update(cx, |this, cx| {
661 if this.pending_completions[0].id == pending_completion_id {
662 this.pending_completions.remove(0);
663 } else {
664 this.pending_completions.clear();
665 }
666
667 if let Some(old_completion) = this.current_completion.as_ref() {
668 let snapshot = buffer.read(cx).snapshot();
669 if new_completion.should_replace_completion(old_completion, &snapshot) {
670 this.current_completion = Some(new_completion);
671 }
672 } else {
673 this.current_completion = Some(new_completion);
674 }
675
676 cx.notify();
677 })
678 .ok();
679 });
680
681 // We always maintain at most two pending completions. When we already
682 // have two, we replace the newest one.
683 if self.pending_completions.len() <= 1 {
684 self.pending_completions.push(PendingCompletion {
685 id: pending_completion_id,
686 _task: task,
687 });
688 } else if self.pending_completions.len() == 2 {
689 self.pending_completions.pop();
690 self.pending_completions.push(PendingCompletion {
691 id: pending_completion_id,
692 _task: task,
693 });
694 }
695 }
696
697 fn cycle(
698 &mut self,
699 _buffer: Entity<Buffer>,
700 _cursor_position: language::Anchor,
701 _direction: edit_prediction::Direction,
702 _cx: &mut Context<Self>,
703 ) {
704 // Right now we don't support cycling.
705 }
706
707 fn accept(&mut self, _cx: &mut Context<Self>) {
708 self.pending_completions.clear();
709 }
710
711 fn discard(&mut self, _cx: &mut Context<Self>) {
712 self.pending_completions.clear();
713 self.current_completion.take();
714 }
715
716 fn suggest(
717 &mut self,
718 buffer: &Entity<Buffer>,
719 cursor_position: language::Anchor,
720 cx: &mut Context<Self>,
721 ) -> Option<edit_prediction::EditPrediction> {
722 let CurrentEditPrediction {
723 buffer_id,
724 completion,
725 ..
726 } = self.current_completion.as_mut()?;
727
728 // Invalidate previous completion if it was generated for a different buffer.
729 if *buffer_id != buffer.entity_id() {
730 self.current_completion.take();
731 return None;
732 }
733
734 let buffer = buffer.read(cx);
735 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
736 self.current_completion.take();
737 return None;
738 };
739
740 let cursor_row = cursor_position.to_point(buffer).row;
741 let (closest_edit_ix, (closest_edit_range, _)) =
742 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
743 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
744 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
745 cmp::min(distance_from_start, distance_from_end)
746 })?;
747
748 let mut edit_start_ix = closest_edit_ix;
749 for (range, _) in edits[..edit_start_ix].iter().rev() {
750 let distance_from_closest_edit =
751 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
752 if distance_from_closest_edit <= 1 {
753 edit_start_ix -= 1;
754 } else {
755 break;
756 }
757 }
758
759 let mut edit_end_ix = closest_edit_ix + 1;
760 for (range, _) in &edits[edit_end_ix..] {
761 let distance_from_closest_edit =
762 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
763 if distance_from_closest_edit <= 1 {
764 edit_end_ix += 1;
765 } else {
766 break;
767 }
768 }
769
770 Some(edit_prediction::EditPrediction::Local {
771 id: Some(completion.id.to_string().into()),
772 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
773 edit_preview: Some(completion.edit_preview.clone()),
774 })
775 }
776}