1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use chrono::TimeDelta;
4use client::{Client, EditPredictionUsage, UserStore};
5use cloud_llm_client::predict_edits_v3::{self, Signature};
6use cloud_llm_client::{
7 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
8};
9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
10use edit_prediction_context::{
11 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
12 SyntaxIndexState,
13};
14use futures::AsyncReadExt as _;
15use futures::channel::mpsc;
16use gpui::http_client::Method;
17use gpui::{
18 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
19 prelude::*,
20};
21use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
22use language::{BufferSnapshot, EditPreview};
23use language_model::{LlmApiToken, RefreshLlmTokenListener};
24use project::Project;
25use release_channel::AppVersion;
26use std::cmp;
27use std::collections::{HashMap, VecDeque, hash_map};
28use std::path::PathBuf;
29use std::str::FromStr as _;
30use std::time::{Duration, Instant};
31use std::{ops::Range, sync::Arc};
32use thiserror::Error;
33use util::{ResultExt as _, some_or_debug_panic};
34use uuid::Uuid;
35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
36
37const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
38
39/// Maximum number of events to track.
40const MAX_EVENT_COUNT: usize = 16;
41
42#[derive(Clone)]
43struct ZetaGlobal(Entity<Zeta>);
44
45impl Global for ZetaGlobal {}
46
47pub struct Zeta {
48 client: Arc<Client>,
49 user_store: Entity<UserStore>,
50 llm_token: LlmApiToken,
51 _llm_token_subscription: Subscription,
52 projects: HashMap<EntityId, ZetaProject>,
53 excerpt_options: EditPredictionExcerptOptions,
54 update_required: bool,
55 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
56}
57
58pub struct PredictionDebugInfo {
59 pub context: EditPredictionContext,
60 pub retrieval_time: TimeDelta,
61 pub request: RequestDebugInfo,
62}
63
64pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
65
66struct ZetaProject {
67 syntax_index: Entity<SyntaxIndex>,
68 events: VecDeque<Event>,
69 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
70}
71
72struct RegisteredBuffer {
73 snapshot: BufferSnapshot,
74 _subscriptions: [gpui::Subscription; 2],
75}
76
77#[derive(Clone)]
78pub enum Event {
79 BufferChange {
80 old_snapshot: BufferSnapshot,
81 new_snapshot: BufferSnapshot,
82 timestamp: Instant,
83 },
84}
85
86impl Event {
87 //TODO: Actually use the events this in the prompt
88 // fn to_prompt(&self) -> String {
89 // match self {
90 // Event::BufferChange {
91 // old_snapshot,
92 // new_snapshot,
93 // ..
94 // } => {
95 // let mut prompt = String::new();
96
97 // let old_path = old_snapshot
98 // .file()
99 // .map(|f| f.path().as_ref())
100 // .unwrap_or(Path::new("untitled"));
101 // let new_path = new_snapshot
102 // .file()
103 // .map(|f| f.path().as_ref())
104 // .unwrap_or(Path::new("untitled"));
105 // if old_path != new_path {
106 // writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
107 // }
108
109 // let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
110 // if !diff.is_empty() {
111 // write!(
112 // prompt,
113 // "User edited {:?}:\n```diff\n{}\n```",
114 // new_path, diff
115 // )
116 // .unwrap();
117 // }
118
119 // prompt
120 // }
121 // }
122 // }
123}
124
125impl Zeta {
126 pub fn global(
127 client: &Arc<Client>,
128 user_store: &Entity<UserStore>,
129 cx: &mut App,
130 ) -> Entity<Self> {
131 cx.try_global::<ZetaGlobal>()
132 .map(|global| global.0.clone())
133 .unwrap_or_else(|| {
134 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
135 cx.set_global(ZetaGlobal(zeta.clone()));
136 zeta
137 })
138 }
139
140 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
141 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
142
143 Self {
144 projects: HashMap::new(),
145 client,
146 user_store,
147 excerpt_options: EditPredictionExcerptOptions {
148 max_bytes: 512,
149 min_bytes: 128,
150 target_before_cursor_over_total_bytes: 0.5,
151 },
152 llm_token: LlmApiToken::default(),
153 _llm_token_subscription: cx.subscribe(
154 &refresh_llm_token_listener,
155 |this, _listener, _event, cx| {
156 let client = this.client.clone();
157 let llm_token = this.llm_token.clone();
158 cx.spawn(async move |_this, _cx| {
159 llm_token.refresh(&client).await?;
160 anyhow::Ok(())
161 })
162 .detach_and_log_err(cx);
163 },
164 ),
165 update_required: false,
166 debug_tx: None,
167 }
168 }
169
170 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
171 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
172 self.debug_tx = Some(debug_watch_tx);
173 debug_watch_rx
174 }
175
176 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
177 self.user_store.read(cx).edit_prediction_usage()
178 }
179
180 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
181 self.get_or_init_zeta_project(project, cx);
182 }
183
184 pub fn register_buffer(
185 &mut self,
186 buffer: &Entity<Buffer>,
187 project: &Entity<Project>,
188 cx: &mut Context<Self>,
189 ) {
190 let zeta_project = self.get_or_init_zeta_project(project, cx);
191 Self::register_buffer_impl(zeta_project, buffer, project, cx);
192 }
193
194 fn get_or_init_zeta_project(
195 &mut self,
196 project: &Entity<Project>,
197 cx: &mut App,
198 ) -> &mut ZetaProject {
199 self.projects
200 .entry(project.entity_id())
201 .or_insert_with(|| ZetaProject {
202 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
203 events: VecDeque::new(),
204 registered_buffers: HashMap::new(),
205 })
206 }
207
208 fn register_buffer_impl<'a>(
209 zeta_project: &'a mut ZetaProject,
210 buffer: &Entity<Buffer>,
211 project: &Entity<Project>,
212 cx: &mut Context<Self>,
213 ) -> &'a mut RegisteredBuffer {
214 let buffer_id = buffer.entity_id();
215 match zeta_project.registered_buffers.entry(buffer_id) {
216 hash_map::Entry::Occupied(entry) => entry.into_mut(),
217 hash_map::Entry::Vacant(entry) => {
218 let snapshot = buffer.read(cx).snapshot();
219 let project_entity_id = project.entity_id();
220 entry.insert(RegisteredBuffer {
221 snapshot,
222 _subscriptions: [
223 cx.subscribe(buffer, {
224 let project = project.downgrade();
225 move |this, buffer, event, cx| {
226 if let language::BufferEvent::Edited = event
227 && let Some(project) = project.upgrade()
228 {
229 this.report_changes_for_buffer(&buffer, &project, cx);
230 }
231 }
232 }),
233 cx.observe_release(buffer, move |this, _buffer, _cx| {
234 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
235 else {
236 return;
237 };
238 zeta_project.registered_buffers.remove(&buffer_id);
239 }),
240 ],
241 })
242 }
243 }
244 }
245
246 fn report_changes_for_buffer(
247 &mut self,
248 buffer: &Entity<Buffer>,
249 project: &Entity<Project>,
250 cx: &mut Context<Self>,
251 ) -> BufferSnapshot {
252 let zeta_project = self.get_or_init_zeta_project(project, cx);
253 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
254
255 let new_snapshot = buffer.read(cx).snapshot();
256 if new_snapshot.version != registered_buffer.snapshot.version {
257 let old_snapshot =
258 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
259 Self::push_event(
260 zeta_project,
261 Event::BufferChange {
262 old_snapshot,
263 new_snapshot: new_snapshot.clone(),
264 timestamp: Instant::now(),
265 },
266 );
267 }
268
269 new_snapshot
270 }
271
272 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
273 let events = &mut zeta_project.events;
274
275 if let Some(Event::BufferChange {
276 new_snapshot: last_new_snapshot,
277 timestamp: last_timestamp,
278 ..
279 }) = events.back_mut()
280 {
281 // Coalesce edits for the same buffer when they happen one after the other.
282 let Event::BufferChange {
283 old_snapshot,
284 new_snapshot,
285 timestamp,
286 } = &event;
287
288 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
289 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
290 && old_snapshot.version == last_new_snapshot.version
291 {
292 *last_new_snapshot = new_snapshot.clone();
293 *last_timestamp = *timestamp;
294 return;
295 }
296 }
297
298 if events.len() >= MAX_EVENT_COUNT {
299 // These are halved instead of popping to improve prompt caching.
300 events.drain(..MAX_EVENT_COUNT / 2);
301 }
302
303 events.push_back(event);
304 }
305
306 pub fn request_prediction(
307 &mut self,
308 project: &Entity<Project>,
309 buffer: &Entity<Buffer>,
310 position: language::Anchor,
311 cx: &mut Context<Self>,
312 ) -> Task<Result<Option<EditPrediction>>> {
313 let project_state = self.projects.get(&project.entity_id());
314
315 let index_state = project_state.map(|state| {
316 state
317 .syntax_index
318 .read_with(cx, |index, _cx| index.state().clone())
319 });
320 let excerpt_options = self.excerpt_options.clone();
321 let snapshot = buffer.read(cx).snapshot();
322 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
323 return Task::ready(Err(anyhow!("No file path for excerpt")));
324 };
325 let client = self.client.clone();
326 let llm_token = self.llm_token.clone();
327 let app_version = AppVersion::global(cx);
328 let worktree_snapshots = project
329 .read(cx)
330 .worktrees(cx)
331 .map(|worktree| worktree.read(cx).snapshot())
332 .collect::<Vec<_>>();
333 let debug_tx = self.debug_tx.clone();
334
335 let request_task = cx.background_spawn({
336 let snapshot = snapshot.clone();
337 async move {
338 let index_state = if let Some(index_state) = index_state {
339 Some(index_state.lock_owned().await)
340 } else {
341 None
342 };
343
344 let cursor_point = position.to_point(&snapshot);
345
346 let before_retrieval = chrono::Utc::now();
347
348 let Some(context) = EditPredictionContext::gather_context(
349 cursor_point,
350 &snapshot,
351 &excerpt_options,
352 index_state.as_deref(),
353 ) else {
354 return Ok(None);
355 };
356
357 let debug_context = if let Some(debug_tx) = debug_tx {
358 Some((debug_tx, context.clone()))
359 } else {
360 None
361 };
362
363 let request = make_cloud_request(
364 excerpt_path.clone(),
365 context,
366 // TODO pass everything
367 Vec::new(),
368 false,
369 Vec::new(),
370 None,
371 debug_context.is_some(),
372 &worktree_snapshots,
373 index_state.as_deref(),
374 );
375
376 let retrieval_time = chrono::Utc::now() - before_retrieval;
377 let response = Self::perform_request(client, llm_token, app_version, request).await;
378
379 if let Some((debug_tx, context)) = debug_context {
380 debug_tx
381 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
382 |response| {
383 let Some(request) =
384 some_or_debug_panic(response.0.debug_info.clone())
385 else {
386 return Err("Missing debug info".to_string());
387 };
388 Ok(PredictionDebugInfo {
389 context,
390 request,
391 retrieval_time,
392 })
393 },
394 ))
395 .ok();
396 }
397
398 anyhow::Ok(Some(response?))
399 }
400 });
401
402 let buffer = buffer.clone();
403
404 cx.spawn(async move |this, cx| {
405 match request_task.await {
406 Ok(Some((response, usage))) => {
407 log::debug!("predicted edits: {:?}", &response.edits);
408
409 if let Some(usage) = usage {
410 this.update(cx, |this, cx| {
411 this.user_store.update(cx, |user_store, cx| {
412 user_store.update_edit_prediction_usage(usage, cx);
413 });
414 })
415 .ok();
416 }
417
418 // TODO telemetry: duration, etc
419
420 // TODO produce smaller edits by diffing against snapshot first
421 //
422 // Cloud returns entire snippets/excerpts ranges as they were included
423 // in the request, but we should display smaller edits to the user.
424 //
425 // We can do this by computing a diff of each one against the snapshot.
426 // Similar to zeta::Zeta::compute_edits, but per edit.
427 let edits = response
428 .edits
429 .into_iter()
430 .map(|edit| {
431 // TODO edits to different files
432 (
433 snapshot.anchor_before(edit.range.start)
434 ..snapshot.anchor_before(edit.range.end),
435 edit.content,
436 )
437 })
438 .collect::<Vec<_>>()
439 .into();
440
441 let Some((edits, snapshot, edit_preview_task)) =
442 buffer.read_with(cx, |buffer, cx| {
443 let new_snapshot = buffer.snapshot();
444 let edits: Arc<[_]> =
445 interpolate(&snapshot, &new_snapshot, edits)?.into();
446 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
447 })?
448 else {
449 return Ok(None);
450 };
451
452 Ok(Some(EditPrediction {
453 id: EditPredictionId(response.request_id),
454 edits,
455 snapshot,
456 edit_preview: edit_preview_task.await,
457 }))
458 }
459 Ok(None) => Ok(None),
460 Err(err) => {
461 if err.is::<ZedUpdateRequiredError>() {
462 cx.update(|cx| {
463 this.update(cx, |this, _cx| {
464 this.update_required = true;
465 })
466 .ok();
467
468 let error_message: SharedString = err.to_string().into();
469 show_app_notification(
470 NotificationId::unique::<ZedUpdateRequiredError>(),
471 cx,
472 move |cx| {
473 cx.new(|cx| {
474 ErrorMessagePrompt::new(error_message.clone(), cx)
475 .with_link_button(
476 "Update Zed",
477 "https://zed.dev/releases",
478 )
479 })
480 },
481 );
482 })
483 .ok();
484 }
485
486 Err(err)
487 }
488 }
489 })
490 }
491
492 async fn perform_request(
493 client: Arc<Client>,
494 llm_token: LlmApiToken,
495 app_version: SemanticVersion,
496 request: predict_edits_v3::PredictEditsRequest,
497 ) -> Result<(
498 predict_edits_v3::PredictEditsResponse,
499 Option<EditPredictionUsage>,
500 )> {
501 let http_client = client.http_client();
502 let mut token = llm_token.acquire(&client).await?;
503 let mut did_retry = false;
504
505 loop {
506 let request_builder = http_client::Request::builder().method(Method::POST);
507 let request_builder =
508 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
509 request_builder.uri(predict_edits_url)
510 } else {
511 request_builder.uri(
512 http_client
513 .build_zed_llm_url("/predict_edits/v3", &[])?
514 .as_ref(),
515 )
516 };
517 let request = request_builder
518 .header("Content-Type", "application/json")
519 .header("Authorization", format!("Bearer {}", token))
520 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
521 .body(serde_json::to_string(&request)?.into())?;
522
523 let mut response = http_client.send(request).await?;
524
525 if let Some(minimum_required_version) = response
526 .headers()
527 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
528 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
529 {
530 anyhow::ensure!(
531 app_version >= minimum_required_version,
532 ZedUpdateRequiredError {
533 minimum_version: minimum_required_version
534 }
535 );
536 }
537
538 if response.status().is_success() {
539 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
540
541 let mut body = Vec::new();
542 response.body_mut().read_to_end(&mut body).await?;
543 return Ok((serde_json::from_slice(&body)?, usage));
544 } else if !did_retry
545 && response
546 .headers()
547 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
548 .is_some()
549 {
550 did_retry = true;
551 token = llm_token.refresh(&client).await?;
552 } else {
553 let mut body = String::new();
554 response.body_mut().read_to_string(&mut body).await?;
555 anyhow::bail!(
556 "error predicting edits.\nStatus: {:?}\nBody: {}",
557 response.status(),
558 body
559 );
560 }
561 }
562 }
563}
564
565#[derive(Error, Debug)]
566#[error(
567 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
568)]
569pub struct ZedUpdateRequiredError {
570 minimum_version: SemanticVersion,
571}
572
573pub struct ZetaEditPredictionProvider {
574 zeta: Entity<Zeta>,
575 current_prediction: Option<CurrentEditPrediction>,
576 next_pending_prediction_id: usize,
577 pending_predictions: ArrayVec<PendingPrediction, 2>,
578 last_request_timestamp: Instant,
579}
580
581impl ZetaEditPredictionProvider {
582 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
583
584 pub fn new(
585 project: Option<&Entity<Project>>,
586 client: &Arc<Client>,
587 user_store: &Entity<UserStore>,
588 cx: &mut App,
589 ) -> Self {
590 let zeta = Zeta::global(client, user_store, cx);
591 if let Some(project) = project {
592 zeta.update(cx, |zeta, cx| {
593 zeta.register_project(project, cx);
594 });
595 }
596
597 Self {
598 zeta,
599 current_prediction: None,
600 next_pending_prediction_id: 0,
601 pending_predictions: ArrayVec::new(),
602 last_request_timestamp: Instant::now(),
603 }
604 }
605}
606
607#[derive(Clone)]
608struct CurrentEditPrediction {
609 buffer_id: EntityId,
610 prediction: EditPrediction,
611}
612
613impl CurrentEditPrediction {
614 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
615 if self.buffer_id != old_prediction.buffer_id {
616 return true;
617 }
618
619 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
620 return true;
621 };
622 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
623 return false;
624 };
625
626 if old_edits.len() == 1 && new_edits.len() == 1 {
627 let (old_range, old_text) = &old_edits[0];
628 let (new_range, new_text) = &new_edits[0];
629 new_range == old_range && new_text.starts_with(old_text)
630 } else {
631 true
632 }
633 }
634}
635
636#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
637pub struct EditPredictionId(Uuid);
638
639impl From<EditPredictionId> for gpui::ElementId {
640 fn from(value: EditPredictionId) -> Self {
641 gpui::ElementId::Uuid(value.0)
642 }
643}
644
645impl std::fmt::Display for EditPredictionId {
646 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
647 write!(f, "{}", self.0)
648 }
649}
650
651#[derive(Clone)]
652pub struct EditPrediction {
653 id: EditPredictionId,
654 edits: Arc<[(Range<Anchor>, String)]>,
655 snapshot: BufferSnapshot,
656 edit_preview: EditPreview,
657}
658
659impl EditPrediction {
660 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
661 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
662 }
663}
664
665struct PendingPrediction {
666 id: usize,
667 _task: Task<()>,
668}
669
670impl EditPredictionProvider for ZetaEditPredictionProvider {
671 fn name() -> &'static str {
672 "zed-predict2"
673 }
674
675 fn display_name() -> &'static str {
676 "Zed's Edit Predictions 2"
677 }
678
679 fn show_completions_in_menu() -> bool {
680 true
681 }
682
683 fn show_tab_accept_marker() -> bool {
684 true
685 }
686
687 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
688 // TODO [zeta2]
689 DataCollectionState::Unsupported
690 }
691
692 fn toggle_data_collection(&mut self, _cx: &mut App) {
693 // TODO [zeta2]
694 }
695
696 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
697 self.zeta.read(cx).usage(cx)
698 }
699
700 fn is_enabled(
701 &self,
702 _buffer: &Entity<language::Buffer>,
703 _cursor_position: language::Anchor,
704 _cx: &App,
705 ) -> bool {
706 true
707 }
708
709 fn is_refreshing(&self) -> bool {
710 !self.pending_predictions.is_empty()
711 }
712
713 fn refresh(
714 &mut self,
715 project: Option<Entity<project::Project>>,
716 buffer: Entity<language::Buffer>,
717 cursor_position: language::Anchor,
718 _debounce: bool,
719 cx: &mut Context<Self>,
720 ) {
721 let Some(project) = project else {
722 return;
723 };
724
725 if self
726 .zeta
727 .read(cx)
728 .user_store
729 .read_with(cx, |user_store, _cx| {
730 user_store.account_too_young() || user_store.has_overdue_invoices()
731 })
732 {
733 return;
734 }
735
736 if let Some(current_prediction) = self.current_prediction.as_ref() {
737 let snapshot = buffer.read(cx).snapshot();
738 if current_prediction
739 .prediction
740 .interpolate(&snapshot)
741 .is_some()
742 {
743 return;
744 }
745 }
746
747 let pending_prediction_id = self.next_pending_prediction_id;
748 self.next_pending_prediction_id += 1;
749 let last_request_timestamp = self.last_request_timestamp;
750
751 let task = cx.spawn(async move |this, cx| {
752 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
753 .checked_duration_since(Instant::now())
754 {
755 cx.background_executor().timer(timeout).await;
756 }
757
758 let prediction_request = this.update(cx, |this, cx| {
759 this.last_request_timestamp = Instant::now();
760 this.zeta.update(cx, |zeta, cx| {
761 zeta.request_prediction(&project, &buffer, cursor_position, cx)
762 })
763 });
764
765 let prediction = match prediction_request {
766 Ok(prediction_request) => {
767 let prediction_request = prediction_request.await;
768 prediction_request.map(|c| {
769 c.map(|prediction| CurrentEditPrediction {
770 buffer_id: buffer.entity_id(),
771 prediction,
772 })
773 })
774 }
775 Err(error) => Err(error),
776 };
777
778 this.update(cx, |this, cx| {
779 if this.pending_predictions[0].id == pending_prediction_id {
780 this.pending_predictions.remove(0);
781 } else {
782 this.pending_predictions.clear();
783 }
784
785 let Some(new_prediction) = prediction
786 .context("edit prediction failed")
787 .log_err()
788 .flatten()
789 else {
790 cx.notify();
791 return;
792 };
793
794 if let Some(old_prediction) = this.current_prediction.as_ref() {
795 let snapshot = buffer.read(cx).snapshot();
796 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
797 this.current_prediction = Some(new_prediction);
798 }
799 } else {
800 this.current_prediction = Some(new_prediction);
801 }
802
803 cx.notify();
804 })
805 .ok();
806 });
807
808 // We always maintain at most two pending predictions. When we already
809 // have two, we replace the newest one.
810 if self.pending_predictions.len() <= 1 {
811 self.pending_predictions.push(PendingPrediction {
812 id: pending_prediction_id,
813 _task: task,
814 });
815 } else if self.pending_predictions.len() == 2 {
816 self.pending_predictions.pop();
817 self.pending_predictions.push(PendingPrediction {
818 id: pending_prediction_id,
819 _task: task,
820 });
821 }
822
823 cx.notify();
824 }
825
826 fn cycle(
827 &mut self,
828 _buffer: Entity<language::Buffer>,
829 _cursor_position: language::Anchor,
830 _direction: Direction,
831 _cx: &mut Context<Self>,
832 ) {
833 }
834
835 fn accept(&mut self, _cx: &mut Context<Self>) {
836 // TODO [zeta2] report accept
837 self.current_prediction.take();
838 self.pending_predictions.clear();
839 }
840
841 fn discard(&mut self, _cx: &mut Context<Self>) {
842 self.pending_predictions.clear();
843 self.current_prediction.take();
844 }
845
846 fn suggest(
847 &mut self,
848 buffer: &Entity<language::Buffer>,
849 cursor_position: language::Anchor,
850 cx: &mut Context<Self>,
851 ) -> Option<edit_prediction::EditPrediction> {
852 let CurrentEditPrediction {
853 buffer_id,
854 prediction,
855 ..
856 } = self.current_prediction.as_mut()?;
857
858 // Invalidate previous prediction if it was generated for a different buffer.
859 if *buffer_id != buffer.entity_id() {
860 self.current_prediction.take();
861 return None;
862 }
863
864 let buffer = buffer.read(cx);
865 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
866 self.current_prediction.take();
867 return None;
868 };
869
870 let cursor_row = cursor_position.to_point(buffer).row;
871 let (closest_edit_ix, (closest_edit_range, _)) =
872 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
873 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
874 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
875 cmp::min(distance_from_start, distance_from_end)
876 })?;
877
878 let mut edit_start_ix = closest_edit_ix;
879 for (range, _) in edits[..edit_start_ix].iter().rev() {
880 let distance_from_closest_edit =
881 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
882 if distance_from_closest_edit <= 1 {
883 edit_start_ix -= 1;
884 } else {
885 break;
886 }
887 }
888
889 let mut edit_end_ix = closest_edit_ix + 1;
890 for (range, _) in &edits[edit_end_ix..] {
891 let distance_from_closest_edit =
892 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
893 if distance_from_closest_edit <= 1 {
894 edit_end_ix += 1;
895 } else {
896 break;
897 }
898 }
899
900 Some(edit_prediction::EditPrediction {
901 id: Some(prediction.id.to_string().into()),
902 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
903 edit_preview: Some(prediction.edit_preview.clone()),
904 })
905 }
906}
907
908fn make_cloud_request(
909 excerpt_path: PathBuf,
910 context: EditPredictionContext,
911 events: Vec<predict_edits_v3::Event>,
912 can_collect_data: bool,
913 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
914 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
915 debug_info: bool,
916 worktrees: &Vec<worktree::Snapshot>,
917 index_state: Option<&SyntaxIndexState>,
918) -> predict_edits_v3::PredictEditsRequest {
919 let mut signatures = Vec::new();
920 let mut declaration_to_signature_index = HashMap::default();
921 let mut referenced_declarations = Vec::new();
922
923 for snippet in context.snippets {
924 let project_entry_id = snippet.declaration.project_entry_id();
925 // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
926 // Note that currently full_path is currently being used for excerpt_path.
927 let Some(path) = worktrees.iter().find_map(|worktree| {
928 let abs_path = worktree.abs_path();
929 worktree
930 .entry_for_id(project_entry_id)
931 .map(|e| abs_path.join(&e.path))
932 }) else {
933 continue;
934 };
935
936 let parent_index = index_state.and_then(|index_state| {
937 snippet.declaration.parent().and_then(|parent| {
938 add_signature(
939 parent,
940 &mut declaration_to_signature_index,
941 &mut signatures,
942 index_state,
943 )
944 })
945 });
946
947 let (text, text_is_truncated) = snippet.declaration.item_text();
948 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
949 path,
950 text: text.into(),
951 range: snippet.declaration.item_range(),
952 text_is_truncated,
953 signature_range: snippet.declaration.signature_range_in_item_text(),
954 parent_index,
955 score_components: snippet.score_components,
956 signature_score: snippet.scores.signature,
957 declaration_score: snippet.scores.declaration,
958 });
959 }
960
961 let excerpt_parent = index_state.and_then(|index_state| {
962 context
963 .excerpt
964 .parent_declarations
965 .last()
966 .and_then(|(parent, _)| {
967 add_signature(
968 *parent,
969 &mut declaration_to_signature_index,
970 &mut signatures,
971 index_state,
972 )
973 })
974 });
975
976 predict_edits_v3::PredictEditsRequest {
977 excerpt_path,
978 excerpt: context.excerpt_text.body,
979 excerpt_range: context.excerpt.range,
980 cursor_offset: context.cursor_offset_in_excerpt,
981 referenced_declarations,
982 signatures,
983 excerpt_parent,
984 // todo!
985 events,
986 can_collect_data,
987 diagnostic_groups,
988 git_info,
989 debug_info,
990 }
991}
992
993fn add_signature(
994 declaration_id: DeclarationId,
995 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
996 signatures: &mut Vec<Signature>,
997 index: &SyntaxIndexState,
998) -> Option<usize> {
999 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1000 return Some(*signature_index);
1001 }
1002 let Some(parent_declaration) = index.declaration(declaration_id) else {
1003 log::error!("bug: missing parent declaration");
1004 return None;
1005 };
1006 let parent_index = parent_declaration.parent().and_then(|parent| {
1007 add_signature(parent, declaration_to_signature_index, signatures, index)
1008 });
1009 let (text, text_is_truncated) = parent_declaration.signature_text();
1010 let signature_index = signatures.len();
1011 signatures.push(Signature {
1012 text: text.into(),
1013 text_is_truncated,
1014 parent_index,
1015 });
1016 declaration_to_signature_index.insert(declaration_id, signature_index);
1017 Some(signature_index)
1018}
1019
1020fn interpolate(
1021 old_snapshot: &BufferSnapshot,
1022 new_snapshot: &BufferSnapshot,
1023 current_edits: Arc<[(Range<Anchor>, String)]>,
1024) -> Option<Vec<(Range<Anchor>, String)>> {
1025 let mut edits = Vec::new();
1026
1027 let mut model_edits = current_edits.iter().peekable();
1028 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1029 while let Some((model_old_range, _)) = model_edits.peek() {
1030 let model_old_range = model_old_range.to_offset(old_snapshot);
1031 if model_old_range.end < user_edit.old.start {
1032 let (model_old_range, model_new_text) = model_edits.next().unwrap();
1033 edits.push((model_old_range.clone(), model_new_text.clone()));
1034 } else {
1035 break;
1036 }
1037 }
1038
1039 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1040 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1041 if user_edit.old == model_old_offset_range {
1042 let user_new_text = new_snapshot
1043 .text_for_range(user_edit.new.clone())
1044 .collect::<String>();
1045
1046 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1047 if !model_suffix.is_empty() {
1048 let anchor = old_snapshot.anchor_after(user_edit.old.end);
1049 edits.push((anchor..anchor, model_suffix.to_string()));
1050 }
1051
1052 model_edits.next();
1053 continue;
1054 }
1055 }
1056 }
1057
1058 return None;
1059 }
1060
1061 edits.extend(model_edits.cloned());
1062
1063 if edits.is_empty() { None } else { Some(edits) }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068 use super::*;
1069 use gpui::TestAppContext;
1070 use language::ToOffset as _;
1071
1072 #[gpui::test]
1073 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1074 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1075 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1076 to_prediction_edits(
1077 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1078 &buffer,
1079 cx,
1080 )
1081 .into()
1082 });
1083
1084 let edit_preview = cx
1085 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1086 .await;
1087
1088 let prediction = EditPrediction {
1089 id: EditPredictionId(Uuid::new_v4()),
1090 edits,
1091 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1092 edit_preview,
1093 };
1094
1095 cx.update(|cx| {
1096 assert_eq!(
1097 from_prediction_edits(
1098 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1099 &buffer,
1100 cx
1101 ),
1102 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1103 );
1104
1105 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1106 assert_eq!(
1107 from_prediction_edits(
1108 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1109 &buffer,
1110 cx
1111 ),
1112 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1113 );
1114
1115 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1116 assert_eq!(
1117 from_prediction_edits(
1118 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1119 &buffer,
1120 cx
1121 ),
1122 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1123 );
1124
1125 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1126 assert_eq!(
1127 from_prediction_edits(
1128 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1129 &buffer,
1130 cx
1131 ),
1132 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1133 );
1134
1135 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1136 assert_eq!(
1137 from_prediction_edits(
1138 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1139 &buffer,
1140 cx
1141 ),
1142 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1143 );
1144
1145 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1146 assert_eq!(
1147 from_prediction_edits(
1148 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1149 &buffer,
1150 cx
1151 ),
1152 vec![(9..11, "".to_string())]
1153 );
1154
1155 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1156 assert_eq!(
1157 from_prediction_edits(
1158 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1159 &buffer,
1160 cx
1161 ),
1162 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1163 );
1164
1165 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1166 assert_eq!(
1167 from_prediction_edits(
1168 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1169 &buffer,
1170 cx
1171 ),
1172 vec![(4..4, "M".to_string())]
1173 );
1174
1175 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1176 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1177 })
1178 }
1179
1180 fn to_prediction_edits(
1181 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1182 buffer: &Entity<Buffer>,
1183 cx: &App,
1184 ) -> Vec<(Range<Anchor>, String)> {
1185 let buffer = buffer.read(cx);
1186 iterator
1187 .into_iter()
1188 .map(|(range, text)| {
1189 (
1190 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1191 text,
1192 )
1193 })
1194 .collect()
1195 }
1196
1197 fn from_prediction_edits(
1198 editor_edits: &[(Range<Anchor>, String)],
1199 buffer: &Entity<Buffer>,
1200 cx: &App,
1201 ) -> Vec<(Range<usize>, String)> {
1202 let buffer = buffer.read(cx);
1203 editor_edits
1204 .iter()
1205 .map(|(range, text)| {
1206 (
1207 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1208 text.clone(),
1209 )
1210 })
1211 .collect()
1212 }
1213}