1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use chrono::TimeDelta;
4use client::{Client, EditPredictionUsage, UserStore};
5use cloud_llm_client::predict_edits_v3::{self, Signature};
6use cloud_llm_client::{
7 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
8};
9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
10use edit_prediction_context::{
11 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
12 SyntaxIndexState,
13};
14use futures::AsyncReadExt as _;
15use futures::channel::mpsc;
16use gpui::http_client::Method;
17use gpui::{
18 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
19 http_client, prelude::*,
20};
21use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
22use language::{BufferSnapshot, EditPreview};
23use language_model::{LlmApiToken, RefreshLlmTokenListener};
24use project::Project;
25use release_channel::AppVersion;
26use std::cmp;
27use std::collections::{HashMap, VecDeque, hash_map};
28use std::path::PathBuf;
29use std::str::FromStr as _;
30use std::time::{Duration, Instant};
31use std::{ops::Range, sync::Arc};
32use thiserror::Error;
33use util::{ResultExt as _, some_or_debug_panic};
34use uuid::Uuid;
35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
36
37const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
38
39/// Maximum number of events to track.
40const MAX_EVENT_COUNT: usize = 16;
41
42pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
43 max_bytes: 512,
44 min_bytes: 128,
45 target_before_cursor_over_total_bytes: 0.5,
46};
47
48#[derive(Clone)]
49struct ZetaGlobal(Entity<Zeta>);
50
51impl Global for ZetaGlobal {}
52
53pub struct Zeta {
54 client: Arc<Client>,
55 user_store: Entity<UserStore>,
56 llm_token: LlmApiToken,
57 _llm_token_subscription: Subscription,
58 projects: HashMap<EntityId, ZetaProject>,
59 pub excerpt_options: EditPredictionExcerptOptions,
60 update_required: bool,
61 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
62}
63
64pub struct PredictionDebugInfo {
65 pub context: EditPredictionContext,
66 pub retrieval_time: TimeDelta,
67 pub request: RequestDebugInfo,
68 pub buffer: WeakEntity<Buffer>,
69 pub position: language::Anchor,
70}
71
72pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
73
74struct ZetaProject {
75 syntax_index: Entity<SyntaxIndex>,
76 events: VecDeque<Event>,
77 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
78}
79
80struct RegisteredBuffer {
81 snapshot: BufferSnapshot,
82 _subscriptions: [gpui::Subscription; 2],
83}
84
85#[derive(Clone)]
86pub enum Event {
87 BufferChange {
88 old_snapshot: BufferSnapshot,
89 new_snapshot: BufferSnapshot,
90 timestamp: Instant,
91 },
92}
93
94impl Zeta {
95 pub fn global(
96 client: &Arc<Client>,
97 user_store: &Entity<UserStore>,
98 cx: &mut App,
99 ) -> Entity<Self> {
100 cx.try_global::<ZetaGlobal>()
101 .map(|global| global.0.clone())
102 .unwrap_or_else(|| {
103 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
104 cx.set_global(ZetaGlobal(zeta.clone()));
105 zeta
106 })
107 }
108
109 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
110 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
111
112 Self {
113 projects: HashMap::new(),
114 client,
115 user_store,
116 excerpt_options: DEFAULT_EXCERPT_OPTIONS,
117 llm_token: LlmApiToken::default(),
118 _llm_token_subscription: cx.subscribe(
119 &refresh_llm_token_listener,
120 |this, _listener, _event, cx| {
121 let client = this.client.clone();
122 let llm_token = this.llm_token.clone();
123 cx.spawn(async move |_this, _cx| {
124 llm_token.refresh(&client).await?;
125 anyhow::Ok(())
126 })
127 .detach_and_log_err(cx);
128 },
129 ),
130 update_required: false,
131 debug_tx: None,
132 }
133 }
134
135 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
136 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
137 self.debug_tx = Some(debug_watch_tx);
138 debug_watch_rx
139 }
140
141 pub fn excerpt_options(&self) -> &EditPredictionExcerptOptions {
142 &self.excerpt_options
143 }
144
145 pub fn set_excerpt_options(&mut self, options: EditPredictionExcerptOptions) {
146 self.excerpt_options = options;
147 }
148
149 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
150 self.user_store.read(cx).edit_prediction_usage()
151 }
152
153 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
154 self.get_or_init_zeta_project(project, cx);
155 }
156
157 pub fn register_buffer(
158 &mut self,
159 buffer: &Entity<Buffer>,
160 project: &Entity<Project>,
161 cx: &mut Context<Self>,
162 ) {
163 let zeta_project = self.get_or_init_zeta_project(project, cx);
164 Self::register_buffer_impl(zeta_project, buffer, project, cx);
165 }
166
167 fn get_or_init_zeta_project(
168 &mut self,
169 project: &Entity<Project>,
170 cx: &mut App,
171 ) -> &mut ZetaProject {
172 self.projects
173 .entry(project.entity_id())
174 .or_insert_with(|| ZetaProject {
175 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
176 events: VecDeque::new(),
177 registered_buffers: HashMap::new(),
178 })
179 }
180
181 fn register_buffer_impl<'a>(
182 zeta_project: &'a mut ZetaProject,
183 buffer: &Entity<Buffer>,
184 project: &Entity<Project>,
185 cx: &mut Context<Self>,
186 ) -> &'a mut RegisteredBuffer {
187 let buffer_id = buffer.entity_id();
188 match zeta_project.registered_buffers.entry(buffer_id) {
189 hash_map::Entry::Occupied(entry) => entry.into_mut(),
190 hash_map::Entry::Vacant(entry) => {
191 let snapshot = buffer.read(cx).snapshot();
192 let project_entity_id = project.entity_id();
193 entry.insert(RegisteredBuffer {
194 snapshot,
195 _subscriptions: [
196 cx.subscribe(buffer, {
197 let project = project.downgrade();
198 move |this, buffer, event, cx| {
199 if let language::BufferEvent::Edited = event
200 && let Some(project) = project.upgrade()
201 {
202 this.report_changes_for_buffer(&buffer, &project, cx);
203 }
204 }
205 }),
206 cx.observe_release(buffer, move |this, _buffer, _cx| {
207 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
208 else {
209 return;
210 };
211 zeta_project.registered_buffers.remove(&buffer_id);
212 }),
213 ],
214 })
215 }
216 }
217 }
218
219 fn report_changes_for_buffer(
220 &mut self,
221 buffer: &Entity<Buffer>,
222 project: &Entity<Project>,
223 cx: &mut Context<Self>,
224 ) -> BufferSnapshot {
225 let zeta_project = self.get_or_init_zeta_project(project, cx);
226 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
227
228 let new_snapshot = buffer.read(cx).snapshot();
229 if new_snapshot.version != registered_buffer.snapshot.version {
230 let old_snapshot =
231 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
232 Self::push_event(
233 zeta_project,
234 Event::BufferChange {
235 old_snapshot,
236 new_snapshot: new_snapshot.clone(),
237 timestamp: Instant::now(),
238 },
239 );
240 }
241
242 new_snapshot
243 }
244
245 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
246 let events = &mut zeta_project.events;
247
248 if let Some(Event::BufferChange {
249 new_snapshot: last_new_snapshot,
250 timestamp: last_timestamp,
251 ..
252 }) = events.back_mut()
253 {
254 // Coalesce edits for the same buffer when they happen one after the other.
255 let Event::BufferChange {
256 old_snapshot,
257 new_snapshot,
258 timestamp,
259 } = &event;
260
261 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
262 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
263 && old_snapshot.version == last_new_snapshot.version
264 {
265 *last_new_snapshot = new_snapshot.clone();
266 *last_timestamp = *timestamp;
267 return;
268 }
269 }
270
271 if events.len() >= MAX_EVENT_COUNT {
272 // These are halved instead of popping to improve prompt caching.
273 events.drain(..MAX_EVENT_COUNT / 2);
274 }
275
276 events.push_back(event);
277 }
278
279 pub fn request_prediction(
280 &mut self,
281 project: &Entity<Project>,
282 buffer: &Entity<Buffer>,
283 position: language::Anchor,
284 cx: &mut Context<Self>,
285 ) -> Task<Result<Option<EditPrediction>>> {
286 let project_state = self.projects.get(&project.entity_id());
287
288 let index_state = project_state.map(|state| {
289 state
290 .syntax_index
291 .read_with(cx, |index, _cx| index.state().clone())
292 });
293 let excerpt_options = self.excerpt_options.clone();
294 let snapshot = buffer.read(cx).snapshot();
295 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
296 return Task::ready(Err(anyhow!("No file path for excerpt")));
297 };
298 let client = self.client.clone();
299 let llm_token = self.llm_token.clone();
300 let app_version = AppVersion::global(cx);
301 let worktree_snapshots = project
302 .read(cx)
303 .worktrees(cx)
304 .map(|worktree| worktree.read(cx).snapshot())
305 .collect::<Vec<_>>();
306 let debug_tx = self.debug_tx.clone();
307
308 let request_task = cx.background_spawn({
309 let snapshot = snapshot.clone();
310 let buffer = buffer.clone();
311 async move {
312 let index_state = if let Some(index_state) = index_state {
313 Some(index_state.lock_owned().await)
314 } else {
315 None
316 };
317
318 let cursor_point = position.to_point(&snapshot);
319
320 let before_retrieval = chrono::Utc::now();
321
322 let Some(context) = EditPredictionContext::gather_context(
323 cursor_point,
324 &snapshot,
325 &excerpt_options,
326 index_state.as_deref(),
327 ) else {
328 return Ok(None);
329 };
330
331 let debug_context = if let Some(debug_tx) = debug_tx {
332 Some((debug_tx, context.clone()))
333 } else {
334 None
335 };
336
337 let request = make_cloud_request(
338 excerpt_path.clone(),
339 context,
340 // TODO pass everything
341 Vec::new(),
342 false,
343 Vec::new(),
344 None,
345 debug_context.is_some(),
346 &worktree_snapshots,
347 index_state.as_deref(),
348 );
349
350 let retrieval_time = chrono::Utc::now() - before_retrieval;
351 let response = Self::perform_request(client, llm_token, app_version, request).await;
352
353 if let Some((debug_tx, context)) = debug_context {
354 debug_tx
355 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
356 |response| {
357 let Some(request) =
358 some_or_debug_panic(response.0.debug_info.clone())
359 else {
360 return Err("Missing debug info".to_string());
361 };
362 Ok(PredictionDebugInfo {
363 context,
364 request,
365 retrieval_time,
366 buffer: buffer.downgrade(),
367 position,
368 })
369 },
370 ))
371 .ok();
372 }
373
374 anyhow::Ok(Some(response?))
375 }
376 });
377
378 let buffer = buffer.clone();
379
380 cx.spawn(async move |this, cx| {
381 match request_task.await {
382 Ok(Some((response, usage))) => {
383 log::debug!("predicted edits: {:?}", &response.edits);
384
385 if let Some(usage) = usage {
386 this.update(cx, |this, cx| {
387 this.user_store.update(cx, |user_store, cx| {
388 user_store.update_edit_prediction_usage(usage, cx);
389 });
390 })
391 .ok();
392 }
393
394 // TODO telemetry: duration, etc
395
396 // TODO produce smaller edits by diffing against snapshot first
397 //
398 // Cloud returns entire snippets/excerpts ranges as they were included
399 // in the request, but we should display smaller edits to the user.
400 //
401 // We can do this by computing a diff of each one against the snapshot.
402 // Similar to zeta::Zeta::compute_edits, but per edit.
403 let edits = response
404 .edits
405 .into_iter()
406 .map(|edit| {
407 // TODO edits to different files
408 (
409 snapshot.anchor_before(edit.range.start)
410 ..snapshot.anchor_before(edit.range.end),
411 edit.content,
412 )
413 })
414 .collect::<Vec<_>>()
415 .into();
416
417 let Some((edits, snapshot, edit_preview_task)) =
418 buffer.read_with(cx, |buffer, cx| {
419 let new_snapshot = buffer.snapshot();
420 let edits: Arc<[_]> =
421 interpolate(&snapshot, &new_snapshot, edits)?.into();
422 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
423 })?
424 else {
425 return Ok(None);
426 };
427
428 Ok(Some(EditPrediction {
429 id: EditPredictionId(response.request_id),
430 edits,
431 snapshot,
432 edit_preview: edit_preview_task.await,
433 }))
434 }
435 Ok(None) => Ok(None),
436 Err(err) => {
437 if err.is::<ZedUpdateRequiredError>() {
438 cx.update(|cx| {
439 this.update(cx, |this, _cx| {
440 this.update_required = true;
441 })
442 .ok();
443
444 let error_message: SharedString = err.to_string().into();
445 show_app_notification(
446 NotificationId::unique::<ZedUpdateRequiredError>(),
447 cx,
448 move |cx| {
449 cx.new(|cx| {
450 ErrorMessagePrompt::new(error_message.clone(), cx)
451 .with_link_button(
452 "Update Zed",
453 "https://zed.dev/releases",
454 )
455 })
456 },
457 );
458 })
459 .ok();
460 }
461
462 Err(err)
463 }
464 }
465 })
466 }
467
468 async fn perform_request(
469 client: Arc<Client>,
470 llm_token: LlmApiToken,
471 app_version: SemanticVersion,
472 request: predict_edits_v3::PredictEditsRequest,
473 ) -> Result<(
474 predict_edits_v3::PredictEditsResponse,
475 Option<EditPredictionUsage>,
476 )> {
477 let http_client = client.http_client();
478 let mut token = llm_token.acquire(&client).await?;
479 let mut did_retry = false;
480
481 loop {
482 let request_builder = http_client::Request::builder().method(Method::POST);
483 let request_builder =
484 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
485 request_builder.uri(predict_edits_url)
486 } else {
487 request_builder.uri(
488 http_client
489 .build_zed_llm_url("/predict_edits/v3", &[])?
490 .as_ref(),
491 )
492 };
493 let request = request_builder
494 .header("Content-Type", "application/json")
495 .header("Authorization", format!("Bearer {}", token))
496 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
497 .body(serde_json::to_string(&request)?.into())?;
498
499 let mut response = http_client.send(request).await?;
500
501 if let Some(minimum_required_version) = response
502 .headers()
503 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
504 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
505 {
506 anyhow::ensure!(
507 app_version >= minimum_required_version,
508 ZedUpdateRequiredError {
509 minimum_version: minimum_required_version
510 }
511 );
512 }
513
514 if response.status().is_success() {
515 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
516
517 let mut body = Vec::new();
518 response.body_mut().read_to_end(&mut body).await?;
519 return Ok((serde_json::from_slice(&body)?, usage));
520 } else if !did_retry
521 && response
522 .headers()
523 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
524 .is_some()
525 {
526 did_retry = true;
527 token = llm_token.refresh(&client).await?;
528 } else {
529 let mut body = String::new();
530 response.body_mut().read_to_string(&mut body).await?;
531 anyhow::bail!(
532 "error predicting edits.\nStatus: {:?}\nBody: {}",
533 response.status(),
534 body
535 );
536 }
537 }
538 }
539
540 // TODO: Dedupe with similar code in request_prediction?
541 pub fn cloud_request_for_zeta_cli(
542 &mut self,
543 project: &Entity<Project>,
544 buffer: &Entity<Buffer>,
545 position: language::Anchor,
546 cx: &mut Context<Self>,
547 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
548 let project_state = self.projects.get(&project.entity_id());
549
550 let index_state = project_state.map(|state| {
551 state
552 .syntax_index
553 .read_with(cx, |index, _cx| index.state().clone())
554 });
555 let excerpt_options = self.excerpt_options.clone();
556 let snapshot = buffer.read(cx).snapshot();
557 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
558 return Task::ready(Err(anyhow!("No file path for excerpt")));
559 };
560 let worktree_snapshots = project
561 .read(cx)
562 .worktrees(cx)
563 .map(|worktree| worktree.read(cx).snapshot())
564 .collect::<Vec<_>>();
565
566 cx.background_spawn(async move {
567 let index_state = if let Some(index_state) = index_state {
568 Some(index_state.lock_owned().await)
569 } else {
570 None
571 };
572
573 let cursor_point = position.to_point(&snapshot);
574
575 let debug_info = true;
576 EditPredictionContext::gather_context(
577 cursor_point,
578 &snapshot,
579 &excerpt_options,
580 index_state.as_deref(),
581 )
582 .context("Failed to select excerpt")
583 .map(|context| {
584 make_cloud_request(
585 excerpt_path.clone(),
586 context,
587 // TODO pass everything
588 Vec::new(),
589 false,
590 Vec::new(),
591 None,
592 debug_info,
593 &worktree_snapshots,
594 index_state.as_deref(),
595 )
596 })
597 })
598 }
599}
600
601#[derive(Error, Debug)]
602#[error(
603 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
604)]
605pub struct ZedUpdateRequiredError {
606 minimum_version: SemanticVersion,
607}
608
609pub struct ZetaEditPredictionProvider {
610 zeta: Entity<Zeta>,
611 current_prediction: Option<CurrentEditPrediction>,
612 next_pending_prediction_id: usize,
613 pending_predictions: ArrayVec<PendingPrediction, 2>,
614 last_request_timestamp: Instant,
615}
616
617impl ZetaEditPredictionProvider {
618 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
619
620 pub fn new(
621 project: Option<&Entity<Project>>,
622 client: &Arc<Client>,
623 user_store: &Entity<UserStore>,
624 cx: &mut App,
625 ) -> Self {
626 let zeta = Zeta::global(client, user_store, cx);
627 if let Some(project) = project {
628 zeta.update(cx, |zeta, cx| {
629 zeta.register_project(project, cx);
630 });
631 }
632
633 Self {
634 zeta,
635 current_prediction: None,
636 next_pending_prediction_id: 0,
637 pending_predictions: ArrayVec::new(),
638 last_request_timestamp: Instant::now(),
639 }
640 }
641}
642
643#[derive(Clone)]
644struct CurrentEditPrediction {
645 buffer_id: EntityId,
646 prediction: EditPrediction,
647}
648
649impl CurrentEditPrediction {
650 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
651 if self.buffer_id != old_prediction.buffer_id {
652 return true;
653 }
654
655 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
656 return true;
657 };
658 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
659 return false;
660 };
661
662 if old_edits.len() == 1 && new_edits.len() == 1 {
663 let (old_range, old_text) = &old_edits[0];
664 let (new_range, new_text) = &new_edits[0];
665 new_range == old_range && new_text.starts_with(old_text)
666 } else {
667 true
668 }
669 }
670}
671
672#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
673pub struct EditPredictionId(Uuid);
674
675impl From<EditPredictionId> for gpui::ElementId {
676 fn from(value: EditPredictionId) -> Self {
677 gpui::ElementId::Uuid(value.0)
678 }
679}
680
681impl std::fmt::Display for EditPredictionId {
682 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
683 write!(f, "{}", self.0)
684 }
685}
686
687#[derive(Clone)]
688pub struct EditPrediction {
689 id: EditPredictionId,
690 edits: Arc<[(Range<Anchor>, String)]>,
691 snapshot: BufferSnapshot,
692 edit_preview: EditPreview,
693}
694
695impl EditPrediction {
696 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
697 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
698 }
699}
700
701struct PendingPrediction {
702 id: usize,
703 _task: Task<()>,
704}
705
706impl EditPredictionProvider for ZetaEditPredictionProvider {
707 fn name() -> &'static str {
708 "zed-predict2"
709 }
710
711 fn display_name() -> &'static str {
712 "Zed's Edit Predictions 2"
713 }
714
715 fn show_completions_in_menu() -> bool {
716 true
717 }
718
719 fn show_tab_accept_marker() -> bool {
720 true
721 }
722
723 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
724 // TODO [zeta2]
725 DataCollectionState::Unsupported
726 }
727
728 fn toggle_data_collection(&mut self, _cx: &mut App) {
729 // TODO [zeta2]
730 }
731
732 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
733 self.zeta.read(cx).usage(cx)
734 }
735
736 fn is_enabled(
737 &self,
738 _buffer: &Entity<language::Buffer>,
739 _cursor_position: language::Anchor,
740 _cx: &App,
741 ) -> bool {
742 true
743 }
744
745 fn is_refreshing(&self) -> bool {
746 !self.pending_predictions.is_empty()
747 }
748
749 fn refresh(
750 &mut self,
751 project: Option<Entity<project::Project>>,
752 buffer: Entity<language::Buffer>,
753 cursor_position: language::Anchor,
754 _debounce: bool,
755 cx: &mut Context<Self>,
756 ) {
757 let Some(project) = project else {
758 return;
759 };
760
761 if self
762 .zeta
763 .read(cx)
764 .user_store
765 .read_with(cx, |user_store, _cx| {
766 user_store.account_too_young() || user_store.has_overdue_invoices()
767 })
768 {
769 return;
770 }
771
772 if let Some(current_prediction) = self.current_prediction.as_ref() {
773 let snapshot = buffer.read(cx).snapshot();
774 if current_prediction
775 .prediction
776 .interpolate(&snapshot)
777 .is_some()
778 {
779 return;
780 }
781 }
782
783 let pending_prediction_id = self.next_pending_prediction_id;
784 self.next_pending_prediction_id += 1;
785 let last_request_timestamp = self.last_request_timestamp;
786
787 let task = cx.spawn(async move |this, cx| {
788 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
789 .checked_duration_since(Instant::now())
790 {
791 cx.background_executor().timer(timeout).await;
792 }
793
794 let prediction_request = this.update(cx, |this, cx| {
795 this.last_request_timestamp = Instant::now();
796 this.zeta.update(cx, |zeta, cx| {
797 zeta.request_prediction(&project, &buffer, cursor_position, cx)
798 })
799 });
800
801 let prediction = match prediction_request {
802 Ok(prediction_request) => {
803 let prediction_request = prediction_request.await;
804 prediction_request.map(|c| {
805 c.map(|prediction| CurrentEditPrediction {
806 buffer_id: buffer.entity_id(),
807 prediction,
808 })
809 })
810 }
811 Err(error) => Err(error),
812 };
813
814 this.update(cx, |this, cx| {
815 if this.pending_predictions[0].id == pending_prediction_id {
816 this.pending_predictions.remove(0);
817 } else {
818 this.pending_predictions.clear();
819 }
820
821 let Some(new_prediction) = prediction
822 .context("edit prediction failed")
823 .log_err()
824 .flatten()
825 else {
826 cx.notify();
827 return;
828 };
829
830 if let Some(old_prediction) = this.current_prediction.as_ref() {
831 let snapshot = buffer.read(cx).snapshot();
832 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
833 this.current_prediction = Some(new_prediction);
834 }
835 } else {
836 this.current_prediction = Some(new_prediction);
837 }
838
839 cx.notify();
840 })
841 .ok();
842 });
843
844 // We always maintain at most two pending predictions. When we already
845 // have two, we replace the newest one.
846 if self.pending_predictions.len() <= 1 {
847 self.pending_predictions.push(PendingPrediction {
848 id: pending_prediction_id,
849 _task: task,
850 });
851 } else if self.pending_predictions.len() == 2 {
852 self.pending_predictions.pop();
853 self.pending_predictions.push(PendingPrediction {
854 id: pending_prediction_id,
855 _task: task,
856 });
857 }
858
859 cx.notify();
860 }
861
862 fn cycle(
863 &mut self,
864 _buffer: Entity<language::Buffer>,
865 _cursor_position: language::Anchor,
866 _direction: Direction,
867 _cx: &mut Context<Self>,
868 ) {
869 }
870
871 fn accept(&mut self, _cx: &mut Context<Self>) {
872 // TODO [zeta2] report accept
873 self.current_prediction.take();
874 self.pending_predictions.clear();
875 }
876
877 fn discard(&mut self, _cx: &mut Context<Self>) {
878 self.pending_predictions.clear();
879 self.current_prediction.take();
880 }
881
882 fn suggest(
883 &mut self,
884 buffer: &Entity<language::Buffer>,
885 cursor_position: language::Anchor,
886 cx: &mut Context<Self>,
887 ) -> Option<edit_prediction::EditPrediction> {
888 let CurrentEditPrediction {
889 buffer_id,
890 prediction,
891 ..
892 } = self.current_prediction.as_mut()?;
893
894 // Invalidate previous prediction if it was generated for a different buffer.
895 if *buffer_id != buffer.entity_id() {
896 self.current_prediction.take();
897 return None;
898 }
899
900 let buffer = buffer.read(cx);
901 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
902 self.current_prediction.take();
903 return None;
904 };
905
906 let cursor_row = cursor_position.to_point(buffer).row;
907 let (closest_edit_ix, (closest_edit_range, _)) =
908 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
909 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
910 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
911 cmp::min(distance_from_start, distance_from_end)
912 })?;
913
914 let mut edit_start_ix = closest_edit_ix;
915 for (range, _) in edits[..edit_start_ix].iter().rev() {
916 let distance_from_closest_edit =
917 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
918 if distance_from_closest_edit <= 1 {
919 edit_start_ix -= 1;
920 } else {
921 break;
922 }
923 }
924
925 let mut edit_end_ix = closest_edit_ix + 1;
926 for (range, _) in &edits[edit_end_ix..] {
927 let distance_from_closest_edit =
928 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
929 if distance_from_closest_edit <= 1 {
930 edit_end_ix += 1;
931 } else {
932 break;
933 }
934 }
935
936 Some(edit_prediction::EditPrediction {
937 id: Some(prediction.id.to_string().into()),
938 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
939 edit_preview: Some(prediction.edit_preview.clone()),
940 })
941 }
942}
943
944fn make_cloud_request(
945 excerpt_path: PathBuf,
946 context: EditPredictionContext,
947 events: Vec<predict_edits_v3::Event>,
948 can_collect_data: bool,
949 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
950 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
951 debug_info: bool,
952 worktrees: &Vec<worktree::Snapshot>,
953 index_state: Option<&SyntaxIndexState>,
954) -> predict_edits_v3::PredictEditsRequest {
955 let mut signatures = Vec::new();
956 let mut declaration_to_signature_index = HashMap::default();
957 let mut referenced_declarations = Vec::new();
958
959 for snippet in context.snippets {
960 let project_entry_id = snippet.declaration.project_entry_id();
961 let Some(path) = worktrees.iter().find_map(|worktree| {
962 worktree.entry_for_id(project_entry_id).map(|entry| {
963 let mut full_path = PathBuf::new();
964 full_path.push(worktree.root_name());
965 full_path.push(&entry.path);
966 full_path
967 })
968 }) else {
969 continue;
970 };
971
972 let parent_index = index_state.and_then(|index_state| {
973 snippet.declaration.parent().and_then(|parent| {
974 add_signature(
975 parent,
976 &mut declaration_to_signature_index,
977 &mut signatures,
978 index_state,
979 )
980 })
981 });
982
983 let (text, text_is_truncated) = snippet.declaration.item_text();
984 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
985 path,
986 text: text.into(),
987 range: snippet.declaration.item_range(),
988 text_is_truncated,
989 signature_range: snippet.declaration.signature_range_in_item_text(),
990 parent_index,
991 score_components: snippet.score_components,
992 signature_score: snippet.scores.signature,
993 declaration_score: snippet.scores.declaration,
994 });
995 }
996
997 let excerpt_parent = index_state.and_then(|index_state| {
998 context
999 .excerpt
1000 .parent_declarations
1001 .last()
1002 .and_then(|(parent, _)| {
1003 add_signature(
1004 *parent,
1005 &mut declaration_to_signature_index,
1006 &mut signatures,
1007 index_state,
1008 )
1009 })
1010 });
1011
1012 predict_edits_v3::PredictEditsRequest {
1013 excerpt_path,
1014 excerpt: context.excerpt_text.body,
1015 excerpt_range: context.excerpt.range,
1016 cursor_offset: context.cursor_offset_in_excerpt,
1017 referenced_declarations,
1018 signatures,
1019 excerpt_parent,
1020 events,
1021 can_collect_data,
1022 diagnostic_groups,
1023 git_info,
1024 debug_info,
1025 }
1026}
1027
1028fn add_signature(
1029 declaration_id: DeclarationId,
1030 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1031 signatures: &mut Vec<Signature>,
1032 index: &SyntaxIndexState,
1033) -> Option<usize> {
1034 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1035 return Some(*signature_index);
1036 }
1037 let Some(parent_declaration) = index.declaration(declaration_id) else {
1038 log::error!("bug: missing parent declaration");
1039 return None;
1040 };
1041 let parent_index = parent_declaration.parent().and_then(|parent| {
1042 add_signature(parent, declaration_to_signature_index, signatures, index)
1043 });
1044 let (text, text_is_truncated) = parent_declaration.signature_text();
1045 let signature_index = signatures.len();
1046 signatures.push(Signature {
1047 text: text.into(),
1048 text_is_truncated,
1049 parent_index,
1050 range: parent_declaration.signature_range(),
1051 });
1052 declaration_to_signature_index.insert(declaration_id, signature_index);
1053 Some(signature_index)
1054}
1055
1056fn interpolate(
1057 old_snapshot: &BufferSnapshot,
1058 new_snapshot: &BufferSnapshot,
1059 current_edits: Arc<[(Range<Anchor>, String)]>,
1060) -> Option<Vec<(Range<Anchor>, String)>> {
1061 let mut edits = Vec::new();
1062
1063 let mut model_edits = current_edits.iter().peekable();
1064 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1065 while let Some((model_old_range, _)) = model_edits.peek() {
1066 let model_old_range = model_old_range.to_offset(old_snapshot);
1067 if model_old_range.end < user_edit.old.start {
1068 let (model_old_range, model_new_text) = model_edits.next().unwrap();
1069 edits.push((model_old_range.clone(), model_new_text.clone()));
1070 } else {
1071 break;
1072 }
1073 }
1074
1075 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1076 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1077 if user_edit.old == model_old_offset_range {
1078 let user_new_text = new_snapshot
1079 .text_for_range(user_edit.new.clone())
1080 .collect::<String>();
1081
1082 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1083 if !model_suffix.is_empty() {
1084 let anchor = old_snapshot.anchor_after(user_edit.old.end);
1085 edits.push((anchor..anchor, model_suffix.to_string()));
1086 }
1087
1088 model_edits.next();
1089 continue;
1090 }
1091 }
1092 }
1093
1094 return None;
1095 }
1096
1097 edits.extend(model_edits.cloned());
1098
1099 if edits.is_empty() { None } else { Some(edits) }
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104 use super::*;
1105 use gpui::TestAppContext;
1106 use language::ToOffset as _;
1107
1108 #[gpui::test]
1109 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1110 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1111 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1112 to_prediction_edits(
1113 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1114 &buffer,
1115 cx,
1116 )
1117 .into()
1118 });
1119
1120 let edit_preview = cx
1121 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1122 .await;
1123
1124 let prediction = EditPrediction {
1125 id: EditPredictionId(Uuid::new_v4()),
1126 edits,
1127 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1128 edit_preview,
1129 };
1130
1131 cx.update(|cx| {
1132 assert_eq!(
1133 from_prediction_edits(
1134 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1135 &buffer,
1136 cx
1137 ),
1138 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1139 );
1140
1141 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1142 assert_eq!(
1143 from_prediction_edits(
1144 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1145 &buffer,
1146 cx
1147 ),
1148 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1149 );
1150
1151 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1152 assert_eq!(
1153 from_prediction_edits(
1154 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1155 &buffer,
1156 cx
1157 ),
1158 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1159 );
1160
1161 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1162 assert_eq!(
1163 from_prediction_edits(
1164 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1165 &buffer,
1166 cx
1167 ),
1168 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1169 );
1170
1171 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1172 assert_eq!(
1173 from_prediction_edits(
1174 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1175 &buffer,
1176 cx
1177 ),
1178 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1179 );
1180
1181 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1182 assert_eq!(
1183 from_prediction_edits(
1184 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1185 &buffer,
1186 cx
1187 ),
1188 vec![(9..11, "".to_string())]
1189 );
1190
1191 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1192 assert_eq!(
1193 from_prediction_edits(
1194 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1195 &buffer,
1196 cx
1197 ),
1198 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1199 );
1200
1201 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1202 assert_eq!(
1203 from_prediction_edits(
1204 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1205 &buffer,
1206 cx
1207 ),
1208 vec![(4..4, "M".to_string())]
1209 );
1210
1211 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1212 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1213 })
1214 }
1215
1216 fn to_prediction_edits(
1217 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1218 buffer: &Entity<Buffer>,
1219 cx: &App,
1220 ) -> Vec<(Range<Anchor>, String)> {
1221 let buffer = buffer.read(cx);
1222 iterator
1223 .into_iter()
1224 .map(|(range, text)| {
1225 (
1226 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1227 text,
1228 )
1229 })
1230 .collect()
1231 }
1232
1233 fn from_prediction_edits(
1234 editor_edits: &[(Range<Anchor>, String)],
1235 buffer: &Entity<Buffer>,
1236 cx: &App,
1237 ) -> Vec<(Range<usize>, String)> {
1238 let buffer = buffer.read(cx);
1239 editor_edits
1240 .iter()
1241 .map(|(range, text)| {
1242 (
1243 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1244 text.clone(),
1245 )
1246 })
1247 .collect()
1248 }
1249}