1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use gpui::http_client::Method;
15use gpui::{
16 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
17 prelude::*,
18};
19use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
20use language::{BufferSnapshot, EditPreview};
21use language_model::{LlmApiToken, RefreshLlmTokenListener};
22use project::Project;
23use release_channel::AppVersion;
24use std::cmp;
25use std::collections::HashMap;
26use std::path::PathBuf;
27use std::str::FromStr as _;
28use std::time::{Duration, Instant};
29use std::{ops::Range, sync::Arc};
30use thiserror::Error;
31use util::ResultExt as _;
32use uuid::Uuid;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35#[derive(Clone)]
36struct ZetaGlobal(Entity<Zeta>);
37
38impl Global for ZetaGlobal {}
39
40pub struct Zeta {
41 client: Arc<Client>,
42 user_store: Entity<UserStore>,
43 llm_token: LlmApiToken,
44 _llm_token_subscription: Subscription,
45 projects: HashMap<EntityId, RegisteredProject>,
46 excerpt_options: EditPredictionExcerptOptions,
47 update_required: bool,
48}
49
50struct RegisteredProject {
51 syntax_index: Entity<SyntaxIndex>,
52}
53
54impl Zeta {
55 pub fn global(
56 client: &Arc<Client>,
57 user_store: &Entity<UserStore>,
58 cx: &mut App,
59 ) -> Entity<Self> {
60 cx.try_global::<ZetaGlobal>()
61 .map(|global| global.0.clone())
62 .unwrap_or_else(|| {
63 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
64 cx.set_global(ZetaGlobal(zeta.clone()));
65 zeta
66 })
67 }
68
69 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
70 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
71
72 Self {
73 projects: HashMap::new(),
74 client,
75 user_store,
76 excerpt_options: EditPredictionExcerptOptions {
77 max_bytes: 512,
78 min_bytes: 128,
79 target_before_cursor_over_total_bytes: 0.5,
80 },
81 llm_token: LlmApiToken::default(),
82 _llm_token_subscription: cx.subscribe(
83 &refresh_llm_token_listener,
84 |this, _listener, _event, cx| {
85 let client = this.client.clone();
86 let llm_token = this.llm_token.clone();
87 cx.spawn(async move |_this, _cx| {
88 llm_token.refresh(&client).await?;
89 anyhow::Ok(())
90 })
91 .detach_and_log_err(cx);
92 },
93 ),
94 update_required: false,
95 }
96 }
97
98 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
99 self.user_store.read(cx).edit_prediction_usage()
100 }
101
102 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
103 self.projects
104 .entry(project.entity_id())
105 .or_insert_with(|| RegisteredProject {
106 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
107 });
108 }
109
110 pub fn request_prediction(
111 &mut self,
112 project: &Entity<Project>,
113 buffer: &Entity<Buffer>,
114 position: language::Anchor,
115 cx: &mut Context<Self>,
116 ) -> Task<Result<Option<EditPrediction>>> {
117 let project_state = self.projects.get(&project.entity_id());
118
119 let index_state = project_state.map(|state| {
120 state
121 .syntax_index
122 .read_with(cx, |index, _cx| index.state().clone())
123 });
124 let excerpt_options = self.excerpt_options.clone();
125 let snapshot = buffer.read(cx).snapshot();
126 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
127 return Task::ready(Err(anyhow!("No file path for excerpt")));
128 };
129 let client = self.client.clone();
130 let llm_token = self.llm_token.clone();
131 let app_version = AppVersion::global(cx);
132 let worktree_snapshots = project
133 .read(cx)
134 .worktrees(cx)
135 .map(|worktree| worktree.read(cx).snapshot())
136 .collect::<Vec<_>>();
137
138 let request_task = cx.background_spawn({
139 let snapshot = snapshot.clone();
140 async move {
141 let index_state = if let Some(index_state) = index_state {
142 Some(index_state.lock_owned().await)
143 } else {
144 None
145 };
146
147 let cursor_point = position.to_point(&snapshot);
148
149 // TODO: make this only true if debug view is open
150 let debug_info = true;
151
152 let Some(request) = EditPredictionContext::gather_context(
153 cursor_point,
154 &snapshot,
155 &excerpt_options,
156 index_state.as_deref(),
157 )
158 .map(|context| {
159 make_cloud_request(
160 excerpt_path.clone(),
161 context,
162 // TODO pass everything
163 Vec::new(),
164 false,
165 Vec::new(),
166 None,
167 debug_info,
168 &worktree_snapshots,
169 index_state.as_deref(),
170 )
171 }) else {
172 return Ok(None);
173 };
174
175 anyhow::Ok(Some(
176 Self::perform_request(client, llm_token, app_version, request).await?,
177 ))
178 }
179 });
180
181 let buffer = buffer.clone();
182
183 cx.spawn(async move |this, cx| {
184 match request_task.await {
185 Ok(Some((response, usage))) => {
186 log::debug!("predicted edits: {:?}", &response.edits);
187
188 if let Some(usage) = usage {
189 this.update(cx, |this, cx| {
190 this.user_store.update(cx, |user_store, cx| {
191 user_store.update_edit_prediction_usage(usage, cx);
192 });
193 })
194 .ok();
195 }
196
197 // TODO telemetry: duration, etc
198
199 // TODO produce smaller edits by diffing against snapshot first
200 //
201 // Cloud returns entire snippets/excerpts ranges as they were included
202 // in the request, but we should display smaller edits to the user.
203 //
204 // We can do this by computing a diff of each one against the snapshot.
205 // Similar to zeta::Zeta::compute_edits, but per edit.
206 let edits = response
207 .edits
208 .into_iter()
209 .map(|edit| {
210 // TODO edits to different files
211 (
212 snapshot.anchor_before(edit.range.start)
213 ..snapshot.anchor_before(edit.range.end),
214 edit.content,
215 )
216 })
217 .collect::<Vec<_>>()
218 .into();
219
220 let Some((edits, snapshot, edit_preview_task)) =
221 buffer.read_with(cx, |buffer, cx| {
222 let new_snapshot = buffer.snapshot();
223 let edits: Arc<[_]> =
224 interpolate(&snapshot, &new_snapshot, edits)?.into();
225 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
226 })?
227 else {
228 return Ok(None);
229 };
230
231 Ok(Some(EditPrediction {
232 id: EditPredictionId(response.request_id),
233 edits,
234 snapshot,
235 edit_preview: edit_preview_task.await,
236 }))
237 }
238 Ok(None) => Ok(None),
239 Err(err) => {
240 if err.is::<ZedUpdateRequiredError>() {
241 cx.update(|cx| {
242 this.update(cx, |this, _cx| {
243 this.update_required = true;
244 })
245 .ok();
246
247 let error_message: SharedString = err.to_string().into();
248 show_app_notification(
249 NotificationId::unique::<ZedUpdateRequiredError>(),
250 cx,
251 move |cx| {
252 cx.new(|cx| {
253 ErrorMessagePrompt::new(error_message.clone(), cx)
254 .with_link_button(
255 "Update Zed",
256 "https://zed.dev/releases",
257 )
258 })
259 },
260 );
261 })
262 .ok();
263 }
264
265 Err(err)
266 }
267 }
268 })
269 }
270
271 async fn perform_request(
272 client: Arc<Client>,
273 llm_token: LlmApiToken,
274 app_version: SemanticVersion,
275 request: predict_edits_v3::PredictEditsRequest,
276 ) -> Result<(
277 predict_edits_v3::PredictEditsResponse,
278 Option<EditPredictionUsage>,
279 )> {
280 let http_client = client.http_client();
281 let mut token = llm_token.acquire(&client).await?;
282 let mut did_retry = false;
283
284 loop {
285 let request_builder = http_client::Request::builder().method(Method::POST);
286 let request_builder =
287 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
288 request_builder.uri(predict_edits_url)
289 } else {
290 request_builder.uri(
291 http_client
292 .build_zed_llm_url("/predict_edits/v3", &[])?
293 .as_ref(),
294 )
295 };
296 let request = request_builder
297 .header("Content-Type", "application/json")
298 .header("Authorization", format!("Bearer {}", token))
299 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
300 .body(serde_json::to_string(&request)?.into())?;
301
302 let mut response = http_client.send(request).await?;
303
304 if let Some(minimum_required_version) = response
305 .headers()
306 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
307 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
308 {
309 anyhow::ensure!(
310 app_version >= minimum_required_version,
311 ZedUpdateRequiredError {
312 minimum_version: minimum_required_version
313 }
314 );
315 }
316
317 if response.status().is_success() {
318 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
319
320 let mut body = Vec::new();
321 response.body_mut().read_to_end(&mut body).await?;
322 return Ok((serde_json::from_slice(&body)?, usage));
323 } else if !did_retry
324 && response
325 .headers()
326 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
327 .is_some()
328 {
329 did_retry = true;
330 token = llm_token.refresh(&client).await?;
331 } else {
332 let mut body = String::new();
333 response.body_mut().read_to_string(&mut body).await?;
334 anyhow::bail!(
335 "error predicting edits.\nStatus: {:?}\nBody: {}",
336 response.status(),
337 body
338 );
339 }
340 }
341 }
342}
343
344#[derive(Error, Debug)]
345#[error(
346 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
347)]
348pub struct ZedUpdateRequiredError {
349 minimum_version: SemanticVersion,
350}
351
352pub struct ZetaEditPredictionProvider {
353 zeta: Entity<Zeta>,
354 current_prediction: Option<CurrentEditPrediction>,
355 next_pending_prediction_id: usize,
356 pending_predictions: ArrayVec<PendingPrediction, 2>,
357 last_request_timestamp: Instant,
358}
359
360impl ZetaEditPredictionProvider {
361 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
362
363 pub fn new(
364 project: Option<&Entity<Project>>,
365 client: &Arc<Client>,
366 user_store: &Entity<UserStore>,
367 cx: &mut App,
368 ) -> Self {
369 let zeta = Zeta::global(client, user_store, cx);
370 if let Some(project) = project {
371 zeta.update(cx, |zeta, cx| {
372 zeta.register_project(project, cx);
373 });
374 }
375
376 Self {
377 zeta,
378 current_prediction: None,
379 next_pending_prediction_id: 0,
380 pending_predictions: ArrayVec::new(),
381 last_request_timestamp: Instant::now(),
382 }
383 }
384}
385
386#[derive(Clone)]
387struct CurrentEditPrediction {
388 buffer_id: EntityId,
389 prediction: EditPrediction,
390}
391
392impl CurrentEditPrediction {
393 fn should_replace_prediction(
394 &self,
395 _old_completion: &Self,
396 _snapshot: &BufferSnapshot,
397 ) -> bool {
398 true
399 // TODO
400 // if self.buffer_id != old_completion.buffer_id {
401 // return true;
402 // }
403
404 // let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
405 // return true;
406 // };
407 // let Some(new_edits) = self.completion.interpolate(snapshot) else {
408 // return false;
409 // };
410
411 // if old_edits.len() == 1 && new_edits.len() == 1 {
412 // let (old_range, old_text) = &old_edits[0];
413 // let (new_range, new_text) = &new_edits[0];
414 // new_range == old_range && new_text.starts_with(old_text)
415 // } else {
416 // true
417 // }
418 }
419}
420
421#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
422pub struct EditPredictionId(Uuid);
423
424impl From<EditPredictionId> for gpui::ElementId {
425 fn from(value: EditPredictionId) -> Self {
426 gpui::ElementId::Uuid(value.0)
427 }
428}
429
430impl std::fmt::Display for EditPredictionId {
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 write!(f, "{}", self.0)
433 }
434}
435
436#[derive(Clone)]
437pub struct EditPrediction {
438 id: EditPredictionId,
439 edits: Arc<[(Range<Anchor>, String)]>,
440 snapshot: BufferSnapshot,
441 edit_preview: EditPreview,
442}
443
444impl EditPrediction {
445 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
446 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
447 }
448}
449
450struct PendingPrediction {
451 id: usize,
452 _task: Task<()>,
453}
454
455impl EditPredictionProvider for ZetaEditPredictionProvider {
456 fn name() -> &'static str {
457 // TODO [zeta2]
458 "zed-predict2"
459 }
460
461 fn display_name() -> &'static str {
462 "Zed's Edit Predictions 2"
463 }
464
465 fn show_completions_in_menu() -> bool {
466 true
467 }
468
469 fn show_tab_accept_marker() -> bool {
470 true
471 }
472
473 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
474 // TODO [zeta2]
475 DataCollectionState::Unsupported
476 }
477
478 fn toggle_data_collection(&mut self, _cx: &mut App) {
479 // TODO [zeta2]
480 }
481
482 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
483 self.zeta.read(cx).usage(cx)
484 }
485
486 fn is_enabled(
487 &self,
488 _buffer: &Entity<language::Buffer>,
489 _cursor_position: language::Anchor,
490 _cx: &App,
491 ) -> bool {
492 true
493 }
494
495 fn is_refreshing(&self) -> bool {
496 !self.pending_predictions.is_empty()
497 }
498
499 fn refresh(
500 &mut self,
501 project: Option<Entity<project::Project>>,
502 buffer: Entity<language::Buffer>,
503 cursor_position: language::Anchor,
504 _debounce: bool,
505 cx: &mut Context<Self>,
506 ) {
507 let Some(project) = project else {
508 return;
509 };
510
511 if self
512 .zeta
513 .read(cx)
514 .user_store
515 .read_with(cx, |user_store, _cx| {
516 user_store.account_too_young() || user_store.has_overdue_invoices()
517 })
518 {
519 return;
520 }
521
522 if let Some(current_prediction) = self.current_prediction.as_ref() {
523 let snapshot = buffer.read(cx).snapshot();
524 if current_prediction
525 .prediction
526 .interpolate(&snapshot)
527 .is_some()
528 {
529 return;
530 }
531 }
532
533 let pending_prediction_id = self.next_pending_prediction_id;
534 self.next_pending_prediction_id += 1;
535 let last_request_timestamp = self.last_request_timestamp;
536
537 let task = cx.spawn(async move |this, cx| {
538 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
539 .checked_duration_since(Instant::now())
540 {
541 cx.background_executor().timer(timeout).await;
542 }
543
544 let prediction_request = this.update(cx, |this, cx| {
545 this.last_request_timestamp = Instant::now();
546 this.zeta.update(cx, |zeta, cx| {
547 zeta.request_prediction(&project, &buffer, cursor_position, cx)
548 })
549 });
550
551 let prediction = match prediction_request {
552 Ok(prediction_request) => {
553 let prediction_request = prediction_request.await;
554 prediction_request.map(|c| {
555 c.map(|prediction| CurrentEditPrediction {
556 buffer_id: buffer.entity_id(),
557 prediction,
558 })
559 })
560 }
561 Err(error) => Err(error),
562 };
563
564 this.update(cx, |this, cx| {
565 if this.pending_predictions[0].id == pending_prediction_id {
566 this.pending_predictions.remove(0);
567 } else {
568 this.pending_predictions.clear();
569 }
570
571 let Some(new_prediction) = prediction
572 .context("edit prediction failed")
573 .log_err()
574 .flatten()
575 else {
576 cx.notify();
577 return;
578 };
579
580 if let Some(old_prediction) = this.current_prediction.as_ref() {
581 let snapshot = buffer.read(cx).snapshot();
582 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
583 this.current_prediction = Some(new_prediction);
584 }
585 } else {
586 this.current_prediction = Some(new_prediction);
587 }
588
589 cx.notify();
590 })
591 .ok();
592 });
593
594 // We always maintain at most two pending predictions. When we already
595 // have two, we replace the newest one.
596 if self.pending_predictions.len() <= 1 {
597 self.pending_predictions.push(PendingPrediction {
598 id: pending_prediction_id,
599 _task: task,
600 });
601 } else if self.pending_predictions.len() == 2 {
602 self.pending_predictions.pop();
603 self.pending_predictions.push(PendingPrediction {
604 id: pending_prediction_id,
605 _task: task,
606 });
607 }
608
609 cx.notify();
610 }
611
612 fn cycle(
613 &mut self,
614 _buffer: Entity<language::Buffer>,
615 _cursor_position: language::Anchor,
616 _direction: Direction,
617 _cx: &mut Context<Self>,
618 ) {
619 }
620
621 fn accept(&mut self, _cx: &mut Context<Self>) {
622 // TODO [zeta2] report accept
623 self.current_prediction.take();
624 self.pending_predictions.clear();
625 }
626
627 fn discard(&mut self, _cx: &mut Context<Self>) {
628 self.pending_predictions.clear();
629 self.current_prediction.take();
630 }
631
632 fn suggest(
633 &mut self,
634 buffer: &Entity<language::Buffer>,
635 cursor_position: language::Anchor,
636 cx: &mut Context<Self>,
637 ) -> Option<edit_prediction::EditPrediction> {
638 let CurrentEditPrediction {
639 buffer_id,
640 prediction,
641 ..
642 } = self.current_prediction.as_mut()?;
643
644 // Invalidate previous prediction if it was generated for a different buffer.
645 if *buffer_id != buffer.entity_id() {
646 self.current_prediction.take();
647 return None;
648 }
649
650 let buffer = buffer.read(cx);
651 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
652 self.current_prediction.take();
653 return None;
654 };
655
656 let cursor_row = cursor_position.to_point(buffer).row;
657 let (closest_edit_ix, (closest_edit_range, _)) =
658 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
659 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
660 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
661 cmp::min(distance_from_start, distance_from_end)
662 })?;
663
664 let mut edit_start_ix = closest_edit_ix;
665 for (range, _) in edits[..edit_start_ix].iter().rev() {
666 let distance_from_closest_edit =
667 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
668 if distance_from_closest_edit <= 1 {
669 edit_start_ix -= 1;
670 } else {
671 break;
672 }
673 }
674
675 let mut edit_end_ix = closest_edit_ix + 1;
676 for (range, _) in &edits[edit_end_ix..] {
677 let distance_from_closest_edit =
678 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
679 if distance_from_closest_edit <= 1 {
680 edit_end_ix += 1;
681 } else {
682 break;
683 }
684 }
685
686 Some(edit_prediction::EditPrediction {
687 id: Some(prediction.id.to_string().into()),
688 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
689 edit_preview: Some(prediction.edit_preview.clone()),
690 })
691 }
692}
693
694fn make_cloud_request(
695 excerpt_path: PathBuf,
696 context: EditPredictionContext,
697 events: Vec<predict_edits_v3::Event>,
698 can_collect_data: bool,
699 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
700 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
701 debug_info: bool,
702 worktrees: &Vec<worktree::Snapshot>,
703 index_state: Option<&SyntaxIndexState>,
704) -> predict_edits_v3::PredictEditsRequest {
705 let mut signatures = Vec::new();
706 let mut declaration_to_signature_index = HashMap::default();
707 let mut referenced_declarations = Vec::new();
708
709 for snippet in context.snippets {
710 let project_entry_id = snippet.declaration.project_entry_id();
711 // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
712 // Note that currently full_path is currently being used for excerpt_path.
713 let Some(path) = worktrees.iter().find_map(|worktree| {
714 let abs_path = worktree.abs_path();
715 worktree
716 .entry_for_id(project_entry_id)
717 .map(|e| abs_path.join(&e.path))
718 }) else {
719 continue;
720 };
721
722 let parent_index = index_state.and_then(|index_state| {
723 snippet.declaration.parent().and_then(|parent| {
724 add_signature(
725 parent,
726 &mut declaration_to_signature_index,
727 &mut signatures,
728 index_state,
729 )
730 })
731 });
732
733 let (text, text_is_truncated) = snippet.declaration.item_text();
734 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
735 path,
736 text: text.into(),
737 range: snippet.declaration.item_range(),
738 text_is_truncated,
739 signature_range: snippet.declaration.signature_range_in_item_text(),
740 parent_index,
741 score_components: snippet.score_components,
742 signature_score: snippet.scores.signature,
743 declaration_score: snippet.scores.declaration,
744 });
745 }
746
747 let excerpt_parent = index_state.and_then(|index_state| {
748 context
749 .excerpt
750 .parent_declarations
751 .last()
752 .and_then(|(parent, _)| {
753 add_signature(
754 *parent,
755 &mut declaration_to_signature_index,
756 &mut signatures,
757 index_state,
758 )
759 })
760 });
761
762 predict_edits_v3::PredictEditsRequest {
763 excerpt_path,
764 excerpt: context.excerpt_text.body,
765 excerpt_range: context.excerpt.range,
766 cursor_offset: context.cursor_offset_in_excerpt,
767 referenced_declarations,
768 signatures,
769 excerpt_parent,
770 // todo!
771 events,
772 can_collect_data,
773 diagnostic_groups,
774 git_info,
775 debug_info,
776 }
777}
778
779fn add_signature(
780 declaration_id: DeclarationId,
781 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
782 signatures: &mut Vec<Signature>,
783 index: &SyntaxIndexState,
784) -> Option<usize> {
785 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
786 return Some(*signature_index);
787 }
788 let Some(parent_declaration) = index.declaration(declaration_id) else {
789 log::error!("bug: missing parent declaration");
790 return None;
791 };
792 let parent_index = parent_declaration.parent().and_then(|parent| {
793 add_signature(parent, declaration_to_signature_index, signatures, index)
794 });
795 let (text, text_is_truncated) = parent_declaration.signature_text();
796 let signature_index = signatures.len();
797 signatures.push(Signature {
798 text: text.into(),
799 text_is_truncated,
800 parent_index,
801 });
802 declaration_to_signature_index.insert(declaration_id, signature_index);
803 Some(signature_index)
804}
805
806fn interpolate(
807 old_snapshot: &BufferSnapshot,
808 new_snapshot: &BufferSnapshot,
809 current_edits: Arc<[(Range<Anchor>, String)]>,
810) -> Option<Vec<(Range<Anchor>, String)>> {
811 let mut edits = Vec::new();
812
813 let mut model_edits = current_edits.iter().peekable();
814 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
815 while let Some((model_old_range, _)) = model_edits.peek() {
816 let model_old_range = model_old_range.to_offset(old_snapshot);
817 if model_old_range.end < user_edit.old.start {
818 let (model_old_range, model_new_text) = model_edits.next().unwrap();
819 edits.push((model_old_range.clone(), model_new_text.clone()));
820 } else {
821 break;
822 }
823 }
824
825 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
826 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
827 if user_edit.old == model_old_offset_range {
828 let user_new_text = new_snapshot
829 .text_for_range(user_edit.new.clone())
830 .collect::<String>();
831
832 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
833 if !model_suffix.is_empty() {
834 let anchor = old_snapshot.anchor_after(user_edit.old.end);
835 edits.push((anchor..anchor, model_suffix.to_string()));
836 }
837
838 model_edits.next();
839 continue;
840 }
841 }
842 }
843
844 return None;
845 }
846
847 edits.extend(model_edits.cloned());
848
849 if edits.is_empty() { None } else { Some(edits) }
850}
851
852#[cfg(test)]
853mod tests {
854 use super::*;
855 use gpui::TestAppContext;
856 use language::ToOffset as _;
857
858 #[gpui::test]
859 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
860 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
861 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
862 to_prediction_edits(
863 [(2..5, "REM".to_string()), (9..11, "".to_string())],
864 &buffer,
865 cx,
866 )
867 .into()
868 });
869
870 let edit_preview = cx
871 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
872 .await;
873
874 let prediction = EditPrediction {
875 id: EditPredictionId(Uuid::new_v4()),
876 edits,
877 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
878 edit_preview,
879 };
880
881 cx.update(|cx| {
882 assert_eq!(
883 from_prediction_edits(
884 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
885 &buffer,
886 cx
887 ),
888 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
889 );
890
891 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
892 assert_eq!(
893 from_prediction_edits(
894 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
895 &buffer,
896 cx
897 ),
898 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
899 );
900
901 buffer.update(cx, |buffer, cx| buffer.undo(cx));
902 assert_eq!(
903 from_prediction_edits(
904 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
905 &buffer,
906 cx
907 ),
908 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
909 );
910
911 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
912 assert_eq!(
913 from_prediction_edits(
914 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
915 &buffer,
916 cx
917 ),
918 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
919 );
920
921 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
922 assert_eq!(
923 from_prediction_edits(
924 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
925 &buffer,
926 cx
927 ),
928 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
929 );
930
931 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
932 assert_eq!(
933 from_prediction_edits(
934 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
935 &buffer,
936 cx
937 ),
938 vec![(9..11, "".to_string())]
939 );
940
941 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
942 assert_eq!(
943 from_prediction_edits(
944 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
945 &buffer,
946 cx
947 ),
948 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
949 );
950
951 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
952 assert_eq!(
953 from_prediction_edits(
954 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
955 &buffer,
956 cx
957 ),
958 vec![(4..4, "M".to_string())]
959 );
960
961 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
962 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
963 })
964 }
965
966 fn to_prediction_edits(
967 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
968 buffer: &Entity<Buffer>,
969 cx: &App,
970 ) -> Vec<(Range<Anchor>, String)> {
971 let buffer = buffer.read(cx);
972 iterator
973 .into_iter()
974 .map(|(range, text)| {
975 (
976 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
977 text,
978 )
979 })
980 .collect()
981 }
982
983 fn from_prediction_edits(
984 editor_edits: &[(Range<Anchor>, String)],
985 buffer: &Entity<Buffer>,
986 cx: &App,
987 ) -> Vec<(Range<usize>, String)> {
988 let buffer = buffer.read(cx);
989 editor_edits
990 .iter()
991 .map(|(range, text)| {
992 (
993 range.start.to_offset(buffer)..range.end.to_offset(buffer),
994 text.clone(),
995 )
996 })
997 .collect()
998 }
999}