1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use gpui::http_client::Method;
15use gpui::{
16 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
17 prelude::*,
18};
19use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
20use language::{BufferSnapshot, EditPreview};
21use language_model::{LlmApiToken, RefreshLlmTokenListener};
22use project::Project;
23use release_channel::AppVersion;
24use std::cmp;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::fmt::Write;
27use std::path::{Path, PathBuf};
28use std::str::FromStr as _;
29use std::time::{Duration, Instant};
30use std::{ops::Range, sync::Arc};
31use thiserror::Error;
32use util::ResultExt as _;
33use uuid::Uuid;
34use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
35
36const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
37
38/// Maximum number of events to track.
39const MAX_EVENT_COUNT: usize = 16;
40
41#[derive(Clone)]
42struct ZetaGlobal(Entity<Zeta>);
43
44impl Global for ZetaGlobal {}
45
46pub struct Zeta {
47 client: Arc<Client>,
48 user_store: Entity<UserStore>,
49 llm_token: LlmApiToken,
50 _llm_token_subscription: Subscription,
51 projects: HashMap<EntityId, ZetaProject>,
52 excerpt_options: EditPredictionExcerptOptions,
53 update_required: bool,
54}
55
56struct ZetaProject {
57 syntax_index: Entity<SyntaxIndex>,
58 events: VecDeque<Event>,
59 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
60}
61
62struct RegisteredBuffer {
63 snapshot: BufferSnapshot,
64 _subscriptions: [gpui::Subscription; 2],
65}
66
67#[derive(Clone)]
68pub enum Event {
69 BufferChange {
70 old_snapshot: BufferSnapshot,
71 new_snapshot: BufferSnapshot,
72 timestamp: Instant,
73 },
74}
75
76impl Event {
77 //TODO: Actually use the events this in the prompt
78 fn to_prompt(&self) -> String {
79 match self {
80 Event::BufferChange {
81 old_snapshot,
82 new_snapshot,
83 ..
84 } => {
85 let mut prompt = String::new();
86
87 let old_path = old_snapshot
88 .file()
89 .map(|f| f.path().as_ref())
90 .unwrap_or(Path::new("untitled"));
91 let new_path = new_snapshot
92 .file()
93 .map(|f| f.path().as_ref())
94 .unwrap_or(Path::new("untitled"));
95 if old_path != new_path {
96 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
97 }
98
99 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
100 if !diff.is_empty() {
101 write!(
102 prompt,
103 "User edited {:?}:\n```diff\n{}\n```",
104 new_path, diff
105 )
106 .unwrap();
107 }
108
109 prompt
110 }
111 }
112 }
113}
114
115impl Zeta {
116 pub fn global(
117 client: &Arc<Client>,
118 user_store: &Entity<UserStore>,
119 cx: &mut App,
120 ) -> Entity<Self> {
121 cx.try_global::<ZetaGlobal>()
122 .map(|global| global.0.clone())
123 .unwrap_or_else(|| {
124 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
125 cx.set_global(ZetaGlobal(zeta.clone()));
126 zeta
127 })
128 }
129
130 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
131 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
132
133 Self {
134 projects: HashMap::new(),
135 client,
136 user_store,
137 excerpt_options: EditPredictionExcerptOptions {
138 max_bytes: 512,
139 min_bytes: 128,
140 target_before_cursor_over_total_bytes: 0.5,
141 },
142 llm_token: LlmApiToken::default(),
143 _llm_token_subscription: cx.subscribe(
144 &refresh_llm_token_listener,
145 |this, _listener, _event, cx| {
146 let client = this.client.clone();
147 let llm_token = this.llm_token.clone();
148 cx.spawn(async move |_this, _cx| {
149 llm_token.refresh(&client).await?;
150 anyhow::Ok(())
151 })
152 .detach_and_log_err(cx);
153 },
154 ),
155 update_required: false,
156 }
157 }
158
159 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
160 self.user_store.read(cx).edit_prediction_usage()
161 }
162
163 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
164 self.get_or_init_zeta_project(project, cx);
165 }
166
167 pub fn register_buffer(
168 &mut self,
169 buffer: &Entity<Buffer>,
170 project: &Entity<Project>,
171 cx: &mut Context<Self>,
172 ) {
173 let zeta_project = self.get_or_init_zeta_project(project, cx);
174 Self::register_buffer_impl(zeta_project, buffer, project, cx);
175 }
176
177 fn get_or_init_zeta_project(
178 &mut self,
179 project: &Entity<Project>,
180 cx: &mut App,
181 ) -> &mut ZetaProject {
182 self.projects
183 .entry(project.entity_id())
184 .or_insert_with(|| ZetaProject {
185 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
186 events: VecDeque::new(),
187 registered_buffers: HashMap::new(),
188 })
189 }
190
191 fn register_buffer_impl<'a>(
192 zeta_project: &'a mut ZetaProject,
193 buffer: &Entity<Buffer>,
194 project: &Entity<Project>,
195 cx: &mut Context<Self>,
196 ) -> &'a mut RegisteredBuffer {
197 let buffer_id = buffer.entity_id();
198 match zeta_project.registered_buffers.entry(buffer_id) {
199 hash_map::Entry::Occupied(entry) => entry.into_mut(),
200 hash_map::Entry::Vacant(entry) => {
201 let snapshot = buffer.read(cx).snapshot();
202 let project_entity_id = project.entity_id();
203 entry.insert(RegisteredBuffer {
204 snapshot,
205 _subscriptions: [
206 cx.subscribe(buffer, {
207 let project = project.downgrade();
208 move |this, buffer, event, cx| {
209 if let language::BufferEvent::Edited = event
210 && let Some(project) = project.upgrade()
211 {
212 this.report_changes_for_buffer(&buffer, &project, cx);
213 }
214 }
215 }),
216 cx.observe_release(buffer, move |this, _buffer, _cx| {
217 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
218 else {
219 return;
220 };
221 zeta_project.registered_buffers.remove(&buffer_id);
222 }),
223 ],
224 })
225 }
226 }
227 }
228
229 fn report_changes_for_buffer(
230 &mut self,
231 buffer: &Entity<Buffer>,
232 project: &Entity<Project>,
233 cx: &mut Context<Self>,
234 ) -> BufferSnapshot {
235 let zeta_project = self.get_or_init_zeta_project(project, cx);
236 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
237
238 let new_snapshot = buffer.read(cx).snapshot();
239 if new_snapshot.version != registered_buffer.snapshot.version {
240 let old_snapshot =
241 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
242 Self::push_event(
243 zeta_project,
244 Event::BufferChange {
245 old_snapshot,
246 new_snapshot: new_snapshot.clone(),
247 timestamp: Instant::now(),
248 },
249 );
250 }
251
252 new_snapshot
253 }
254
255 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
256 let events = &mut zeta_project.events;
257
258 if let Some(Event::BufferChange {
259 new_snapshot: last_new_snapshot,
260 timestamp: last_timestamp,
261 ..
262 }) = events.back_mut()
263 {
264 // Coalesce edits for the same buffer when they happen one after the other.
265 let Event::BufferChange {
266 old_snapshot,
267 new_snapshot,
268 timestamp,
269 } = &event;
270
271 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
272 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
273 && old_snapshot.version == last_new_snapshot.version
274 {
275 *last_new_snapshot = new_snapshot.clone();
276 *last_timestamp = *timestamp;
277 return;
278 }
279 }
280
281 if events.len() >= MAX_EVENT_COUNT {
282 // These are halved instead of popping to improve prompt caching.
283 events.drain(..MAX_EVENT_COUNT / 2);
284 }
285
286 events.push_back(event);
287 }
288
289 pub fn request_prediction(
290 &mut self,
291 project: &Entity<Project>,
292 buffer: &Entity<Buffer>,
293 position: language::Anchor,
294 cx: &mut Context<Self>,
295 ) -> Task<Result<Option<EditPrediction>>> {
296 let project_state = self.projects.get(&project.entity_id());
297
298 let index_state = project_state.map(|state| {
299 state
300 .syntax_index
301 .read_with(cx, |index, _cx| index.state().clone())
302 });
303 let excerpt_options = self.excerpt_options.clone();
304 let snapshot = buffer.read(cx).snapshot();
305 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
306 return Task::ready(Err(anyhow!("No file path for excerpt")));
307 };
308 let client = self.client.clone();
309 let llm_token = self.llm_token.clone();
310 let app_version = AppVersion::global(cx);
311 let worktree_snapshots = project
312 .read(cx)
313 .worktrees(cx)
314 .map(|worktree| worktree.read(cx).snapshot())
315 .collect::<Vec<_>>();
316
317 let request_task = cx.background_spawn({
318 let snapshot = snapshot.clone();
319 async move {
320 let index_state = if let Some(index_state) = index_state {
321 Some(index_state.lock_owned().await)
322 } else {
323 None
324 };
325
326 let cursor_point = position.to_point(&snapshot);
327
328 // TODO: make this only true if debug view is open
329 let debug_info = true;
330
331 let Some(request) = EditPredictionContext::gather_context(
332 cursor_point,
333 &snapshot,
334 &excerpt_options,
335 index_state.as_deref(),
336 )
337 .map(|context| {
338 make_cloud_request(
339 excerpt_path.clone(),
340 context,
341 // TODO pass everything
342 Vec::new(),
343 false,
344 Vec::new(),
345 None,
346 debug_info,
347 &worktree_snapshots,
348 index_state.as_deref(),
349 )
350 }) else {
351 return Ok(None);
352 };
353
354 anyhow::Ok(Some(
355 Self::perform_request(client, llm_token, app_version, request).await?,
356 ))
357 }
358 });
359
360 let buffer = buffer.clone();
361
362 cx.spawn(async move |this, cx| {
363 match request_task.await {
364 Ok(Some((response, usage))) => {
365 log::debug!("predicted edits: {:?}", &response.edits);
366
367 if let Some(usage) = usage {
368 this.update(cx, |this, cx| {
369 this.user_store.update(cx, |user_store, cx| {
370 user_store.update_edit_prediction_usage(usage, cx);
371 });
372 })
373 .ok();
374 }
375
376 // TODO telemetry: duration, etc
377
378 // TODO produce smaller edits by diffing against snapshot first
379 //
380 // Cloud returns entire snippets/excerpts ranges as they were included
381 // in the request, but we should display smaller edits to the user.
382 //
383 // We can do this by computing a diff of each one against the snapshot.
384 // Similar to zeta::Zeta::compute_edits, but per edit.
385 let edits = response
386 .edits
387 .into_iter()
388 .map(|edit| {
389 // TODO edits to different files
390 (
391 snapshot.anchor_before(edit.range.start)
392 ..snapshot.anchor_before(edit.range.end),
393 edit.content,
394 )
395 })
396 .collect::<Vec<_>>()
397 .into();
398
399 let Some((edits, snapshot, edit_preview_task)) =
400 buffer.read_with(cx, |buffer, cx| {
401 let new_snapshot = buffer.snapshot();
402 let edits: Arc<[_]> =
403 interpolate(&snapshot, &new_snapshot, edits)?.into();
404 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
405 })?
406 else {
407 return Ok(None);
408 };
409
410 Ok(Some(EditPrediction {
411 id: EditPredictionId(response.request_id),
412 edits,
413 snapshot,
414 edit_preview: edit_preview_task.await,
415 }))
416 }
417 Ok(None) => Ok(None),
418 Err(err) => {
419 if err.is::<ZedUpdateRequiredError>() {
420 cx.update(|cx| {
421 this.update(cx, |this, _cx| {
422 this.update_required = true;
423 })
424 .ok();
425
426 let error_message: SharedString = err.to_string().into();
427 show_app_notification(
428 NotificationId::unique::<ZedUpdateRequiredError>(),
429 cx,
430 move |cx| {
431 cx.new(|cx| {
432 ErrorMessagePrompt::new(error_message.clone(), cx)
433 .with_link_button(
434 "Update Zed",
435 "https://zed.dev/releases",
436 )
437 })
438 },
439 );
440 })
441 .ok();
442 }
443
444 Err(err)
445 }
446 }
447 })
448 }
449
450 async fn perform_request(
451 client: Arc<Client>,
452 llm_token: LlmApiToken,
453 app_version: SemanticVersion,
454 request: predict_edits_v3::PredictEditsRequest,
455 ) -> Result<(
456 predict_edits_v3::PredictEditsResponse,
457 Option<EditPredictionUsage>,
458 )> {
459 let http_client = client.http_client();
460 let mut token = llm_token.acquire(&client).await?;
461 let mut did_retry = false;
462
463 loop {
464 let request_builder = http_client::Request::builder().method(Method::POST);
465 let request_builder =
466 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
467 request_builder.uri(predict_edits_url)
468 } else {
469 request_builder.uri(
470 http_client
471 .build_zed_llm_url("/predict_edits/v3", &[])?
472 .as_ref(),
473 )
474 };
475 let request = request_builder
476 .header("Content-Type", "application/json")
477 .header("Authorization", format!("Bearer {}", token))
478 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
479 .body(serde_json::to_string(&request)?.into())?;
480
481 let mut response = http_client.send(request).await?;
482
483 if let Some(minimum_required_version) = response
484 .headers()
485 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
486 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
487 {
488 anyhow::ensure!(
489 app_version >= minimum_required_version,
490 ZedUpdateRequiredError {
491 minimum_version: minimum_required_version
492 }
493 );
494 }
495
496 if response.status().is_success() {
497 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
498
499 let mut body = Vec::new();
500 response.body_mut().read_to_end(&mut body).await?;
501 return Ok((serde_json::from_slice(&body)?, usage));
502 } else if !did_retry
503 && response
504 .headers()
505 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
506 .is_some()
507 {
508 did_retry = true;
509 token = llm_token.refresh(&client).await?;
510 } else {
511 let mut body = String::new();
512 response.body_mut().read_to_string(&mut body).await?;
513 anyhow::bail!(
514 "error predicting edits.\nStatus: {:?}\nBody: {}",
515 response.status(),
516 body
517 );
518 }
519 }
520 }
521}
522
523#[derive(Error, Debug)]
524#[error(
525 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
526)]
527pub struct ZedUpdateRequiredError {
528 minimum_version: SemanticVersion,
529}
530
531pub struct ZetaEditPredictionProvider {
532 zeta: Entity<Zeta>,
533 current_prediction: Option<CurrentEditPrediction>,
534 next_pending_prediction_id: usize,
535 pending_predictions: ArrayVec<PendingPrediction, 2>,
536 last_request_timestamp: Instant,
537}
538
539impl ZetaEditPredictionProvider {
540 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
541
542 pub fn new(
543 project: Option<&Entity<Project>>,
544 client: &Arc<Client>,
545 user_store: &Entity<UserStore>,
546 cx: &mut App,
547 ) -> Self {
548 let zeta = Zeta::global(client, user_store, cx);
549 if let Some(project) = project {
550 zeta.update(cx, |zeta, cx| {
551 zeta.register_project(project, cx);
552 });
553 }
554
555 Self {
556 zeta,
557 current_prediction: None,
558 next_pending_prediction_id: 0,
559 pending_predictions: ArrayVec::new(),
560 last_request_timestamp: Instant::now(),
561 }
562 }
563}
564
565#[derive(Clone)]
566struct CurrentEditPrediction {
567 buffer_id: EntityId,
568 prediction: EditPrediction,
569}
570
571impl CurrentEditPrediction {
572 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
573 if self.buffer_id != old_prediction.buffer_id {
574 return true;
575 }
576
577 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
578 return true;
579 };
580 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
581 return false;
582 };
583
584 if old_edits.len() == 1 && new_edits.len() == 1 {
585 let (old_range, old_text) = &old_edits[0];
586 let (new_range, new_text) = &new_edits[0];
587 new_range == old_range && new_text.starts_with(old_text)
588 } else {
589 true
590 }
591 }
592}
593
594#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
595pub struct EditPredictionId(Uuid);
596
597impl From<EditPredictionId> for gpui::ElementId {
598 fn from(value: EditPredictionId) -> Self {
599 gpui::ElementId::Uuid(value.0)
600 }
601}
602
603impl std::fmt::Display for EditPredictionId {
604 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605 write!(f, "{}", self.0)
606 }
607}
608
609#[derive(Clone)]
610pub struct EditPrediction {
611 id: EditPredictionId,
612 edits: Arc<[(Range<Anchor>, String)]>,
613 snapshot: BufferSnapshot,
614 edit_preview: EditPreview,
615}
616
617impl EditPrediction {
618 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
619 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
620 }
621}
622
623struct PendingPrediction {
624 id: usize,
625 _task: Task<()>,
626}
627
628impl EditPredictionProvider for ZetaEditPredictionProvider {
629 fn name() -> &'static str {
630 "zed-predict2"
631 }
632
633 fn display_name() -> &'static str {
634 "Zed's Edit Predictions 2"
635 }
636
637 fn show_completions_in_menu() -> bool {
638 true
639 }
640
641 fn show_tab_accept_marker() -> bool {
642 true
643 }
644
645 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
646 // TODO [zeta2]
647 DataCollectionState::Unsupported
648 }
649
650 fn toggle_data_collection(&mut self, _cx: &mut App) {
651 // TODO [zeta2]
652 }
653
654 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
655 self.zeta.read(cx).usage(cx)
656 }
657
658 fn is_enabled(
659 &self,
660 _buffer: &Entity<language::Buffer>,
661 _cursor_position: language::Anchor,
662 _cx: &App,
663 ) -> bool {
664 true
665 }
666
667 fn is_refreshing(&self) -> bool {
668 !self.pending_predictions.is_empty()
669 }
670
671 fn refresh(
672 &mut self,
673 project: Option<Entity<project::Project>>,
674 buffer: Entity<language::Buffer>,
675 cursor_position: language::Anchor,
676 _debounce: bool,
677 cx: &mut Context<Self>,
678 ) {
679 let Some(project) = project else {
680 return;
681 };
682
683 if self
684 .zeta
685 .read(cx)
686 .user_store
687 .read_with(cx, |user_store, _cx| {
688 user_store.account_too_young() || user_store.has_overdue_invoices()
689 })
690 {
691 return;
692 }
693
694 if let Some(current_prediction) = self.current_prediction.as_ref() {
695 let snapshot = buffer.read(cx).snapshot();
696 if current_prediction
697 .prediction
698 .interpolate(&snapshot)
699 .is_some()
700 {
701 return;
702 }
703 }
704
705 let pending_prediction_id = self.next_pending_prediction_id;
706 self.next_pending_prediction_id += 1;
707 let last_request_timestamp = self.last_request_timestamp;
708
709 let task = cx.spawn(async move |this, cx| {
710 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
711 .checked_duration_since(Instant::now())
712 {
713 cx.background_executor().timer(timeout).await;
714 }
715
716 let prediction_request = this.update(cx, |this, cx| {
717 this.last_request_timestamp = Instant::now();
718 this.zeta.update(cx, |zeta, cx| {
719 zeta.request_prediction(&project, &buffer, cursor_position, cx)
720 })
721 });
722
723 let prediction = match prediction_request {
724 Ok(prediction_request) => {
725 let prediction_request = prediction_request.await;
726 prediction_request.map(|c| {
727 c.map(|prediction| CurrentEditPrediction {
728 buffer_id: buffer.entity_id(),
729 prediction,
730 })
731 })
732 }
733 Err(error) => Err(error),
734 };
735
736 this.update(cx, |this, cx| {
737 if this.pending_predictions[0].id == pending_prediction_id {
738 this.pending_predictions.remove(0);
739 } else {
740 this.pending_predictions.clear();
741 }
742
743 let Some(new_prediction) = prediction
744 .context("edit prediction failed")
745 .log_err()
746 .flatten()
747 else {
748 cx.notify();
749 return;
750 };
751
752 if let Some(old_prediction) = this.current_prediction.as_ref() {
753 let snapshot = buffer.read(cx).snapshot();
754 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
755 this.current_prediction = Some(new_prediction);
756 }
757 } else {
758 this.current_prediction = Some(new_prediction);
759 }
760
761 cx.notify();
762 })
763 .ok();
764 });
765
766 // We always maintain at most two pending predictions. When we already
767 // have two, we replace the newest one.
768 if self.pending_predictions.len() <= 1 {
769 self.pending_predictions.push(PendingPrediction {
770 id: pending_prediction_id,
771 _task: task,
772 });
773 } else if self.pending_predictions.len() == 2 {
774 self.pending_predictions.pop();
775 self.pending_predictions.push(PendingPrediction {
776 id: pending_prediction_id,
777 _task: task,
778 });
779 }
780
781 cx.notify();
782 }
783
784 fn cycle(
785 &mut self,
786 _buffer: Entity<language::Buffer>,
787 _cursor_position: language::Anchor,
788 _direction: Direction,
789 _cx: &mut Context<Self>,
790 ) {
791 }
792
793 fn accept(&mut self, _cx: &mut Context<Self>) {
794 // TODO [zeta2] report accept
795 self.current_prediction.take();
796 self.pending_predictions.clear();
797 }
798
799 fn discard(&mut self, _cx: &mut Context<Self>) {
800 self.pending_predictions.clear();
801 self.current_prediction.take();
802 }
803
804 fn suggest(
805 &mut self,
806 buffer: &Entity<language::Buffer>,
807 cursor_position: language::Anchor,
808 cx: &mut Context<Self>,
809 ) -> Option<edit_prediction::EditPrediction> {
810 let CurrentEditPrediction {
811 buffer_id,
812 prediction,
813 ..
814 } = self.current_prediction.as_mut()?;
815
816 // Invalidate previous prediction if it was generated for a different buffer.
817 if *buffer_id != buffer.entity_id() {
818 self.current_prediction.take();
819 return None;
820 }
821
822 let buffer = buffer.read(cx);
823 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
824 self.current_prediction.take();
825 return None;
826 };
827
828 let cursor_row = cursor_position.to_point(buffer).row;
829 let (closest_edit_ix, (closest_edit_range, _)) =
830 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
831 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
832 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
833 cmp::min(distance_from_start, distance_from_end)
834 })?;
835
836 let mut edit_start_ix = closest_edit_ix;
837 for (range, _) in edits[..edit_start_ix].iter().rev() {
838 let distance_from_closest_edit =
839 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
840 if distance_from_closest_edit <= 1 {
841 edit_start_ix -= 1;
842 } else {
843 break;
844 }
845 }
846
847 let mut edit_end_ix = closest_edit_ix + 1;
848 for (range, _) in &edits[edit_end_ix..] {
849 let distance_from_closest_edit =
850 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
851 if distance_from_closest_edit <= 1 {
852 edit_end_ix += 1;
853 } else {
854 break;
855 }
856 }
857
858 Some(edit_prediction::EditPrediction {
859 id: Some(prediction.id.to_string().into()),
860 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
861 edit_preview: Some(prediction.edit_preview.clone()),
862 })
863 }
864}
865
866fn make_cloud_request(
867 excerpt_path: PathBuf,
868 context: EditPredictionContext,
869 events: Vec<predict_edits_v3::Event>,
870 can_collect_data: bool,
871 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
872 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
873 debug_info: bool,
874 worktrees: &Vec<worktree::Snapshot>,
875 index_state: Option<&SyntaxIndexState>,
876) -> predict_edits_v3::PredictEditsRequest {
877 let mut signatures = Vec::new();
878 let mut declaration_to_signature_index = HashMap::default();
879 let mut referenced_declarations = Vec::new();
880
881 for snippet in context.snippets {
882 let project_entry_id = snippet.declaration.project_entry_id();
883 // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
884 // Note that currently full_path is currently being used for excerpt_path.
885 let Some(path) = worktrees.iter().find_map(|worktree| {
886 let abs_path = worktree.abs_path();
887 worktree
888 .entry_for_id(project_entry_id)
889 .map(|e| abs_path.join(&e.path))
890 }) else {
891 continue;
892 };
893
894 let parent_index = index_state.and_then(|index_state| {
895 snippet.declaration.parent().and_then(|parent| {
896 add_signature(
897 parent,
898 &mut declaration_to_signature_index,
899 &mut signatures,
900 index_state,
901 )
902 })
903 });
904
905 let (text, text_is_truncated) = snippet.declaration.item_text();
906 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
907 path,
908 text: text.into(),
909 range: snippet.declaration.item_range(),
910 text_is_truncated,
911 signature_range: snippet.declaration.signature_range_in_item_text(),
912 parent_index,
913 score_components: snippet.score_components,
914 signature_score: snippet.scores.signature,
915 declaration_score: snippet.scores.declaration,
916 });
917 }
918
919 let excerpt_parent = index_state.and_then(|index_state| {
920 context
921 .excerpt
922 .parent_declarations
923 .last()
924 .and_then(|(parent, _)| {
925 add_signature(
926 *parent,
927 &mut declaration_to_signature_index,
928 &mut signatures,
929 index_state,
930 )
931 })
932 });
933
934 predict_edits_v3::PredictEditsRequest {
935 excerpt_path,
936 excerpt: context.excerpt_text.body,
937 excerpt_range: context.excerpt.range,
938 cursor_offset: context.cursor_offset_in_excerpt,
939 referenced_declarations,
940 signatures,
941 excerpt_parent,
942 // todo!
943 events,
944 can_collect_data,
945 diagnostic_groups,
946 git_info,
947 debug_info,
948 }
949}
950
951fn add_signature(
952 declaration_id: DeclarationId,
953 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
954 signatures: &mut Vec<Signature>,
955 index: &SyntaxIndexState,
956) -> Option<usize> {
957 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
958 return Some(*signature_index);
959 }
960 let Some(parent_declaration) = index.declaration(declaration_id) else {
961 log::error!("bug: missing parent declaration");
962 return None;
963 };
964 let parent_index = parent_declaration.parent().and_then(|parent| {
965 add_signature(parent, declaration_to_signature_index, signatures, index)
966 });
967 let (text, text_is_truncated) = parent_declaration.signature_text();
968 let signature_index = signatures.len();
969 signatures.push(Signature {
970 text: text.into(),
971 text_is_truncated,
972 parent_index,
973 });
974 declaration_to_signature_index.insert(declaration_id, signature_index);
975 Some(signature_index)
976}
977
978fn interpolate(
979 old_snapshot: &BufferSnapshot,
980 new_snapshot: &BufferSnapshot,
981 current_edits: Arc<[(Range<Anchor>, String)]>,
982) -> Option<Vec<(Range<Anchor>, String)>> {
983 let mut edits = Vec::new();
984
985 let mut model_edits = current_edits.iter().peekable();
986 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
987 while let Some((model_old_range, _)) = model_edits.peek() {
988 let model_old_range = model_old_range.to_offset(old_snapshot);
989 if model_old_range.end < user_edit.old.start {
990 let (model_old_range, model_new_text) = model_edits.next().unwrap();
991 edits.push((model_old_range.clone(), model_new_text.clone()));
992 } else {
993 break;
994 }
995 }
996
997 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
998 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
999 if user_edit.old == model_old_offset_range {
1000 let user_new_text = new_snapshot
1001 .text_for_range(user_edit.new.clone())
1002 .collect::<String>();
1003
1004 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1005 if !model_suffix.is_empty() {
1006 let anchor = old_snapshot.anchor_after(user_edit.old.end);
1007 edits.push((anchor..anchor, model_suffix.to_string()));
1008 }
1009
1010 model_edits.next();
1011 continue;
1012 }
1013 }
1014 }
1015
1016 return None;
1017 }
1018
1019 edits.extend(model_edits.cloned());
1020
1021 if edits.is_empty() { None } else { Some(edits) }
1022}
1023
1024#[cfg(test)]
1025mod tests {
1026 use super::*;
1027 use gpui::TestAppContext;
1028 use language::ToOffset as _;
1029
1030 #[gpui::test]
1031 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1032 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1033 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1034 to_prediction_edits(
1035 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1036 &buffer,
1037 cx,
1038 )
1039 .into()
1040 });
1041
1042 let edit_preview = cx
1043 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1044 .await;
1045
1046 let prediction = EditPrediction {
1047 id: EditPredictionId(Uuid::new_v4()),
1048 edits,
1049 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1050 edit_preview,
1051 };
1052
1053 cx.update(|cx| {
1054 assert_eq!(
1055 from_prediction_edits(
1056 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1057 &buffer,
1058 cx
1059 ),
1060 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1061 );
1062
1063 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1064 assert_eq!(
1065 from_prediction_edits(
1066 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1067 &buffer,
1068 cx
1069 ),
1070 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1071 );
1072
1073 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1074 assert_eq!(
1075 from_prediction_edits(
1076 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1077 &buffer,
1078 cx
1079 ),
1080 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1081 );
1082
1083 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1084 assert_eq!(
1085 from_prediction_edits(
1086 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1087 &buffer,
1088 cx
1089 ),
1090 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1091 );
1092
1093 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1094 assert_eq!(
1095 from_prediction_edits(
1096 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1097 &buffer,
1098 cx
1099 ),
1100 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1101 );
1102
1103 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1104 assert_eq!(
1105 from_prediction_edits(
1106 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1107 &buffer,
1108 cx
1109 ),
1110 vec![(9..11, "".to_string())]
1111 );
1112
1113 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1114 assert_eq!(
1115 from_prediction_edits(
1116 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1117 &buffer,
1118 cx
1119 ),
1120 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1121 );
1122
1123 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1124 assert_eq!(
1125 from_prediction_edits(
1126 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1127 &buffer,
1128 cx
1129 ),
1130 vec![(4..4, "M".to_string())]
1131 );
1132
1133 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1134 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1135 })
1136 }
1137
1138 fn to_prediction_edits(
1139 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1140 buffer: &Entity<Buffer>,
1141 cx: &App,
1142 ) -> Vec<(Range<Anchor>, String)> {
1143 let buffer = buffer.read(cx);
1144 iterator
1145 .into_iter()
1146 .map(|(range, text)| {
1147 (
1148 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1149 text,
1150 )
1151 })
1152 .collect()
1153 }
1154
1155 fn from_prediction_edits(
1156 editor_edits: &[(Range<Anchor>, String)],
1157 buffer: &Entity<Buffer>,
1158 cx: &App,
1159 ) -> Vec<(Range<usize>, String)> {
1160 let buffer = buffer.read(cx);
1161 editor_edits
1162 .iter()
1163 .map(|(range, text)| {
1164 (
1165 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1166 text.clone(),
1167 )
1168 })
1169 .collect()
1170 }
1171}