1mod api;
2
3use anyhow::{Context as _, Result};
4use arrayvec::ArrayVec;
5use client::telemetry;
6use collections::HashMap;
7use feature_flags::FeatureFlag;
8use futures::AsyncReadExt as _;
9use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity};
10use http_client::{AsyncBody, Method};
11use language::{
12 Anchor, Buffer, BufferSnapshot, EditPreview, Point, ToOffset as _, ToPoint, text_diff,
13};
14use project::{Project, ProjectPath};
15use release_channel::{AppCommitSha, AppVersion};
16use std::collections::{VecDeque, hash_map};
17use std::fmt::{self, Display};
18use std::mem;
19use std::{
20 cmp,
21 fmt::Write,
22 ops::Range,
23 path::Path,
24 sync::Arc,
25 time::{Duration, Instant},
26};
27use util::ResultExt;
28use util::rel_path::RelPath;
29use workspace::Workspace;
30
31use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk};
32
33const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
34const MAX_EVENT_COUNT: usize = 6;
35
36const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
37
38pub struct SweepFeatureFlag;
39
40impl FeatureFlag for SweepFeatureFlag {
41 const NAME: &str = "sweep-ai";
42}
43
44#[derive(Clone)]
45struct SweepAiGlobal(Entity<SweepAi>);
46
47impl Global for SweepAiGlobal {}
48
49#[derive(Clone)]
50pub struct EditPrediction {
51 pub id: EditPredictionId,
52 pub path: Arc<Path>,
53 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
54 pub snapshot: BufferSnapshot,
55 pub edit_preview: EditPreview,
56}
57
58impl EditPrediction {
59 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
60 edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
61 }
62}
63
64impl fmt::Debug for EditPrediction {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("EditPrediction")
67 .field("path", &self.path)
68 .field("edits", &self.edits)
69 .finish_non_exhaustive()
70 }
71}
72
73#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
74pub struct EditPredictionId(String);
75
76impl Display for EditPredictionId {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 write!(f, "{}", self.0)
79 }
80}
81
82pub struct SweepAi {
83 projects: HashMap<EntityId, SweepAiProject>,
84 debug_info: Arc<str>,
85 api_token: Option<String>,
86}
87
88struct SweepAiProject {
89 events: VecDeque<Event>,
90 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
91}
92
93impl SweepAi {
94 pub fn global(cx: &mut App) -> Option<Entity<Self>> {
95 cx.try_global::<SweepAiGlobal>()
96 .map(|global| global.0.clone())
97 }
98
99 pub fn register(cx: &mut App) -> Entity<Self> {
100 Self::global(cx).unwrap_or_else(|| {
101 let entity = cx.new(|cx| Self::new(cx));
102 cx.set_global(SweepAiGlobal(entity.clone()));
103 entity
104 })
105 }
106
107 pub fn clear_history(&mut self) {
108 for sweep_ai_project in self.projects.values_mut() {
109 sweep_ai_project.events.clear();
110 }
111 }
112
113 pub fn new(cx: &mut Context<Self>) -> Self {
114 Self {
115 api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
116 projects: HashMap::default(),
117 debug_info: format!(
118 "Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
119 version = AppVersion::global(cx),
120 sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()),
121 os = telemetry::os_name(),
122 )
123 .into(),
124 }
125 }
126
127 fn get_or_init_sweep_ai_project(
128 &mut self,
129 project: &Entity<Project>,
130 cx: &mut Context<Self>,
131 ) -> &mut SweepAiProject {
132 let project_id = project.entity_id();
133 match self.projects.entry(project_id) {
134 hash_map::Entry::Occupied(entry) => entry.into_mut(),
135 hash_map::Entry::Vacant(entry) => {
136 cx.observe_release(project, move |this, _, _cx| {
137 this.projects.remove(&project_id);
138 })
139 .detach();
140 entry.insert(SweepAiProject {
141 events: VecDeque::with_capacity(MAX_EVENT_COUNT),
142 registered_buffers: HashMap::default(),
143 })
144 }
145 }
146 }
147
148 pub fn register_buffer(
149 &mut self,
150 buffer: &Entity<Buffer>,
151 project: &Entity<Project>,
152 cx: &mut Context<Self>,
153 ) {
154 let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
155 Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
156 }
157
158 fn register_buffer_impl<'a>(
159 sweep_ai_project: &'a mut SweepAiProject,
160 buffer: &Entity<Buffer>,
161 project: &Entity<Project>,
162 cx: &mut Context<Self>,
163 ) -> &'a mut RegisteredBuffer {
164 let buffer_id = buffer.entity_id();
165 match sweep_ai_project.registered_buffers.entry(buffer_id) {
166 hash_map::Entry::Occupied(entry) => entry.into_mut(),
167 hash_map::Entry::Vacant(entry) => {
168 let snapshot = buffer.read(cx).snapshot();
169 let project_entity_id = project.entity_id();
170 entry.insert(RegisteredBuffer {
171 snapshot,
172 _subscriptions: [
173 cx.subscribe(buffer, {
174 let project = project.downgrade();
175 move |this, buffer, event, cx| {
176 if let language::BufferEvent::Edited = event
177 && let Some(project) = project.upgrade()
178 {
179 this.report_changes_for_buffer(&buffer, &project, cx);
180 }
181 }
182 }),
183 cx.observe_release(buffer, move |this, _buffer, _cx| {
184 let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id)
185 else {
186 return;
187 };
188 sweep_ai_project.registered_buffers.remove(&buffer_id);
189 }),
190 ],
191 })
192 }
193 }
194 }
195
196 pub fn request_completion(
197 &mut self,
198 project: &Entity<Project>,
199 recent_buffers: impl Iterator<Item = ProjectPath>,
200 active_buffer: &Entity<Buffer>,
201 position: language::Anchor,
202 cx: &mut Context<Self>,
203 ) -> Task<Result<Option<EditPrediction>>> {
204 let snapshot = active_buffer.read(cx).snapshot();
205 let debug_info = self.debug_info.clone();
206 let Some(api_token) = self.api_token.clone() else {
207 return Task::ready(Ok(None));
208 };
209 let full_path: Arc<Path> = snapshot
210 .file()
211 .map(|file| file.full_path(cx))
212 .unwrap_or_else(|| "untitled".into())
213 .into();
214
215 let project_file = project::File::from_dyn(snapshot.file());
216 let repo_name = project_file
217 .map(|file| file.worktree.read(cx).root_name_str())
218 .unwrap_or("untitled")
219 .into();
220 let offset = position.to_offset(&snapshot);
221
222 let project_state = self.get_or_init_sweep_ai_project(project, cx);
223 let events = project_state.events.clone();
224 let http_client = cx.http_client();
225
226 let recent_buffer_snapshots = recent_buffers
227 .filter_map(|project_path| {
228 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
229 if active_buffer == &buffer {
230 None
231 } else {
232 Some(buffer.read(cx).snapshot())
233 }
234 })
235 .take(3)
236 .collect::<Vec<_>>();
237
238 let result = cx.background_spawn({
239 let full_path = full_path.clone();
240 async move {
241 let text = snapshot.text();
242
243 let mut recent_changes = String::new();
244
245 for event in events {
246 writeln!(&mut recent_changes, "{event}")?;
247 }
248
249 let file_chunks = recent_buffer_snapshots
250 .into_iter()
251 .map(|snapshot| {
252 let end_point = language::Point::new(30, 0).min(snapshot.max_point());
253 FileChunk {
254 content: snapshot
255 .text_for_range(language::Point::zero()..end_point)
256 .collect(),
257 file_path: snapshot
258 .file()
259 .map(|f| f.path().as_unix_str())
260 .unwrap_or("untitled")
261 .to_string(),
262 start_line: 0,
263 end_line: end_point.row as usize,
264 timestamp: snapshot.file().and_then(|file| {
265 Some(
266 file.disk_state()
267 .mtime()?
268 .to_seconds_and_nanos_for_persistence()?
269 .0,
270 )
271 }),
272 }
273 })
274 .collect();
275
276 eprintln!("{recent_changes}");
277
278 let request_body = AutocompleteRequest {
279 debug_info,
280 repo_name,
281 file_path: full_path.clone(),
282 file_contents: text.clone(),
283 original_file_contents: text,
284 cursor_position: offset,
285 recent_changes: recent_changes.clone(),
286 changes_above_cursor: true,
287 multiple_suggestions: false,
288 branch: None,
289 file_chunks,
290 retrieval_chunks: vec![],
291 recent_user_actions: vec![],
292 // TODO
293 privacy_mode_enabled: false,
294 };
295
296 let mut buf: Vec<u8> = Vec::new();
297 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
298 serde_json::to_writer(writer, &request_body)?;
299 let body: AsyncBody = buf.into();
300
301 let request = http_client::Request::builder()
302 .uri(SWEEP_API_URL)
303 .header("Content-Type", "application/json")
304 .header("Authorization", format!("Bearer {}", api_token))
305 .header("Connection", "keep-alive")
306 .header("Content-Encoding", "br")
307 .method(Method::POST)
308 .body(body)?;
309
310 let mut response = http_client.send(request).await?;
311
312 let mut body: Vec<u8> = Vec::new();
313 response.body_mut().read_to_end(&mut body).await?;
314
315 if !response.status().is_success() {
316 anyhow::bail!(
317 "Request failed with status: {:?}\nBody: {}",
318 response.status(),
319 String::from_utf8_lossy(&body),
320 );
321 };
322
323 let response: AutocompleteResponse = serde_json::from_slice(&body)?;
324
325 let old_text = snapshot
326 .text_for_range(response.start_index..response.end_index)
327 .collect::<String>();
328 let edits = text_diff(&old_text, &response.completion)
329 .into_iter()
330 .map(|(range, text)| {
331 (
332 snapshot.anchor_after(response.start_index + range.start)
333 ..snapshot.anchor_before(response.start_index + range.end),
334 text,
335 )
336 })
337 .collect::<Vec<_>>();
338
339 anyhow::Ok((response.autocomplete_id, edits, snapshot))
340 }
341 });
342
343 let buffer = active_buffer.clone();
344
345 cx.spawn(async move |_, cx| {
346 let (id, edits, old_snapshot) = result.await?;
347
348 if edits.is_empty() {
349 return anyhow::Ok(None);
350 }
351
352 let Some((edits, new_snapshot, preview_task)) =
353 buffer.read_with(cx, |buffer, cx| {
354 let new_snapshot = buffer.snapshot();
355
356 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
357 edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
358 .into();
359 let preview_task = buffer.preview_edits(edits.clone(), cx);
360
361 Some((edits, new_snapshot, preview_task))
362 })?
363 else {
364 return anyhow::Ok(None);
365 };
366
367 let prediction = EditPrediction {
368 id: EditPredictionId(id),
369 path: full_path,
370 edits,
371 snapshot: new_snapshot,
372 edit_preview: preview_task.await,
373 };
374
375 anyhow::Ok(Some(prediction))
376 })
377 }
378
379 fn report_changes_for_buffer(
380 &mut self,
381 buffer: &Entity<Buffer>,
382 project: &Entity<Project>,
383 cx: &mut Context<Self>,
384 ) {
385 let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
386 let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
387
388 let new_snapshot = buffer.read(cx).snapshot();
389 if new_snapshot.version == registered_buffer.snapshot.version {
390 return;
391 }
392
393 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
394 let end_edit_anchor = new_snapshot
395 .anchored_edits_since::<Point>(&old_snapshot.version)
396 .last()
397 .map(|(_, range)| range.end);
398 let events = &mut sweep_ai_project.events;
399
400 if let Some(Event::BufferChange {
401 new_snapshot: last_new_snapshot,
402 end_edit_anchor: last_end_edit_anchor,
403 ..
404 }) = events.back_mut()
405 {
406 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
407 == last_new_snapshot.remote_id()
408 && old_snapshot.version == last_new_snapshot.version;
409
410 let should_coalesce = is_next_snapshot_of_same_buffer
411 && end_edit_anchor
412 .as_ref()
413 .zip(last_end_edit_anchor.as_ref())
414 .is_some_and(|(a, b)| {
415 let a = a.to_point(&new_snapshot);
416 let b = b.to_point(&new_snapshot);
417 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
418 });
419
420 if should_coalesce {
421 *last_end_edit_anchor = end_edit_anchor;
422 *last_new_snapshot = new_snapshot;
423 return;
424 }
425 }
426
427 if events.len() >= MAX_EVENT_COUNT {
428 events.pop_front();
429 }
430
431 events.push_back(Event::BufferChange {
432 old_snapshot,
433 new_snapshot,
434 end_edit_anchor,
435 });
436 }
437}
438
439struct RegisteredBuffer {
440 snapshot: BufferSnapshot,
441 _subscriptions: [gpui::Subscription; 2],
442}
443
444#[derive(Clone)]
445pub enum Event {
446 BufferChange {
447 old_snapshot: BufferSnapshot,
448 new_snapshot: BufferSnapshot,
449 end_edit_anchor: Option<Anchor>,
450 },
451}
452
453impl Display for Event {
454 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455 match self {
456 Event::BufferChange {
457 old_snapshot,
458 new_snapshot,
459 ..
460 } => {
461 let old_path = old_snapshot
462 .file()
463 .map(|f| f.path().as_ref())
464 .unwrap_or(RelPath::unix("untitled").unwrap());
465 let new_path = new_snapshot
466 .file()
467 .map(|f| f.path().as_ref())
468 .unwrap_or(RelPath::unix("untitled").unwrap());
469 if old_path != new_path {
470 // TODO confirm how to do this for sweep
471 // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
472 }
473
474 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
475 if !diff.is_empty() {
476 write!(
477 f,
478 "File: {}:\n{}\n",
479 new_path.display(util::paths::PathStyle::Posix),
480 diff
481 )?
482 }
483
484 fmt::Result::Ok(())
485 }
486 }
487 }
488}
489
490#[derive(Debug, Clone)]
491struct CurrentEditPrediction {
492 buffer_id: EntityId,
493 completion: EditPrediction,
494}
495
496impl CurrentEditPrediction {
497 fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
498 if self.buffer_id != old_completion.buffer_id {
499 return true;
500 }
501
502 let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
503 return true;
504 };
505 let Some(new_edits) = self.completion.interpolate(snapshot) else {
506 return false;
507 };
508
509 if old_edits.len() == 1 && new_edits.len() == 1 {
510 let (old_range, old_text) = &old_edits[0];
511 let (new_range, new_text) = &new_edits[0];
512 new_range == old_range && new_text.starts_with(old_text.as_ref())
513 } else {
514 true
515 }
516 }
517}
518
519struct PendingCompletion {
520 id: usize,
521 _task: Task<()>,
522}
523
524pub struct SweepAiEditPredictionProvider {
525 workspace: WeakEntity<Workspace>,
526 sweep_ai: Entity<SweepAi>,
527 pending_completions: ArrayVec<PendingCompletion, 2>,
528 next_pending_completion_id: usize,
529 current_completion: Option<CurrentEditPrediction>,
530 last_request_timestamp: Instant,
531 project: Entity<Project>,
532}
533
534impl SweepAiEditPredictionProvider {
535 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
536
537 pub fn new(
538 sweep_ai: Entity<SweepAi>,
539 workspace: WeakEntity<Workspace>,
540 project: Entity<Project>,
541 ) -> Self {
542 Self {
543 sweep_ai,
544 pending_completions: ArrayVec::new(),
545 next_pending_completion_id: 0,
546 current_completion: None,
547 last_request_timestamp: Instant::now(),
548 project,
549 workspace,
550 }
551 }
552}
553
554impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
555 fn name() -> &'static str {
556 "zed-predict"
557 }
558
559 fn display_name() -> &'static str {
560 "Zed's Edit Predictions"
561 }
562
563 fn show_completions_in_menu() -> bool {
564 true
565 }
566
567 fn show_tab_accept_marker() -> bool {
568 true
569 }
570
571 fn is_enabled(
572 &self,
573 _buffer: &Entity<Buffer>,
574 _cursor_position: language::Anchor,
575 cx: &App,
576 ) -> bool {
577 self.sweep_ai.read(cx).api_token.is_some()
578 }
579
580 fn is_refreshing(&self) -> bool {
581 !self.pending_completions.is_empty()
582 }
583
584 fn refresh(
585 &mut self,
586 buffer: Entity<Buffer>,
587 position: language::Anchor,
588 _debounce: bool,
589 cx: &mut Context<Self>,
590 ) {
591 if let Some(current_completion) = self.current_completion.as_ref() {
592 let snapshot = buffer.read(cx).snapshot();
593 if current_completion
594 .completion
595 .interpolate(&snapshot)
596 .is_some()
597 {
598 return;
599 }
600 }
601
602 let pending_completion_id = self.next_pending_completion_id;
603 self.next_pending_completion_id += 1;
604 let last_request_timestamp = self.last_request_timestamp;
605
606 let project = self.project.clone();
607 let workspace = self.workspace.clone();
608 let task = cx.spawn(async move |this, cx| {
609 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
610 .checked_duration_since(Instant::now())
611 {
612 cx.background_executor().timer(timeout).await;
613 }
614
615 let completion_request = this.update(cx, |this, cx| {
616 this.last_request_timestamp = Instant::now();
617
618 this.sweep_ai.update(cx, |sweep_ai, cx| {
619 let Some(recent_buffers) = workspace
620 .read_with(cx, |workspace, cx| {
621 workspace.recent_navigation_history_iter(cx)
622 })
623 .log_err()
624 else {
625 return Task::ready(Ok(None));
626 };
627 sweep_ai.request_completion(
628 &project,
629 recent_buffers.map(move |(project_path, _)| project_path),
630 &buffer,
631 position,
632 cx,
633 )
634 })
635 });
636
637 let completion = match completion_request {
638 Ok(completion_request) => {
639 let completion_request = completion_request.await;
640 completion_request.map(|c| {
641 c.map(|completion| CurrentEditPrediction {
642 buffer_id: buffer.entity_id(),
643 completion,
644 })
645 })
646 }
647 Err(error) => Err(error),
648 };
649
650 let Some(new_completion) = completion
651 .context("edit prediction failed")
652 .log_err()
653 .flatten()
654 else {
655 this.update(cx, |this, cx| {
656 if this.pending_completions[0].id == pending_completion_id {
657 this.pending_completions.remove(0);
658 } else {
659 this.pending_completions.clear();
660 }
661
662 cx.notify();
663 })
664 .ok();
665 return;
666 };
667
668 this.update(cx, |this, cx| {
669 if this.pending_completions[0].id == pending_completion_id {
670 this.pending_completions.remove(0);
671 } else {
672 this.pending_completions.clear();
673 }
674
675 if let Some(old_completion) = this.current_completion.as_ref() {
676 let snapshot = buffer.read(cx).snapshot();
677 if new_completion.should_replace_completion(old_completion, &snapshot) {
678 this.current_completion = Some(new_completion);
679 }
680 } else {
681 this.current_completion = Some(new_completion);
682 }
683
684 cx.notify();
685 })
686 .ok();
687 });
688
689 // We always maintain at most two pending completions. When we already
690 // have two, we replace the newest one.
691 if self.pending_completions.len() <= 1 {
692 self.pending_completions.push(PendingCompletion {
693 id: pending_completion_id,
694 _task: task,
695 });
696 } else if self.pending_completions.len() == 2 {
697 self.pending_completions.pop();
698 self.pending_completions.push(PendingCompletion {
699 id: pending_completion_id,
700 _task: task,
701 });
702 }
703 }
704
705 fn cycle(
706 &mut self,
707 _buffer: Entity<Buffer>,
708 _cursor_position: language::Anchor,
709 _direction: edit_prediction::Direction,
710 _cx: &mut Context<Self>,
711 ) {
712 // Right now we don't support cycling.
713 }
714
715 fn accept(&mut self, _cx: &mut Context<Self>) {
716 self.pending_completions.clear();
717 }
718
719 fn discard(&mut self, _cx: &mut Context<Self>) {
720 self.pending_completions.clear();
721 self.current_completion.take();
722 }
723
724 fn suggest(
725 &mut self,
726 buffer: &Entity<Buffer>,
727 cursor_position: language::Anchor,
728 cx: &mut Context<Self>,
729 ) -> Option<edit_prediction::EditPrediction> {
730 let CurrentEditPrediction {
731 buffer_id,
732 completion,
733 ..
734 } = self.current_completion.as_mut()?;
735
736 // Invalidate previous completion if it was generated for a different buffer.
737 if *buffer_id != buffer.entity_id() {
738 self.current_completion.take();
739 return None;
740 }
741
742 let buffer = buffer.read(cx);
743 let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
744 self.current_completion.take();
745 return None;
746 };
747
748 let cursor_row = cursor_position.to_point(buffer).row;
749 let (closest_edit_ix, (closest_edit_range, _)) =
750 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
751 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
752 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
753 cmp::min(distance_from_start, distance_from_end)
754 })?;
755
756 let mut edit_start_ix = closest_edit_ix;
757 for (range, _) in edits[..edit_start_ix].iter().rev() {
758 let distance_from_closest_edit =
759 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
760 if distance_from_closest_edit <= 1 {
761 edit_start_ix -= 1;
762 } else {
763 break;
764 }
765 }
766
767 let mut edit_end_ix = closest_edit_ix + 1;
768 for (range, _) in &edits[edit_end_ix..] {
769 let distance_from_closest_edit =
770 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
771 if distance_from_closest_edit <= 1 {
772 edit_end_ix += 1;
773 } else {
774 break;
775 }
776 }
777
778 Some(edit_prediction::EditPrediction::Local {
779 id: Some(completion.id.to_string().into()),
780 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
781 edit_preview: Some(completion.edit_preview.clone()),
782 })
783 }
784}