1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use gpui::http_client::Method;
15use gpui::{
16 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
17 prelude::*,
18};
19use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
20use language::{BufferSnapshot, EditPreview};
21use language_model::{LlmApiToken, RefreshLlmTokenListener};
22use project::Project;
23use release_channel::AppVersion;
24use std::cmp;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::PathBuf;
27use std::str::FromStr as _;
28use std::time::{Duration, Instant};
29use std::{ops::Range, sync::Arc};
30use thiserror::Error;
31use util::ResultExt as _;
32use uuid::Uuid;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
36
37/// Maximum number of events to track.
38const MAX_EVENT_COUNT: usize = 16;
39
40#[derive(Clone)]
41struct ZetaGlobal(Entity<Zeta>);
42
43impl Global for ZetaGlobal {}
44
45pub struct Zeta {
46 client: Arc<Client>,
47 user_store: Entity<UserStore>,
48 llm_token: LlmApiToken,
49 _llm_token_subscription: Subscription,
50 projects: HashMap<EntityId, ZetaProject>,
51 pub excerpt_options: EditPredictionExcerptOptions,
52 update_required: bool,
53}
54
55struct ZetaProject {
56 syntax_index: Entity<SyntaxIndex>,
57 events: VecDeque<Event>,
58 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
59}
60
61struct RegisteredBuffer {
62 snapshot: BufferSnapshot,
63 _subscriptions: [gpui::Subscription; 2],
64}
65
66#[derive(Clone)]
67pub enum Event {
68 BufferChange {
69 old_snapshot: BufferSnapshot,
70 new_snapshot: BufferSnapshot,
71 timestamp: Instant,
72 },
73}
74
75impl Zeta {
76 pub fn global(
77 client: &Arc<Client>,
78 user_store: &Entity<UserStore>,
79 cx: &mut App,
80 ) -> Entity<Self> {
81 cx.try_global::<ZetaGlobal>()
82 .map(|global| global.0.clone())
83 .unwrap_or_else(|| {
84 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
85 cx.set_global(ZetaGlobal(zeta.clone()));
86 zeta
87 })
88 }
89
90 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
91 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
92
93 Self {
94 projects: HashMap::new(),
95 client,
96 user_store,
97 excerpt_options: EditPredictionExcerptOptions {
98 max_bytes: 512,
99 min_bytes: 128,
100 target_before_cursor_over_total_bytes: 0.5,
101 },
102 llm_token: LlmApiToken::default(),
103 _llm_token_subscription: cx.subscribe(
104 &refresh_llm_token_listener,
105 |this, _listener, _event, cx| {
106 let client = this.client.clone();
107 let llm_token = this.llm_token.clone();
108 cx.spawn(async move |_this, _cx| {
109 llm_token.refresh(&client).await?;
110 anyhow::Ok(())
111 })
112 .detach_and_log_err(cx);
113 },
114 ),
115 update_required: false,
116 }
117 }
118
119 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
120 self.user_store.read(cx).edit_prediction_usage()
121 }
122
123 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
124 self.get_or_init_zeta_project(project, cx);
125 }
126
127 pub fn register_buffer(
128 &mut self,
129 buffer: &Entity<Buffer>,
130 project: &Entity<Project>,
131 cx: &mut Context<Self>,
132 ) {
133 let zeta_project = self.get_or_init_zeta_project(project, cx);
134 Self::register_buffer_impl(zeta_project, buffer, project, cx);
135 }
136
137 fn get_or_init_zeta_project(
138 &mut self,
139 project: &Entity<Project>,
140 cx: &mut App,
141 ) -> &mut ZetaProject {
142 self.projects
143 .entry(project.entity_id())
144 .or_insert_with(|| ZetaProject {
145 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
146 events: VecDeque::new(),
147 registered_buffers: HashMap::new(),
148 })
149 }
150
151 fn register_buffer_impl<'a>(
152 zeta_project: &'a mut ZetaProject,
153 buffer: &Entity<Buffer>,
154 project: &Entity<Project>,
155 cx: &mut Context<Self>,
156 ) -> &'a mut RegisteredBuffer {
157 let buffer_id = buffer.entity_id();
158 match zeta_project.registered_buffers.entry(buffer_id) {
159 hash_map::Entry::Occupied(entry) => entry.into_mut(),
160 hash_map::Entry::Vacant(entry) => {
161 let snapshot = buffer.read(cx).snapshot();
162 let project_entity_id = project.entity_id();
163 entry.insert(RegisteredBuffer {
164 snapshot,
165 _subscriptions: [
166 cx.subscribe(buffer, {
167 let project = project.downgrade();
168 move |this, buffer, event, cx| {
169 if let language::BufferEvent::Edited = event
170 && let Some(project) = project.upgrade()
171 {
172 this.report_changes_for_buffer(&buffer, &project, cx);
173 }
174 }
175 }),
176 cx.observe_release(buffer, move |this, _buffer, _cx| {
177 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
178 else {
179 return;
180 };
181 zeta_project.registered_buffers.remove(&buffer_id);
182 }),
183 ],
184 })
185 }
186 }
187 }
188
189 fn report_changes_for_buffer(
190 &mut self,
191 buffer: &Entity<Buffer>,
192 project: &Entity<Project>,
193 cx: &mut Context<Self>,
194 ) -> BufferSnapshot {
195 let zeta_project = self.get_or_init_zeta_project(project, cx);
196 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
197
198 let new_snapshot = buffer.read(cx).snapshot();
199 if new_snapshot.version != registered_buffer.snapshot.version {
200 let old_snapshot =
201 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
202 Self::push_event(
203 zeta_project,
204 Event::BufferChange {
205 old_snapshot,
206 new_snapshot: new_snapshot.clone(),
207 timestamp: Instant::now(),
208 },
209 );
210 }
211
212 new_snapshot
213 }
214
215 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
216 let events = &mut zeta_project.events;
217
218 if let Some(Event::BufferChange {
219 new_snapshot: last_new_snapshot,
220 timestamp: last_timestamp,
221 ..
222 }) = events.back_mut()
223 {
224 // Coalesce edits for the same buffer when they happen one after the other.
225 let Event::BufferChange {
226 old_snapshot,
227 new_snapshot,
228 timestamp,
229 } = &event;
230
231 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
232 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
233 && old_snapshot.version == last_new_snapshot.version
234 {
235 *last_new_snapshot = new_snapshot.clone();
236 *last_timestamp = *timestamp;
237 return;
238 }
239 }
240
241 if events.len() >= MAX_EVENT_COUNT {
242 // These are halved instead of popping to improve prompt caching.
243 events.drain(..MAX_EVENT_COUNT / 2);
244 }
245
246 events.push_back(event);
247 }
248
249 pub fn request_prediction(
250 &mut self,
251 project: &Entity<Project>,
252 buffer: &Entity<Buffer>,
253 position: language::Anchor,
254 cx: &mut Context<Self>,
255 ) -> Task<Result<Option<EditPrediction>>> {
256 let project_state = self.projects.get(&project.entity_id());
257
258 let index_state = project_state.map(|state| {
259 state
260 .syntax_index
261 .read_with(cx, |index, _cx| index.state().clone())
262 });
263 let excerpt_options = self.excerpt_options.clone();
264 let snapshot = buffer.read(cx).snapshot();
265 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
266 return Task::ready(Err(anyhow!("No file path for excerpt")));
267 };
268 let client = self.client.clone();
269 let llm_token = self.llm_token.clone();
270 let app_version = AppVersion::global(cx);
271 let worktree_snapshots = project
272 .read(cx)
273 .worktrees(cx)
274 .map(|worktree| worktree.read(cx).snapshot())
275 .collect::<Vec<_>>();
276
277 let request_task = cx.background_spawn({
278 let snapshot = snapshot.clone();
279 async move {
280 let index_state = if let Some(index_state) = index_state {
281 Some(index_state.lock_owned().await)
282 } else {
283 None
284 };
285
286 let cursor_point = position.to_point(&snapshot);
287
288 // TODO: make this only true if debug view is open
289 let debug_info = true;
290
291 let Some(request) = EditPredictionContext::gather_context(
292 cursor_point,
293 &snapshot,
294 &excerpt_options,
295 index_state.as_deref(),
296 )
297 .map(|context| {
298 make_cloud_request(
299 excerpt_path.clone(),
300 context,
301 // TODO pass everything
302 Vec::new(),
303 false,
304 Vec::new(),
305 None,
306 debug_info,
307 &worktree_snapshots,
308 index_state.as_deref(),
309 )
310 }) else {
311 return Ok(None);
312 };
313
314 anyhow::Ok(Some(
315 Self::perform_request(client, llm_token, app_version, request).await?,
316 ))
317 }
318 });
319
320 let buffer = buffer.clone();
321
322 cx.spawn(async move |this, cx| {
323 match request_task.await {
324 Ok(Some((response, usage))) => {
325 log::debug!("predicted edits: {:?}", &response.edits);
326
327 if let Some(usage) = usage {
328 this.update(cx, |this, cx| {
329 this.user_store.update(cx, |user_store, cx| {
330 user_store.update_edit_prediction_usage(usage, cx);
331 });
332 })
333 .ok();
334 }
335
336 // TODO telemetry: duration, etc
337
338 // TODO produce smaller edits by diffing against snapshot first
339 //
340 // Cloud returns entire snippets/excerpts ranges as they were included
341 // in the request, but we should display smaller edits to the user.
342 //
343 // We can do this by computing a diff of each one against the snapshot.
344 // Similar to zeta::Zeta::compute_edits, but per edit.
345 let edits = response
346 .edits
347 .into_iter()
348 .map(|edit| {
349 // TODO edits to different files
350 (
351 snapshot.anchor_before(edit.range.start)
352 ..snapshot.anchor_before(edit.range.end),
353 edit.content,
354 )
355 })
356 .collect::<Vec<_>>()
357 .into();
358
359 let Some((edits, snapshot, edit_preview_task)) =
360 buffer.read_with(cx, |buffer, cx| {
361 let new_snapshot = buffer.snapshot();
362 let edits: Arc<[_]> =
363 interpolate(&snapshot, &new_snapshot, edits)?.into();
364 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
365 })?
366 else {
367 return Ok(None);
368 };
369
370 Ok(Some(EditPrediction {
371 id: EditPredictionId(response.request_id),
372 edits,
373 snapshot,
374 edit_preview: edit_preview_task.await,
375 }))
376 }
377 Ok(None) => Ok(None),
378 Err(err) => {
379 if err.is::<ZedUpdateRequiredError>() {
380 cx.update(|cx| {
381 this.update(cx, |this, _cx| {
382 this.update_required = true;
383 })
384 .ok();
385
386 let error_message: SharedString = err.to_string().into();
387 show_app_notification(
388 NotificationId::unique::<ZedUpdateRequiredError>(),
389 cx,
390 move |cx| {
391 cx.new(|cx| {
392 ErrorMessagePrompt::new(error_message.clone(), cx)
393 .with_link_button(
394 "Update Zed",
395 "https://zed.dev/releases",
396 )
397 })
398 },
399 );
400 })
401 .ok();
402 }
403
404 Err(err)
405 }
406 }
407 })
408 }
409
410 async fn perform_request(
411 client: Arc<Client>,
412 llm_token: LlmApiToken,
413 app_version: SemanticVersion,
414 request: predict_edits_v3::PredictEditsRequest,
415 ) -> Result<(
416 predict_edits_v3::PredictEditsResponse,
417 Option<EditPredictionUsage>,
418 )> {
419 let http_client = client.http_client();
420 let mut token = llm_token.acquire(&client).await?;
421 let mut did_retry = false;
422
423 loop {
424 let request_builder = http_client::Request::builder().method(Method::POST);
425 let request_builder =
426 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
427 request_builder.uri(predict_edits_url)
428 } else {
429 request_builder.uri(
430 http_client
431 .build_zed_llm_url("/predict_edits/v3", &[])?
432 .as_ref(),
433 )
434 };
435 let request = request_builder
436 .header("Content-Type", "application/json")
437 .header("Authorization", format!("Bearer {}", token))
438 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
439 .body(serde_json::to_string(&request)?.into())?;
440
441 let mut response = http_client.send(request).await?;
442
443 if let Some(minimum_required_version) = response
444 .headers()
445 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
446 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
447 {
448 anyhow::ensure!(
449 app_version >= minimum_required_version,
450 ZedUpdateRequiredError {
451 minimum_version: minimum_required_version
452 }
453 );
454 }
455
456 if response.status().is_success() {
457 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
458
459 let mut body = Vec::new();
460 response.body_mut().read_to_end(&mut body).await?;
461 return Ok((serde_json::from_slice(&body)?, usage));
462 } else if !did_retry
463 && response
464 .headers()
465 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
466 .is_some()
467 {
468 did_retry = true;
469 token = llm_token.refresh(&client).await?;
470 } else {
471 let mut body = String::new();
472 response.body_mut().read_to_string(&mut body).await?;
473 anyhow::bail!(
474 "error predicting edits.\nStatus: {:?}\nBody: {}",
475 response.status(),
476 body
477 );
478 }
479 }
480 }
481
482 // TODO: Dedupe with similar code in request_prediction?
483 pub fn cloud_request_for_zeta_cli(
484 &mut self,
485 project: &Entity<Project>,
486 buffer: &Entity<Buffer>,
487 position: language::Anchor,
488 cx: &mut Context<Self>,
489 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
490 let project_state = self.projects.get(&project.entity_id());
491
492 let index_state = project_state.map(|state| {
493 state
494 .syntax_index
495 .read_with(cx, |index, _cx| index.state().clone())
496 });
497 let excerpt_options = self.excerpt_options.clone();
498 let snapshot = buffer.read(cx).snapshot();
499 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
500 return Task::ready(Err(anyhow!("No file path for excerpt")));
501 };
502 let worktree_snapshots = project
503 .read(cx)
504 .worktrees(cx)
505 .map(|worktree| worktree.read(cx).snapshot())
506 .collect::<Vec<_>>();
507
508 cx.background_spawn(async move {
509 let index_state = if let Some(index_state) = index_state {
510 Some(index_state.lock_owned().await)
511 } else {
512 None
513 };
514
515 let cursor_point = position.to_point(&snapshot);
516
517 let debug_info = true;
518 EditPredictionContext::gather_context(
519 cursor_point,
520 &snapshot,
521 &excerpt_options,
522 index_state.as_deref(),
523 )
524 .context("Failed to select excerpt")
525 .map(|context| {
526 make_cloud_request(
527 excerpt_path.clone(),
528 context,
529 // TODO pass everything
530 Vec::new(),
531 false,
532 Vec::new(),
533 None,
534 debug_info,
535 &worktree_snapshots,
536 index_state.as_deref(),
537 )
538 })
539 })
540 }
541}
542
543#[derive(Error, Debug)]
544#[error(
545 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
546)]
547pub struct ZedUpdateRequiredError {
548 minimum_version: SemanticVersion,
549}
550
551pub struct ZetaEditPredictionProvider {
552 zeta: Entity<Zeta>,
553 current_prediction: Option<CurrentEditPrediction>,
554 next_pending_prediction_id: usize,
555 pending_predictions: ArrayVec<PendingPrediction, 2>,
556 last_request_timestamp: Instant,
557}
558
559impl ZetaEditPredictionProvider {
560 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
561
562 pub fn new(
563 project: Option<&Entity<Project>>,
564 client: &Arc<Client>,
565 user_store: &Entity<UserStore>,
566 cx: &mut App,
567 ) -> Self {
568 let zeta = Zeta::global(client, user_store, cx);
569 if let Some(project) = project {
570 zeta.update(cx, |zeta, cx| {
571 zeta.register_project(project, cx);
572 });
573 }
574
575 Self {
576 zeta,
577 current_prediction: None,
578 next_pending_prediction_id: 0,
579 pending_predictions: ArrayVec::new(),
580 last_request_timestamp: Instant::now(),
581 }
582 }
583}
584
585#[derive(Clone)]
586struct CurrentEditPrediction {
587 buffer_id: EntityId,
588 prediction: EditPrediction,
589}
590
591impl CurrentEditPrediction {
592 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
593 if self.buffer_id != old_prediction.buffer_id {
594 return true;
595 }
596
597 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
598 return true;
599 };
600 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
601 return false;
602 };
603
604 if old_edits.len() == 1 && new_edits.len() == 1 {
605 let (old_range, old_text) = &old_edits[0];
606 let (new_range, new_text) = &new_edits[0];
607 new_range == old_range && new_text.starts_with(old_text)
608 } else {
609 true
610 }
611 }
612}
613
614#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
615pub struct EditPredictionId(Uuid);
616
617impl From<EditPredictionId> for gpui::ElementId {
618 fn from(value: EditPredictionId) -> Self {
619 gpui::ElementId::Uuid(value.0)
620 }
621}
622
623impl std::fmt::Display for EditPredictionId {
624 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
625 write!(f, "{}", self.0)
626 }
627}
628
629#[derive(Clone)]
630pub struct EditPrediction {
631 id: EditPredictionId,
632 edits: Arc<[(Range<Anchor>, String)]>,
633 snapshot: BufferSnapshot,
634 edit_preview: EditPreview,
635}
636
637impl EditPrediction {
638 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
639 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
640 }
641}
642
643struct PendingPrediction {
644 id: usize,
645 _task: Task<()>,
646}
647
648impl EditPredictionProvider for ZetaEditPredictionProvider {
649 fn name() -> &'static str {
650 "zed-predict2"
651 }
652
653 fn display_name() -> &'static str {
654 "Zed's Edit Predictions 2"
655 }
656
657 fn show_completions_in_menu() -> bool {
658 true
659 }
660
661 fn show_tab_accept_marker() -> bool {
662 true
663 }
664
665 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
666 // TODO [zeta2]
667 DataCollectionState::Unsupported
668 }
669
670 fn toggle_data_collection(&mut self, _cx: &mut App) {
671 // TODO [zeta2]
672 }
673
674 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
675 self.zeta.read(cx).usage(cx)
676 }
677
678 fn is_enabled(
679 &self,
680 _buffer: &Entity<language::Buffer>,
681 _cursor_position: language::Anchor,
682 _cx: &App,
683 ) -> bool {
684 true
685 }
686
687 fn is_refreshing(&self) -> bool {
688 !self.pending_predictions.is_empty()
689 }
690
691 fn refresh(
692 &mut self,
693 project: Option<Entity<project::Project>>,
694 buffer: Entity<language::Buffer>,
695 cursor_position: language::Anchor,
696 _debounce: bool,
697 cx: &mut Context<Self>,
698 ) {
699 let Some(project) = project else {
700 return;
701 };
702
703 if self
704 .zeta
705 .read(cx)
706 .user_store
707 .read_with(cx, |user_store, _cx| {
708 user_store.account_too_young() || user_store.has_overdue_invoices()
709 })
710 {
711 return;
712 }
713
714 if let Some(current_prediction) = self.current_prediction.as_ref() {
715 let snapshot = buffer.read(cx).snapshot();
716 if current_prediction
717 .prediction
718 .interpolate(&snapshot)
719 .is_some()
720 {
721 return;
722 }
723 }
724
725 let pending_prediction_id = self.next_pending_prediction_id;
726 self.next_pending_prediction_id += 1;
727 let last_request_timestamp = self.last_request_timestamp;
728
729 let task = cx.spawn(async move |this, cx| {
730 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
731 .checked_duration_since(Instant::now())
732 {
733 cx.background_executor().timer(timeout).await;
734 }
735
736 let prediction_request = this.update(cx, |this, cx| {
737 this.last_request_timestamp = Instant::now();
738 this.zeta.update(cx, |zeta, cx| {
739 zeta.request_prediction(&project, &buffer, cursor_position, cx)
740 })
741 });
742
743 let prediction = match prediction_request {
744 Ok(prediction_request) => {
745 let prediction_request = prediction_request.await;
746 prediction_request.map(|c| {
747 c.map(|prediction| CurrentEditPrediction {
748 buffer_id: buffer.entity_id(),
749 prediction,
750 })
751 })
752 }
753 Err(error) => Err(error),
754 };
755
756 this.update(cx, |this, cx| {
757 if this.pending_predictions[0].id == pending_prediction_id {
758 this.pending_predictions.remove(0);
759 } else {
760 this.pending_predictions.clear();
761 }
762
763 let Some(new_prediction) = prediction
764 .context("edit prediction failed")
765 .log_err()
766 .flatten()
767 else {
768 cx.notify();
769 return;
770 };
771
772 if let Some(old_prediction) = this.current_prediction.as_ref() {
773 let snapshot = buffer.read(cx).snapshot();
774 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
775 this.current_prediction = Some(new_prediction);
776 }
777 } else {
778 this.current_prediction = Some(new_prediction);
779 }
780
781 cx.notify();
782 })
783 .ok();
784 });
785
786 // We always maintain at most two pending predictions. When we already
787 // have two, we replace the newest one.
788 if self.pending_predictions.len() <= 1 {
789 self.pending_predictions.push(PendingPrediction {
790 id: pending_prediction_id,
791 _task: task,
792 });
793 } else if self.pending_predictions.len() == 2 {
794 self.pending_predictions.pop();
795 self.pending_predictions.push(PendingPrediction {
796 id: pending_prediction_id,
797 _task: task,
798 });
799 }
800
801 cx.notify();
802 }
803
804 fn cycle(
805 &mut self,
806 _buffer: Entity<language::Buffer>,
807 _cursor_position: language::Anchor,
808 _direction: Direction,
809 _cx: &mut Context<Self>,
810 ) {
811 }
812
813 fn accept(&mut self, _cx: &mut Context<Self>) {
814 // TODO [zeta2] report accept
815 self.current_prediction.take();
816 self.pending_predictions.clear();
817 }
818
819 fn discard(&mut self, _cx: &mut Context<Self>) {
820 self.pending_predictions.clear();
821 self.current_prediction.take();
822 }
823
824 fn suggest(
825 &mut self,
826 buffer: &Entity<language::Buffer>,
827 cursor_position: language::Anchor,
828 cx: &mut Context<Self>,
829 ) -> Option<edit_prediction::EditPrediction> {
830 let CurrentEditPrediction {
831 buffer_id,
832 prediction,
833 ..
834 } = self.current_prediction.as_mut()?;
835
836 // Invalidate previous prediction if it was generated for a different buffer.
837 if *buffer_id != buffer.entity_id() {
838 self.current_prediction.take();
839 return None;
840 }
841
842 let buffer = buffer.read(cx);
843 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
844 self.current_prediction.take();
845 return None;
846 };
847
848 let cursor_row = cursor_position.to_point(buffer).row;
849 let (closest_edit_ix, (closest_edit_range, _)) =
850 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
851 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
852 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
853 cmp::min(distance_from_start, distance_from_end)
854 })?;
855
856 let mut edit_start_ix = closest_edit_ix;
857 for (range, _) in edits[..edit_start_ix].iter().rev() {
858 let distance_from_closest_edit =
859 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
860 if distance_from_closest_edit <= 1 {
861 edit_start_ix -= 1;
862 } else {
863 break;
864 }
865 }
866
867 let mut edit_end_ix = closest_edit_ix + 1;
868 for (range, _) in &edits[edit_end_ix..] {
869 let distance_from_closest_edit =
870 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
871 if distance_from_closest_edit <= 1 {
872 edit_end_ix += 1;
873 } else {
874 break;
875 }
876 }
877
878 Some(edit_prediction::EditPrediction {
879 id: Some(prediction.id.to_string().into()),
880 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
881 edit_preview: Some(prediction.edit_preview.clone()),
882 })
883 }
884}
885
886fn make_cloud_request(
887 excerpt_path: PathBuf,
888 context: EditPredictionContext,
889 events: Vec<predict_edits_v3::Event>,
890 can_collect_data: bool,
891 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
892 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
893 debug_info: bool,
894 worktrees: &Vec<worktree::Snapshot>,
895 index_state: Option<&SyntaxIndexState>,
896) -> predict_edits_v3::PredictEditsRequest {
897 let mut signatures = Vec::new();
898 let mut declaration_to_signature_index = HashMap::default();
899 let mut referenced_declarations = Vec::new();
900
901 for snippet in context.snippets {
902 let project_entry_id = snippet.declaration.project_entry_id();
903 let Some(path) = worktrees.iter().find_map(|worktree| {
904 worktree.entry_for_id(project_entry_id).map(|entry| {
905 let mut full_path = PathBuf::new();
906 full_path.push(worktree.root_name());
907 full_path.push(&entry.path);
908 full_path
909 })
910 }) else {
911 continue;
912 };
913
914 let parent_index = index_state.and_then(|index_state| {
915 snippet.declaration.parent().and_then(|parent| {
916 add_signature(
917 parent,
918 &mut declaration_to_signature_index,
919 &mut signatures,
920 index_state,
921 )
922 })
923 });
924
925 let (text, text_is_truncated) = snippet.declaration.item_text();
926 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
927 path,
928 text: text.into(),
929 range: snippet.declaration.item_range(),
930 text_is_truncated,
931 signature_range: snippet.declaration.signature_range_in_item_text(),
932 parent_index,
933 score_components: snippet.score_components,
934 signature_score: snippet.scores.signature,
935 declaration_score: snippet.scores.declaration,
936 });
937 }
938
939 let excerpt_parent = index_state.and_then(|index_state| {
940 context
941 .excerpt
942 .parent_declarations
943 .last()
944 .and_then(|(parent, _)| {
945 add_signature(
946 *parent,
947 &mut declaration_to_signature_index,
948 &mut signatures,
949 index_state,
950 )
951 })
952 });
953
954 predict_edits_v3::PredictEditsRequest {
955 excerpt_path,
956 excerpt: context.excerpt_text.body,
957 excerpt_range: context.excerpt.range,
958 cursor_offset: context.cursor_offset_in_excerpt,
959 referenced_declarations,
960 signatures,
961 excerpt_parent,
962 events,
963 can_collect_data,
964 diagnostic_groups,
965 git_info,
966 debug_info,
967 }
968}
969
970fn add_signature(
971 declaration_id: DeclarationId,
972 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
973 signatures: &mut Vec<Signature>,
974 index: &SyntaxIndexState,
975) -> Option<usize> {
976 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
977 return Some(*signature_index);
978 }
979 let Some(parent_declaration) = index.declaration(declaration_id) else {
980 log::error!("bug: missing parent declaration");
981 return None;
982 };
983 let parent_index = parent_declaration.parent().and_then(|parent| {
984 add_signature(parent, declaration_to_signature_index, signatures, index)
985 });
986 let (text, text_is_truncated) = parent_declaration.signature_text();
987 let signature_index = signatures.len();
988 signatures.push(Signature {
989 text: text.into(),
990 text_is_truncated,
991 parent_index,
992 range: parent_declaration.signature_range(),
993 });
994 declaration_to_signature_index.insert(declaration_id, signature_index);
995 Some(signature_index)
996}
997
998fn interpolate(
999 old_snapshot: &BufferSnapshot,
1000 new_snapshot: &BufferSnapshot,
1001 current_edits: Arc<[(Range<Anchor>, String)]>,
1002) -> Option<Vec<(Range<Anchor>, String)>> {
1003 let mut edits = Vec::new();
1004
1005 let mut model_edits = current_edits.iter().peekable();
1006 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1007 while let Some((model_old_range, _)) = model_edits.peek() {
1008 let model_old_range = model_old_range.to_offset(old_snapshot);
1009 if model_old_range.end < user_edit.old.start {
1010 let (model_old_range, model_new_text) = model_edits.next().unwrap();
1011 edits.push((model_old_range.clone(), model_new_text.clone()));
1012 } else {
1013 break;
1014 }
1015 }
1016
1017 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1018 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1019 if user_edit.old == model_old_offset_range {
1020 let user_new_text = new_snapshot
1021 .text_for_range(user_edit.new.clone())
1022 .collect::<String>();
1023
1024 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1025 if !model_suffix.is_empty() {
1026 let anchor = old_snapshot.anchor_after(user_edit.old.end);
1027 edits.push((anchor..anchor, model_suffix.to_string()));
1028 }
1029
1030 model_edits.next();
1031 continue;
1032 }
1033 }
1034 }
1035
1036 return None;
1037 }
1038
1039 edits.extend(model_edits.cloned());
1040
1041 if edits.is_empty() { None } else { Some(edits) }
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046 use super::*;
1047 use gpui::TestAppContext;
1048 use language::ToOffset as _;
1049
1050 #[gpui::test]
1051 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1052 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1053 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1054 to_prediction_edits(
1055 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1056 &buffer,
1057 cx,
1058 )
1059 .into()
1060 });
1061
1062 let edit_preview = cx
1063 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1064 .await;
1065
1066 let prediction = EditPrediction {
1067 id: EditPredictionId(Uuid::new_v4()),
1068 edits,
1069 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1070 edit_preview,
1071 };
1072
1073 cx.update(|cx| {
1074 assert_eq!(
1075 from_prediction_edits(
1076 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1077 &buffer,
1078 cx
1079 ),
1080 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1081 );
1082
1083 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1084 assert_eq!(
1085 from_prediction_edits(
1086 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1087 &buffer,
1088 cx
1089 ),
1090 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1091 );
1092
1093 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1094 assert_eq!(
1095 from_prediction_edits(
1096 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1097 &buffer,
1098 cx
1099 ),
1100 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1101 );
1102
1103 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1104 assert_eq!(
1105 from_prediction_edits(
1106 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1107 &buffer,
1108 cx
1109 ),
1110 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1111 );
1112
1113 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1114 assert_eq!(
1115 from_prediction_edits(
1116 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1117 &buffer,
1118 cx
1119 ),
1120 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1121 );
1122
1123 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1124 assert_eq!(
1125 from_prediction_edits(
1126 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1127 &buffer,
1128 cx
1129 ),
1130 vec![(9..11, "".to_string())]
1131 );
1132
1133 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1134 assert_eq!(
1135 from_prediction_edits(
1136 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1137 &buffer,
1138 cx
1139 ),
1140 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1141 );
1142
1143 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1144 assert_eq!(
1145 from_prediction_edits(
1146 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1147 &buffer,
1148 cx
1149 ),
1150 vec![(4..4, "M".to_string())]
1151 );
1152
1153 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1154 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1155 })
1156 }
1157
1158 fn to_prediction_edits(
1159 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1160 buffer: &Entity<Buffer>,
1161 cx: &App,
1162 ) -> Vec<(Range<Anchor>, String)> {
1163 let buffer = buffer.read(cx);
1164 iterator
1165 .into_iter()
1166 .map(|(range, text)| {
1167 (
1168 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1169 text,
1170 )
1171 })
1172 .collect()
1173 }
1174
1175 fn from_prediction_edits(
1176 editor_edits: &[(Range<Anchor>, String)],
1177 buffer: &Entity<Buffer>,
1178 cx: &App,
1179 ) -> Vec<(Range<usize>, String)> {
1180 let buffer = buffer.read(cx);
1181 editor_edits
1182 .iter()
1183 .map(|(range, text)| {
1184 (
1185 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1186 text.clone(),
1187 )
1188 })
1189 .collect()
1190 }
1191}