1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use chrono::TimeDelta;
4use client::{Client, EditPredictionUsage, UserStore};
5use cloud_llm_client::predict_edits_v3::{self, Signature};
6use cloud_llm_client::{
7 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
8};
9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
10use edit_prediction_context::{
11 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
12 SyntaxIndexState,
13};
14use futures::AsyncReadExt as _;
15use futures::channel::mpsc;
16use gpui::http_client::Method;
17use gpui::{
18 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
19 http_client, prelude::*,
20};
21use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
22use language::{BufferSnapshot, EditPreview};
23use language_model::{LlmApiToken, RefreshLlmTokenListener};
24use project::Project;
25use release_channel::AppVersion;
26use std::cmp;
27use std::collections::{HashMap, VecDeque, hash_map};
28use std::path::PathBuf;
29use std::str::FromStr as _;
30use std::time::{Duration, Instant};
31use std::{ops::Range, sync::Arc};
32use thiserror::Error;
33use util::{ResultExt as _, some_or_debug_panic};
34use uuid::Uuid;
35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
36
37const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
38
39/// Maximum number of events to track.
40const MAX_EVENT_COUNT: usize = 16;
41
42pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
43 max_bytes: 512,
44 min_bytes: 128,
45 target_before_cursor_over_total_bytes: 0.5,
46};
47
48#[derive(Clone)]
49struct ZetaGlobal(Entity<Zeta>);
50
51impl Global for ZetaGlobal {}
52
53pub struct Zeta {
54 client: Arc<Client>,
55 user_store: Entity<UserStore>,
56 llm_token: LlmApiToken,
57 _llm_token_subscription: Subscription,
58 projects: HashMap<EntityId, ZetaProject>,
59 pub excerpt_options: EditPredictionExcerptOptions,
60 update_required: bool,
61 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
62}
63
64pub struct PredictionDebugInfo {
65 pub context: EditPredictionContext,
66 pub retrieval_time: TimeDelta,
67 pub request: RequestDebugInfo,
68 pub buffer: WeakEntity<Buffer>,
69 pub position: language::Anchor,
70}
71
72pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
73
74struct ZetaProject {
75 syntax_index: Entity<SyntaxIndex>,
76 events: VecDeque<Event>,
77 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
78}
79
80struct RegisteredBuffer {
81 snapshot: BufferSnapshot,
82 _subscriptions: [gpui::Subscription; 2],
83}
84
85#[derive(Clone)]
86pub enum Event {
87 BufferChange {
88 old_snapshot: BufferSnapshot,
89 new_snapshot: BufferSnapshot,
90 timestamp: Instant,
91 },
92}
93
94impl Zeta {
95 pub fn global(
96 client: &Arc<Client>,
97 user_store: &Entity<UserStore>,
98 cx: &mut App,
99 ) -> Entity<Self> {
100 cx.try_global::<ZetaGlobal>()
101 .map(|global| global.0.clone())
102 .unwrap_or_else(|| {
103 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
104 cx.set_global(ZetaGlobal(zeta.clone()));
105 zeta
106 })
107 }
108
109 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
110 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
111
112 Self {
113 projects: HashMap::new(),
114 client,
115 user_store,
116 excerpt_options: DEFAULT_EXCERPT_OPTIONS,
117 llm_token: LlmApiToken::default(),
118 _llm_token_subscription: cx.subscribe(
119 &refresh_llm_token_listener,
120 |this, _listener, _event, cx| {
121 let client = this.client.clone();
122 let llm_token = this.llm_token.clone();
123 cx.spawn(async move |_this, _cx| {
124 llm_token.refresh(&client).await?;
125 anyhow::Ok(())
126 })
127 .detach_and_log_err(cx);
128 },
129 ),
130 update_required: false,
131 debug_tx: None,
132 }
133 }
134
135 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
136 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
137 self.debug_tx = Some(debug_watch_tx);
138 debug_watch_rx
139 }
140
141 pub fn excerpt_options(&self) -> &EditPredictionExcerptOptions {
142 &self.excerpt_options
143 }
144
145 pub fn set_excerpt_options(&mut self, options: EditPredictionExcerptOptions) {
146 self.excerpt_options = options;
147 }
148
149 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
150 self.user_store.read(cx).edit_prediction_usage()
151 }
152
153 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
154 self.get_or_init_zeta_project(project, cx);
155 }
156
157 pub fn register_buffer(
158 &mut self,
159 buffer: &Entity<Buffer>,
160 project: &Entity<Project>,
161 cx: &mut Context<Self>,
162 ) {
163 let zeta_project = self.get_or_init_zeta_project(project, cx);
164 Self::register_buffer_impl(zeta_project, buffer, project, cx);
165 }
166
167 fn get_or_init_zeta_project(
168 &mut self,
169 project: &Entity<Project>,
170 cx: &mut App,
171 ) -> &mut ZetaProject {
172 self.projects
173 .entry(project.entity_id())
174 .or_insert_with(|| ZetaProject {
175 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
176 events: VecDeque::new(),
177 registered_buffers: HashMap::new(),
178 })
179 }
180
181 fn register_buffer_impl<'a>(
182 zeta_project: &'a mut ZetaProject,
183 buffer: &Entity<Buffer>,
184 project: &Entity<Project>,
185 cx: &mut Context<Self>,
186 ) -> &'a mut RegisteredBuffer {
187 let buffer_id = buffer.entity_id();
188 match zeta_project.registered_buffers.entry(buffer_id) {
189 hash_map::Entry::Occupied(entry) => entry.into_mut(),
190 hash_map::Entry::Vacant(entry) => {
191 let snapshot = buffer.read(cx).snapshot();
192 let project_entity_id = project.entity_id();
193 entry.insert(RegisteredBuffer {
194 snapshot,
195 _subscriptions: [
196 cx.subscribe(buffer, {
197 let project = project.downgrade();
198 move |this, buffer, event, cx| {
199 if let language::BufferEvent::Edited = event
200 && let Some(project) = project.upgrade()
201 {
202 this.report_changes_for_buffer(&buffer, &project, cx);
203 }
204 }
205 }),
206 cx.observe_release(buffer, move |this, _buffer, _cx| {
207 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
208 else {
209 return;
210 };
211 zeta_project.registered_buffers.remove(&buffer_id);
212 }),
213 ],
214 })
215 }
216 }
217 }
218
219 fn report_changes_for_buffer(
220 &mut self,
221 buffer: &Entity<Buffer>,
222 project: &Entity<Project>,
223 cx: &mut Context<Self>,
224 ) -> BufferSnapshot {
225 let zeta_project = self.get_or_init_zeta_project(project, cx);
226 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
227
228 let new_snapshot = buffer.read(cx).snapshot();
229 if new_snapshot.version != registered_buffer.snapshot.version {
230 let old_snapshot =
231 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
232 Self::push_event(
233 zeta_project,
234 Event::BufferChange {
235 old_snapshot,
236 new_snapshot: new_snapshot.clone(),
237 timestamp: Instant::now(),
238 },
239 );
240 }
241
242 new_snapshot
243 }
244
245 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
246 let events = &mut zeta_project.events;
247
248 if let Some(Event::BufferChange {
249 new_snapshot: last_new_snapshot,
250 timestamp: last_timestamp,
251 ..
252 }) = events.back_mut()
253 {
254 // Coalesce edits for the same buffer when they happen one after the other.
255 let Event::BufferChange {
256 old_snapshot,
257 new_snapshot,
258 timestamp,
259 } = &event;
260
261 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
262 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
263 && old_snapshot.version == last_new_snapshot.version
264 {
265 *last_new_snapshot = new_snapshot.clone();
266 *last_timestamp = *timestamp;
267 return;
268 }
269 }
270
271 if events.len() >= MAX_EVENT_COUNT {
272 // These are halved instead of popping to improve prompt caching.
273 events.drain(..MAX_EVENT_COUNT / 2);
274 }
275
276 events.push_back(event);
277 }
278
279 pub fn request_prediction(
280 &mut self,
281 project: &Entity<Project>,
282 buffer: &Entity<Buffer>,
283 position: language::Anchor,
284 cx: &mut Context<Self>,
285 ) -> Task<Result<Option<EditPrediction>>> {
286 let project_state = self.projects.get(&project.entity_id());
287
288 let index_state = project_state.map(|state| {
289 state
290 .syntax_index
291 .read_with(cx, |index, _cx| index.state().clone())
292 });
293 let excerpt_options = self.excerpt_options.clone();
294 let snapshot = buffer.read(cx).snapshot();
295 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
296 return Task::ready(Err(anyhow!("No file path for excerpt")));
297 };
298 let client = self.client.clone();
299 let llm_token = self.llm_token.clone();
300 let app_version = AppVersion::global(cx);
301 let worktree_snapshots = project
302 .read(cx)
303 .worktrees(cx)
304 .map(|worktree| worktree.read(cx).snapshot())
305 .collect::<Vec<_>>();
306 let debug_tx = self.debug_tx.clone();
307
308 let events = project_state
309 .map(|state| {
310 state
311 .events
312 .iter()
313 .map(|event| match event {
314 Event::BufferChange {
315 old_snapshot,
316 new_snapshot,
317 ..
318 } => {
319 let path = new_snapshot.file().map(|f| f.path().to_path_buf());
320
321 let old_path = old_snapshot.file().and_then(|f| {
322 let old_path = f.path().as_ref();
323 if Some(old_path) != path.as_deref() {
324 Some(old_path.to_path_buf())
325 } else {
326 None
327 }
328 });
329
330 predict_edits_v3::Event::BufferChange {
331 old_path,
332 path,
333 diff: language::unified_diff(
334 &old_snapshot.text(),
335 &new_snapshot.text(),
336 ),
337 //todo: Actually detect if this edit was predicted or not
338 predicted: false,
339 }
340 }
341 })
342 .collect::<Vec<_>>()
343 })
344 .unwrap_or_default();
345
346 let request_task = cx.background_spawn({
347 let snapshot = snapshot.clone();
348 let buffer = buffer.clone();
349 async move {
350 let index_state = if let Some(index_state) = index_state {
351 Some(index_state.lock_owned().await)
352 } else {
353 None
354 };
355
356 let cursor_point = position.to_point(&snapshot);
357
358 let before_retrieval = chrono::Utc::now();
359
360 let Some(context) = EditPredictionContext::gather_context(
361 cursor_point,
362 &snapshot,
363 &excerpt_options,
364 index_state.as_deref(),
365 ) else {
366 return Ok(None);
367 };
368
369 let debug_context = if let Some(debug_tx) = debug_tx {
370 Some((debug_tx, context.clone()))
371 } else {
372 None
373 };
374
375 let request = make_cloud_request(
376 excerpt_path.clone(),
377 context,
378 events,
379 // TODO data collection
380 false,
381 Vec::new(),
382 None,
383 debug_context.is_some(),
384 &worktree_snapshots,
385 index_state.as_deref(),
386 );
387
388 let retrieval_time = chrono::Utc::now() - before_retrieval;
389 let response = Self::perform_request(client, llm_token, app_version, request).await;
390
391 if let Some((debug_tx, context)) = debug_context {
392 debug_tx
393 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
394 |response| {
395 let Some(request) =
396 some_or_debug_panic(response.0.debug_info.clone())
397 else {
398 return Err("Missing debug info".to_string());
399 };
400 Ok(PredictionDebugInfo {
401 context,
402 request,
403 retrieval_time,
404 buffer: buffer.downgrade(),
405 position,
406 })
407 },
408 ))
409 .ok();
410 }
411
412 anyhow::Ok(Some(response?))
413 }
414 });
415
416 let buffer = buffer.clone();
417
418 cx.spawn(async move |this, cx| {
419 match request_task.await {
420 Ok(Some((response, usage))) => {
421 log::debug!("predicted edits: {:?}", &response.edits);
422
423 if let Some(usage) = usage {
424 this.update(cx, |this, cx| {
425 this.user_store.update(cx, |user_store, cx| {
426 user_store.update_edit_prediction_usage(usage, cx);
427 });
428 })
429 .ok();
430 }
431
432 // TODO telemetry: duration, etc
433
434 // TODO produce smaller edits by diffing against snapshot first
435 //
436 // Cloud returns entire snippets/excerpts ranges as they were included
437 // in the request, but we should display smaller edits to the user.
438 //
439 // We can do this by computing a diff of each one against the snapshot.
440 // Similar to zeta::Zeta::compute_edits, but per edit.
441 let edits = response
442 .edits
443 .into_iter()
444 .map(|edit| {
445 // TODO edits to different files
446 (
447 snapshot.anchor_before(edit.range.start)
448 ..snapshot.anchor_before(edit.range.end),
449 edit.content,
450 )
451 })
452 .collect::<Vec<_>>()
453 .into();
454
455 let Some((edits, snapshot, edit_preview_task)) =
456 buffer.read_with(cx, |buffer, cx| {
457 let new_snapshot = buffer.snapshot();
458 let edits: Arc<[_]> =
459 interpolate(&snapshot, &new_snapshot, edits)?.into();
460 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
461 })?
462 else {
463 return Ok(None);
464 };
465
466 Ok(Some(EditPrediction {
467 id: EditPredictionId(response.request_id),
468 edits,
469 snapshot,
470 edit_preview: edit_preview_task.await,
471 }))
472 }
473 Ok(None) => Ok(None),
474 Err(err) => {
475 if err.is::<ZedUpdateRequiredError>() {
476 cx.update(|cx| {
477 this.update(cx, |this, _cx| {
478 this.update_required = true;
479 })
480 .ok();
481
482 let error_message: SharedString = err.to_string().into();
483 show_app_notification(
484 NotificationId::unique::<ZedUpdateRequiredError>(),
485 cx,
486 move |cx| {
487 cx.new(|cx| {
488 ErrorMessagePrompt::new(error_message.clone(), cx)
489 .with_link_button(
490 "Update Zed",
491 "https://zed.dev/releases",
492 )
493 })
494 },
495 );
496 })
497 .ok();
498 }
499
500 Err(err)
501 }
502 }
503 })
504 }
505
506 async fn perform_request(
507 client: Arc<Client>,
508 llm_token: LlmApiToken,
509 app_version: SemanticVersion,
510 request: predict_edits_v3::PredictEditsRequest,
511 ) -> Result<(
512 predict_edits_v3::PredictEditsResponse,
513 Option<EditPredictionUsage>,
514 )> {
515 let http_client = client.http_client();
516 let mut token = llm_token.acquire(&client).await?;
517 let mut did_retry = false;
518
519 loop {
520 let request_builder = http_client::Request::builder().method(Method::POST);
521 let request_builder =
522 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
523 request_builder.uri(predict_edits_url)
524 } else {
525 request_builder.uri(
526 http_client
527 .build_zed_llm_url("/predict_edits/v3", &[])?
528 .as_ref(),
529 )
530 };
531 let request = request_builder
532 .header("Content-Type", "application/json")
533 .header("Authorization", format!("Bearer {}", token))
534 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
535 .body(serde_json::to_string(&request)?.into())?;
536
537 let mut response = http_client.send(request).await?;
538
539 if let Some(minimum_required_version) = response
540 .headers()
541 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
542 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
543 {
544 anyhow::ensure!(
545 app_version >= minimum_required_version,
546 ZedUpdateRequiredError {
547 minimum_version: minimum_required_version
548 }
549 );
550 }
551
552 if response.status().is_success() {
553 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
554
555 let mut body = Vec::new();
556 response.body_mut().read_to_end(&mut body).await?;
557 return Ok((serde_json::from_slice(&body)?, usage));
558 } else if !did_retry
559 && response
560 .headers()
561 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
562 .is_some()
563 {
564 did_retry = true;
565 token = llm_token.refresh(&client).await?;
566 } else {
567 let mut body = String::new();
568 response.body_mut().read_to_string(&mut body).await?;
569 anyhow::bail!(
570 "error predicting edits.\nStatus: {:?}\nBody: {}",
571 response.status(),
572 body
573 );
574 }
575 }
576 }
577
578 // TODO: Dedupe with similar code in request_prediction?
579 pub fn cloud_request_for_zeta_cli(
580 &mut self,
581 project: &Entity<Project>,
582 buffer: &Entity<Buffer>,
583 position: language::Anchor,
584 cx: &mut Context<Self>,
585 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
586 let project_state = self.projects.get(&project.entity_id());
587
588 let index_state = project_state.map(|state| {
589 state
590 .syntax_index
591 .read_with(cx, |index, _cx| index.state().clone())
592 });
593 let excerpt_options = self.excerpt_options.clone();
594 let snapshot = buffer.read(cx).snapshot();
595 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
596 return Task::ready(Err(anyhow!("No file path for excerpt")));
597 };
598 let worktree_snapshots = project
599 .read(cx)
600 .worktrees(cx)
601 .map(|worktree| worktree.read(cx).snapshot())
602 .collect::<Vec<_>>();
603
604 cx.background_spawn(async move {
605 let index_state = if let Some(index_state) = index_state {
606 Some(index_state.lock_owned().await)
607 } else {
608 None
609 };
610
611 let cursor_point = position.to_point(&snapshot);
612
613 let debug_info = true;
614 EditPredictionContext::gather_context(
615 cursor_point,
616 &snapshot,
617 &excerpt_options,
618 index_state.as_deref(),
619 )
620 .context("Failed to select excerpt")
621 .map(|context| {
622 make_cloud_request(
623 excerpt_path.clone(),
624 context,
625 // TODO pass everything
626 Vec::new(),
627 false,
628 Vec::new(),
629 None,
630 debug_info,
631 &worktree_snapshots,
632 index_state.as_deref(),
633 )
634 })
635 })
636 }
637}
638
639#[derive(Error, Debug)]
640#[error(
641 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
642)]
643pub struct ZedUpdateRequiredError {
644 minimum_version: SemanticVersion,
645}
646
647pub struct ZetaEditPredictionProvider {
648 zeta: Entity<Zeta>,
649 current_prediction: Option<CurrentEditPrediction>,
650 next_pending_prediction_id: usize,
651 pending_predictions: ArrayVec<PendingPrediction, 2>,
652 last_request_timestamp: Instant,
653}
654
655impl ZetaEditPredictionProvider {
656 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
657
658 pub fn new(
659 project: Option<&Entity<Project>>,
660 client: &Arc<Client>,
661 user_store: &Entity<UserStore>,
662 cx: &mut App,
663 ) -> Self {
664 let zeta = Zeta::global(client, user_store, cx);
665 if let Some(project) = project {
666 zeta.update(cx, |zeta, cx| {
667 zeta.register_project(project, cx);
668 });
669 }
670
671 Self {
672 zeta,
673 current_prediction: None,
674 next_pending_prediction_id: 0,
675 pending_predictions: ArrayVec::new(),
676 last_request_timestamp: Instant::now(),
677 }
678 }
679}
680
681#[derive(Clone)]
682struct CurrentEditPrediction {
683 buffer_id: EntityId,
684 prediction: EditPrediction,
685}
686
687impl CurrentEditPrediction {
688 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
689 if self.buffer_id != old_prediction.buffer_id {
690 return true;
691 }
692
693 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
694 return true;
695 };
696 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
697 return false;
698 };
699
700 if old_edits.len() == 1 && new_edits.len() == 1 {
701 let (old_range, old_text) = &old_edits[0];
702 let (new_range, new_text) = &new_edits[0];
703 new_range == old_range && new_text.starts_with(old_text)
704 } else {
705 true
706 }
707 }
708}
709
710#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
711pub struct EditPredictionId(Uuid);
712
713impl From<EditPredictionId> for gpui::ElementId {
714 fn from(value: EditPredictionId) -> Self {
715 gpui::ElementId::Uuid(value.0)
716 }
717}
718
719impl std::fmt::Display for EditPredictionId {
720 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
721 write!(f, "{}", self.0)
722 }
723}
724
725#[derive(Clone)]
726pub struct EditPrediction {
727 id: EditPredictionId,
728 edits: Arc<[(Range<Anchor>, String)]>,
729 snapshot: BufferSnapshot,
730 edit_preview: EditPreview,
731}
732
733impl EditPrediction {
734 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
735 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
736 }
737}
738
739struct PendingPrediction {
740 id: usize,
741 _task: Task<()>,
742}
743
744impl EditPredictionProvider for ZetaEditPredictionProvider {
745 fn name() -> &'static str {
746 "zed-predict2"
747 }
748
749 fn display_name() -> &'static str {
750 "Zed's Edit Predictions 2"
751 }
752
753 fn show_completions_in_menu() -> bool {
754 true
755 }
756
757 fn show_tab_accept_marker() -> bool {
758 true
759 }
760
761 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
762 // TODO [zeta2]
763 DataCollectionState::Unsupported
764 }
765
766 fn toggle_data_collection(&mut self, _cx: &mut App) {
767 // TODO [zeta2]
768 }
769
770 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
771 self.zeta.read(cx).usage(cx)
772 }
773
774 fn is_enabled(
775 &self,
776 _buffer: &Entity<language::Buffer>,
777 _cursor_position: language::Anchor,
778 _cx: &App,
779 ) -> bool {
780 true
781 }
782
783 fn is_refreshing(&self) -> bool {
784 !self.pending_predictions.is_empty()
785 }
786
787 fn refresh(
788 &mut self,
789 project: Option<Entity<project::Project>>,
790 buffer: Entity<language::Buffer>,
791 cursor_position: language::Anchor,
792 _debounce: bool,
793 cx: &mut Context<Self>,
794 ) {
795 let Some(project) = project else {
796 return;
797 };
798
799 if self
800 .zeta
801 .read(cx)
802 .user_store
803 .read_with(cx, |user_store, _cx| {
804 user_store.account_too_young() || user_store.has_overdue_invoices()
805 })
806 {
807 return;
808 }
809
810 if let Some(current_prediction) = self.current_prediction.as_ref() {
811 let snapshot = buffer.read(cx).snapshot();
812 if current_prediction
813 .prediction
814 .interpolate(&snapshot)
815 .is_some()
816 {
817 return;
818 }
819 }
820
821 let pending_prediction_id = self.next_pending_prediction_id;
822 self.next_pending_prediction_id += 1;
823 let last_request_timestamp = self.last_request_timestamp;
824
825 let task = cx.spawn(async move |this, cx| {
826 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
827 .checked_duration_since(Instant::now())
828 {
829 cx.background_executor().timer(timeout).await;
830 }
831
832 let prediction_request = this.update(cx, |this, cx| {
833 this.last_request_timestamp = Instant::now();
834 this.zeta.update(cx, |zeta, cx| {
835 zeta.request_prediction(&project, &buffer, cursor_position, cx)
836 })
837 });
838
839 let prediction = match prediction_request {
840 Ok(prediction_request) => {
841 let prediction_request = prediction_request.await;
842 prediction_request.map(|c| {
843 c.map(|prediction| CurrentEditPrediction {
844 buffer_id: buffer.entity_id(),
845 prediction,
846 })
847 })
848 }
849 Err(error) => Err(error),
850 };
851
852 this.update(cx, |this, cx| {
853 if this.pending_predictions[0].id == pending_prediction_id {
854 this.pending_predictions.remove(0);
855 } else {
856 this.pending_predictions.clear();
857 }
858
859 let Some(new_prediction) = prediction
860 .context("edit prediction failed")
861 .log_err()
862 .flatten()
863 else {
864 cx.notify();
865 return;
866 };
867
868 if let Some(old_prediction) = this.current_prediction.as_ref() {
869 let snapshot = buffer.read(cx).snapshot();
870 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
871 this.current_prediction = Some(new_prediction);
872 }
873 } else {
874 this.current_prediction = Some(new_prediction);
875 }
876
877 cx.notify();
878 })
879 .ok();
880 });
881
882 // We always maintain at most two pending predictions. When we already
883 // have two, we replace the newest one.
884 if self.pending_predictions.len() <= 1 {
885 self.pending_predictions.push(PendingPrediction {
886 id: pending_prediction_id,
887 _task: task,
888 });
889 } else if self.pending_predictions.len() == 2 {
890 self.pending_predictions.pop();
891 self.pending_predictions.push(PendingPrediction {
892 id: pending_prediction_id,
893 _task: task,
894 });
895 }
896
897 cx.notify();
898 }
899
900 fn cycle(
901 &mut self,
902 _buffer: Entity<language::Buffer>,
903 _cursor_position: language::Anchor,
904 _direction: Direction,
905 _cx: &mut Context<Self>,
906 ) {
907 }
908
909 fn accept(&mut self, _cx: &mut Context<Self>) {
910 // TODO [zeta2] report accept
911 self.current_prediction.take();
912 self.pending_predictions.clear();
913 }
914
915 fn discard(&mut self, _cx: &mut Context<Self>) {
916 self.pending_predictions.clear();
917 self.current_prediction.take();
918 }
919
920 fn suggest(
921 &mut self,
922 buffer: &Entity<language::Buffer>,
923 cursor_position: language::Anchor,
924 cx: &mut Context<Self>,
925 ) -> Option<edit_prediction::EditPrediction> {
926 let CurrentEditPrediction {
927 buffer_id,
928 prediction,
929 ..
930 } = self.current_prediction.as_mut()?;
931
932 // Invalidate previous prediction if it was generated for a different buffer.
933 if *buffer_id != buffer.entity_id() {
934 self.current_prediction.take();
935 return None;
936 }
937
938 let buffer = buffer.read(cx);
939 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
940 self.current_prediction.take();
941 return None;
942 };
943
944 let cursor_row = cursor_position.to_point(buffer).row;
945 let (closest_edit_ix, (closest_edit_range, _)) =
946 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
947 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
948 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
949 cmp::min(distance_from_start, distance_from_end)
950 })?;
951
952 let mut edit_start_ix = closest_edit_ix;
953 for (range, _) in edits[..edit_start_ix].iter().rev() {
954 let distance_from_closest_edit =
955 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
956 if distance_from_closest_edit <= 1 {
957 edit_start_ix -= 1;
958 } else {
959 break;
960 }
961 }
962
963 let mut edit_end_ix = closest_edit_ix + 1;
964 for (range, _) in &edits[edit_end_ix..] {
965 let distance_from_closest_edit =
966 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
967 if distance_from_closest_edit <= 1 {
968 edit_end_ix += 1;
969 } else {
970 break;
971 }
972 }
973
974 Some(edit_prediction::EditPrediction {
975 id: Some(prediction.id.to_string().into()),
976 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
977 edit_preview: Some(prediction.edit_preview.clone()),
978 })
979 }
980}
981
982fn make_cloud_request(
983 excerpt_path: PathBuf,
984 context: EditPredictionContext,
985 events: Vec<predict_edits_v3::Event>,
986 can_collect_data: bool,
987 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
988 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
989 debug_info: bool,
990 worktrees: &Vec<worktree::Snapshot>,
991 index_state: Option<&SyntaxIndexState>,
992) -> predict_edits_v3::PredictEditsRequest {
993 let mut signatures = Vec::new();
994 let mut declaration_to_signature_index = HashMap::default();
995 let mut referenced_declarations = Vec::new();
996
997 for snippet in context.snippets {
998 let project_entry_id = snippet.declaration.project_entry_id();
999 let Some(path) = worktrees.iter().find_map(|worktree| {
1000 worktree.entry_for_id(project_entry_id).map(|entry| {
1001 let mut full_path = PathBuf::new();
1002 full_path.push(worktree.root_name());
1003 full_path.push(&entry.path);
1004 full_path
1005 })
1006 }) else {
1007 continue;
1008 };
1009
1010 let parent_index = index_state.and_then(|index_state| {
1011 snippet.declaration.parent().and_then(|parent| {
1012 add_signature(
1013 parent,
1014 &mut declaration_to_signature_index,
1015 &mut signatures,
1016 index_state,
1017 )
1018 })
1019 });
1020
1021 let (text, text_is_truncated) = snippet.declaration.item_text();
1022 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1023 path,
1024 text: text.into(),
1025 range: snippet.declaration.item_range(),
1026 text_is_truncated,
1027 signature_range: snippet.declaration.signature_range_in_item_text(),
1028 parent_index,
1029 score_components: snippet.score_components,
1030 signature_score: snippet.scores.signature,
1031 declaration_score: snippet.scores.declaration,
1032 });
1033 }
1034
1035 let excerpt_parent = index_state.and_then(|index_state| {
1036 context
1037 .excerpt
1038 .parent_declarations
1039 .last()
1040 .and_then(|(parent, _)| {
1041 add_signature(
1042 *parent,
1043 &mut declaration_to_signature_index,
1044 &mut signatures,
1045 index_state,
1046 )
1047 })
1048 });
1049
1050 predict_edits_v3::PredictEditsRequest {
1051 excerpt_path,
1052 excerpt: context.excerpt_text.body,
1053 excerpt_range: context.excerpt.range,
1054 cursor_offset: context.cursor_offset_in_excerpt,
1055 referenced_declarations,
1056 signatures,
1057 excerpt_parent,
1058 events,
1059 can_collect_data,
1060 diagnostic_groups,
1061 git_info,
1062 debug_info,
1063 }
1064}
1065
1066fn add_signature(
1067 declaration_id: DeclarationId,
1068 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1069 signatures: &mut Vec<Signature>,
1070 index: &SyntaxIndexState,
1071) -> Option<usize> {
1072 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1073 return Some(*signature_index);
1074 }
1075 let Some(parent_declaration) = index.declaration(declaration_id) else {
1076 log::error!("bug: missing parent declaration");
1077 return None;
1078 };
1079 let parent_index = parent_declaration.parent().and_then(|parent| {
1080 add_signature(parent, declaration_to_signature_index, signatures, index)
1081 });
1082 let (text, text_is_truncated) = parent_declaration.signature_text();
1083 let signature_index = signatures.len();
1084 signatures.push(Signature {
1085 text: text.into(),
1086 text_is_truncated,
1087 parent_index,
1088 range: parent_declaration.signature_range(),
1089 });
1090 declaration_to_signature_index.insert(declaration_id, signature_index);
1091 Some(signature_index)
1092}
1093
1094fn interpolate(
1095 old_snapshot: &BufferSnapshot,
1096 new_snapshot: &BufferSnapshot,
1097 current_edits: Arc<[(Range<Anchor>, String)]>,
1098) -> Option<Vec<(Range<Anchor>, String)>> {
1099 let mut edits = Vec::new();
1100
1101 let mut model_edits = current_edits.iter().peekable();
1102 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1103 while let Some((model_old_range, _)) = model_edits.peek() {
1104 let model_old_range = model_old_range.to_offset(old_snapshot);
1105 if model_old_range.end < user_edit.old.start {
1106 let (model_old_range, model_new_text) = model_edits.next().unwrap();
1107 edits.push((model_old_range.clone(), model_new_text.clone()));
1108 } else {
1109 break;
1110 }
1111 }
1112
1113 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1114 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1115 if user_edit.old == model_old_offset_range {
1116 let user_new_text = new_snapshot
1117 .text_for_range(user_edit.new.clone())
1118 .collect::<String>();
1119
1120 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1121 if !model_suffix.is_empty() {
1122 let anchor = old_snapshot.anchor_after(user_edit.old.end);
1123 edits.push((anchor..anchor, model_suffix.to_string()));
1124 }
1125
1126 model_edits.next();
1127 continue;
1128 }
1129 }
1130 }
1131
1132 return None;
1133 }
1134
1135 edits.extend(model_edits.cloned());
1136
1137 if edits.is_empty() { None } else { Some(edits) }
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142 use super::*;
1143 use gpui::TestAppContext;
1144 use language::ToOffset as _;
1145
1146 #[gpui::test]
1147 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1148 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1149 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1150 to_prediction_edits(
1151 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1152 &buffer,
1153 cx,
1154 )
1155 .into()
1156 });
1157
1158 let edit_preview = cx
1159 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1160 .await;
1161
1162 let prediction = EditPrediction {
1163 id: EditPredictionId(Uuid::new_v4()),
1164 edits,
1165 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1166 edit_preview,
1167 };
1168
1169 cx.update(|cx| {
1170 assert_eq!(
1171 from_prediction_edits(
1172 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1173 &buffer,
1174 cx
1175 ),
1176 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1177 );
1178
1179 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1180 assert_eq!(
1181 from_prediction_edits(
1182 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1183 &buffer,
1184 cx
1185 ),
1186 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1187 );
1188
1189 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1190 assert_eq!(
1191 from_prediction_edits(
1192 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1193 &buffer,
1194 cx
1195 ),
1196 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1197 );
1198
1199 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1200 assert_eq!(
1201 from_prediction_edits(
1202 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1203 &buffer,
1204 cx
1205 ),
1206 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1207 );
1208
1209 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1210 assert_eq!(
1211 from_prediction_edits(
1212 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1213 &buffer,
1214 cx
1215 ),
1216 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1217 );
1218
1219 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1220 assert_eq!(
1221 from_prediction_edits(
1222 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1223 &buffer,
1224 cx
1225 ),
1226 vec![(9..11, "".to_string())]
1227 );
1228
1229 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1230 assert_eq!(
1231 from_prediction_edits(
1232 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1233 &buffer,
1234 cx
1235 ),
1236 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1237 );
1238
1239 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1240 assert_eq!(
1241 from_prediction_edits(
1242 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1243 &buffer,
1244 cx
1245 ),
1246 vec![(4..4, "M".to_string())]
1247 );
1248
1249 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1250 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1251 })
1252 }
1253
1254 fn to_prediction_edits(
1255 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1256 buffer: &Entity<Buffer>,
1257 cx: &App,
1258 ) -> Vec<(Range<Anchor>, String)> {
1259 let buffer = buffer.read(cx);
1260 iterator
1261 .into_iter()
1262 .map(|(range, text)| {
1263 (
1264 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1265 text,
1266 )
1267 })
1268 .collect()
1269 }
1270
1271 fn from_prediction_edits(
1272 editor_edits: &[(Range<Anchor>, String)],
1273 buffer: &Entity<Buffer>,
1274 cx: &App,
1275 ) -> Vec<(Range<usize>, String)> {
1276 let buffer = buffer.read(cx);
1277 editor_edits
1278 .iter()
1279 .map(|(range, text)| {
1280 (
1281 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1282 text.clone(),
1283 )
1284 })
1285 .collect()
1286 }
1287}