1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use gpui::http_client::Method;
15use gpui::{
16 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
17 prelude::*,
18};
19use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
20use language::{BufferSnapshot, EditPreview};
21use language_model::{LlmApiToken, RefreshLlmTokenListener};
22use project::Project;
23use release_channel::AppVersion;
24use std::cmp;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::PathBuf;
27use std::str::FromStr as _;
28use std::time::{Duration, Instant};
29use std::{ops::Range, sync::Arc};
30use thiserror::Error;
31use util::ResultExt as _;
32use uuid::Uuid;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
36
37/// Maximum number of events to track.
38const MAX_EVENT_COUNT: usize = 16;
39
40#[derive(Clone)]
41struct ZetaGlobal(Entity<Zeta>);
42
43impl Global for ZetaGlobal {}
44
45pub struct Zeta {
46 client: Arc<Client>,
47 user_store: Entity<UserStore>,
48 llm_token: LlmApiToken,
49 _llm_token_subscription: Subscription,
50 projects: HashMap<EntityId, ZetaProject>,
51 excerpt_options: EditPredictionExcerptOptions,
52 update_required: bool,
53}
54
55struct ZetaProject {
56 syntax_index: Entity<SyntaxIndex>,
57 events: VecDeque<Event>,
58 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
59}
60
61struct RegisteredBuffer {
62 snapshot: BufferSnapshot,
63 _subscriptions: [gpui::Subscription; 2],
64}
65
66#[derive(Clone)]
67pub enum Event {
68 BufferChange {
69 old_snapshot: BufferSnapshot,
70 new_snapshot: BufferSnapshot,
71 timestamp: Instant,
72 },
73}
74
75impl Zeta {
76 pub fn global(
77 client: &Arc<Client>,
78 user_store: &Entity<UserStore>,
79 cx: &mut App,
80 ) -> Entity<Self> {
81 cx.try_global::<ZetaGlobal>()
82 .map(|global| global.0.clone())
83 .unwrap_or_else(|| {
84 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
85 cx.set_global(ZetaGlobal(zeta.clone()));
86 zeta
87 })
88 }
89
90 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
91 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
92
93 Self {
94 projects: HashMap::new(),
95 client,
96 user_store,
97 excerpt_options: EditPredictionExcerptOptions {
98 max_bytes: 512,
99 min_bytes: 128,
100 target_before_cursor_over_total_bytes: 0.5,
101 },
102 llm_token: LlmApiToken::default(),
103 _llm_token_subscription: cx.subscribe(
104 &refresh_llm_token_listener,
105 |this, _listener, _event, cx| {
106 let client = this.client.clone();
107 let llm_token = this.llm_token.clone();
108 cx.spawn(async move |_this, _cx| {
109 llm_token.refresh(&client).await?;
110 anyhow::Ok(())
111 })
112 .detach_and_log_err(cx);
113 },
114 ),
115 update_required: false,
116 }
117 }
118
119 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
120 self.user_store.read(cx).edit_prediction_usage()
121 }
122
123 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
124 self.get_or_init_zeta_project(project, cx);
125 }
126
127 pub fn register_buffer(
128 &mut self,
129 buffer: &Entity<Buffer>,
130 project: &Entity<Project>,
131 cx: &mut Context<Self>,
132 ) {
133 let zeta_project = self.get_or_init_zeta_project(project, cx);
134 Self::register_buffer_impl(zeta_project, buffer, project, cx);
135 }
136
137 fn get_or_init_zeta_project(
138 &mut self,
139 project: &Entity<Project>,
140 cx: &mut App,
141 ) -> &mut ZetaProject {
142 self.projects
143 .entry(project.entity_id())
144 .or_insert_with(|| ZetaProject {
145 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
146 events: VecDeque::new(),
147 registered_buffers: HashMap::new(),
148 })
149 }
150
151 fn register_buffer_impl<'a>(
152 zeta_project: &'a mut ZetaProject,
153 buffer: &Entity<Buffer>,
154 project: &Entity<Project>,
155 cx: &mut Context<Self>,
156 ) -> &'a mut RegisteredBuffer {
157 let buffer_id = buffer.entity_id();
158 match zeta_project.registered_buffers.entry(buffer_id) {
159 hash_map::Entry::Occupied(entry) => entry.into_mut(),
160 hash_map::Entry::Vacant(entry) => {
161 let snapshot = buffer.read(cx).snapshot();
162 let project_entity_id = project.entity_id();
163 entry.insert(RegisteredBuffer {
164 snapshot,
165 _subscriptions: [
166 cx.subscribe(buffer, {
167 let project = project.downgrade();
168 move |this, buffer, event, cx| {
169 if let language::BufferEvent::Edited = event
170 && let Some(project) = project.upgrade()
171 {
172 this.report_changes_for_buffer(&buffer, &project, cx);
173 }
174 }
175 }),
176 cx.observe_release(buffer, move |this, _buffer, _cx| {
177 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
178 else {
179 return;
180 };
181 zeta_project.registered_buffers.remove(&buffer_id);
182 }),
183 ],
184 })
185 }
186 }
187 }
188
189 fn report_changes_for_buffer(
190 &mut self,
191 buffer: &Entity<Buffer>,
192 project: &Entity<Project>,
193 cx: &mut Context<Self>,
194 ) -> BufferSnapshot {
195 let zeta_project = self.get_or_init_zeta_project(project, cx);
196 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
197
198 let new_snapshot = buffer.read(cx).snapshot();
199 if new_snapshot.version != registered_buffer.snapshot.version {
200 let old_snapshot =
201 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
202 Self::push_event(
203 zeta_project,
204 Event::BufferChange {
205 old_snapshot,
206 new_snapshot: new_snapshot.clone(),
207 timestamp: Instant::now(),
208 },
209 );
210 }
211
212 new_snapshot
213 }
214
215 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
216 let events = &mut zeta_project.events;
217
218 if let Some(Event::BufferChange {
219 new_snapshot: last_new_snapshot,
220 timestamp: last_timestamp,
221 ..
222 }) = events.back_mut()
223 {
224 // Coalesce edits for the same buffer when they happen one after the other.
225 let Event::BufferChange {
226 old_snapshot,
227 new_snapshot,
228 timestamp,
229 } = &event;
230
231 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
232 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
233 && old_snapshot.version == last_new_snapshot.version
234 {
235 *last_new_snapshot = new_snapshot.clone();
236 *last_timestamp = *timestamp;
237 return;
238 }
239 }
240
241 if events.len() >= MAX_EVENT_COUNT {
242 // These are halved instead of popping to improve prompt caching.
243 events.drain(..MAX_EVENT_COUNT / 2);
244 }
245
246 events.push_back(event);
247 }
248
249 pub fn request_prediction(
250 &mut self,
251 project: &Entity<Project>,
252 buffer: &Entity<Buffer>,
253 position: language::Anchor,
254 cx: &mut Context<Self>,
255 ) -> Task<Result<Option<EditPrediction>>> {
256 let project_state = self.projects.get(&project.entity_id());
257
258 let index_state = project_state.map(|state| {
259 state
260 .syntax_index
261 .read_with(cx, |index, _cx| index.state().clone())
262 });
263 let excerpt_options = self.excerpt_options.clone();
264 let snapshot = buffer.read(cx).snapshot();
265 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
266 return Task::ready(Err(anyhow!("No file path for excerpt")));
267 };
268 let client = self.client.clone();
269 let llm_token = self.llm_token.clone();
270 let app_version = AppVersion::global(cx);
271 let worktree_snapshots = project
272 .read(cx)
273 .worktrees(cx)
274 .map(|worktree| worktree.read(cx).snapshot())
275 .collect::<Vec<_>>();
276
277 let request_task = cx.background_spawn({
278 let snapshot = snapshot.clone();
279 async move {
280 let index_state = if let Some(index_state) = index_state {
281 Some(index_state.lock_owned().await)
282 } else {
283 None
284 };
285
286 let cursor_point = position.to_point(&snapshot);
287
288 // TODO: make this only true if debug view is open
289 let debug_info = true;
290
291 let Some(request) = EditPredictionContext::gather_context(
292 cursor_point,
293 &snapshot,
294 &excerpt_options,
295 index_state.as_deref(),
296 )
297 .map(|context| {
298 make_cloud_request(
299 excerpt_path.clone(),
300 context,
301 // TODO pass everything
302 Vec::new(),
303 false,
304 Vec::new(),
305 None,
306 debug_info,
307 &worktree_snapshots,
308 index_state.as_deref(),
309 )
310 }) else {
311 return Ok(None);
312 };
313
314 anyhow::Ok(Some(
315 Self::perform_request(client, llm_token, app_version, request).await?,
316 ))
317 }
318 });
319
320 let buffer = buffer.clone();
321
322 cx.spawn(async move |this, cx| {
323 match request_task.await {
324 Ok(Some((response, usage))) => {
325 log::debug!("predicted edits: {:?}", &response.edits);
326
327 if let Some(usage) = usage {
328 this.update(cx, |this, cx| {
329 this.user_store.update(cx, |user_store, cx| {
330 user_store.update_edit_prediction_usage(usage, cx);
331 });
332 })
333 .ok();
334 }
335
336 // TODO telemetry: duration, etc
337
338 // TODO produce smaller edits by diffing against snapshot first
339 //
340 // Cloud returns entire snippets/excerpts ranges as they were included
341 // in the request, but we should display smaller edits to the user.
342 //
343 // We can do this by computing a diff of each one against the snapshot.
344 // Similar to zeta::Zeta::compute_edits, but per edit.
345 let edits = response
346 .edits
347 .into_iter()
348 .map(|edit| {
349 // TODO edits to different files
350 (
351 snapshot.anchor_before(edit.range.start)
352 ..snapshot.anchor_before(edit.range.end),
353 edit.content,
354 )
355 })
356 .collect::<Vec<_>>()
357 .into();
358
359 let Some((edits, snapshot, edit_preview_task)) =
360 buffer.read_with(cx, |buffer, cx| {
361 let new_snapshot = buffer.snapshot();
362 let edits: Arc<[_]> =
363 interpolate(&snapshot, &new_snapshot, edits)?.into();
364 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
365 })?
366 else {
367 return Ok(None);
368 };
369
370 Ok(Some(EditPrediction {
371 id: EditPredictionId(response.request_id),
372 edits,
373 snapshot,
374 edit_preview: edit_preview_task.await,
375 }))
376 }
377 Ok(None) => Ok(None),
378 Err(err) => {
379 if err.is::<ZedUpdateRequiredError>() {
380 cx.update(|cx| {
381 this.update(cx, |this, _cx| {
382 this.update_required = true;
383 })
384 .ok();
385
386 let error_message: SharedString = err.to_string().into();
387 show_app_notification(
388 NotificationId::unique::<ZedUpdateRequiredError>(),
389 cx,
390 move |cx| {
391 cx.new(|cx| {
392 ErrorMessagePrompt::new(error_message.clone(), cx)
393 .with_link_button(
394 "Update Zed",
395 "https://zed.dev/releases",
396 )
397 })
398 },
399 );
400 })
401 .ok();
402 }
403
404 Err(err)
405 }
406 }
407 })
408 }
409
410 async fn perform_request(
411 client: Arc<Client>,
412 llm_token: LlmApiToken,
413 app_version: SemanticVersion,
414 request: predict_edits_v3::PredictEditsRequest,
415 ) -> Result<(
416 predict_edits_v3::PredictEditsResponse,
417 Option<EditPredictionUsage>,
418 )> {
419 let http_client = client.http_client();
420 let mut token = llm_token.acquire(&client).await?;
421 let mut did_retry = false;
422
423 loop {
424 let request_builder = http_client::Request::builder().method(Method::POST);
425 let request_builder =
426 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
427 request_builder.uri(predict_edits_url)
428 } else {
429 request_builder.uri(
430 http_client
431 .build_zed_llm_url("/predict_edits/v3", &[])?
432 .as_ref(),
433 )
434 };
435 let request = request_builder
436 .header("Content-Type", "application/json")
437 .header("Authorization", format!("Bearer {}", token))
438 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
439 .body(serde_json::to_string(&request)?.into())?;
440
441 let mut response = http_client.send(request).await?;
442
443 if let Some(minimum_required_version) = response
444 .headers()
445 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
446 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
447 {
448 anyhow::ensure!(
449 app_version >= minimum_required_version,
450 ZedUpdateRequiredError {
451 minimum_version: minimum_required_version
452 }
453 );
454 }
455
456 if response.status().is_success() {
457 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
458
459 let mut body = Vec::new();
460 response.body_mut().read_to_end(&mut body).await?;
461 return Ok((serde_json::from_slice(&body)?, usage));
462 } else if !did_retry
463 && response
464 .headers()
465 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
466 .is_some()
467 {
468 did_retry = true;
469 token = llm_token.refresh(&client).await?;
470 } else {
471 let mut body = String::new();
472 response.body_mut().read_to_string(&mut body).await?;
473 anyhow::bail!(
474 "error predicting edits.\nStatus: {:?}\nBody: {}",
475 response.status(),
476 body
477 );
478 }
479 }
480 }
481}
482
483#[derive(Error, Debug)]
484#[error(
485 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
486)]
487pub struct ZedUpdateRequiredError {
488 minimum_version: SemanticVersion,
489}
490
491pub struct ZetaEditPredictionProvider {
492 zeta: Entity<Zeta>,
493 current_prediction: Option<CurrentEditPrediction>,
494 next_pending_prediction_id: usize,
495 pending_predictions: ArrayVec<PendingPrediction, 2>,
496 last_request_timestamp: Instant,
497}
498
499impl ZetaEditPredictionProvider {
500 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
501
502 pub fn new(
503 project: Option<&Entity<Project>>,
504 client: &Arc<Client>,
505 user_store: &Entity<UserStore>,
506 cx: &mut App,
507 ) -> Self {
508 let zeta = Zeta::global(client, user_store, cx);
509 if let Some(project) = project {
510 zeta.update(cx, |zeta, cx| {
511 zeta.register_project(project, cx);
512 });
513 }
514
515 Self {
516 zeta,
517 current_prediction: None,
518 next_pending_prediction_id: 0,
519 pending_predictions: ArrayVec::new(),
520 last_request_timestamp: Instant::now(),
521 }
522 }
523}
524
525#[derive(Clone)]
526struct CurrentEditPrediction {
527 buffer_id: EntityId,
528 prediction: EditPrediction,
529}
530
531impl CurrentEditPrediction {
532 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
533 if self.buffer_id != old_prediction.buffer_id {
534 return true;
535 }
536
537 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
538 return true;
539 };
540 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
541 return false;
542 };
543
544 if old_edits.len() == 1 && new_edits.len() == 1 {
545 let (old_range, old_text) = &old_edits[0];
546 let (new_range, new_text) = &new_edits[0];
547 new_range == old_range && new_text.starts_with(old_text)
548 } else {
549 true
550 }
551 }
552}
553
554#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
555pub struct EditPredictionId(Uuid);
556
557impl From<EditPredictionId> for gpui::ElementId {
558 fn from(value: EditPredictionId) -> Self {
559 gpui::ElementId::Uuid(value.0)
560 }
561}
562
563impl std::fmt::Display for EditPredictionId {
564 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565 write!(f, "{}", self.0)
566 }
567}
568
569#[derive(Clone)]
570pub struct EditPrediction {
571 id: EditPredictionId,
572 edits: Arc<[(Range<Anchor>, String)]>,
573 snapshot: BufferSnapshot,
574 edit_preview: EditPreview,
575}
576
577impl EditPrediction {
578 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
579 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
580 }
581}
582
583struct PendingPrediction {
584 id: usize,
585 _task: Task<()>,
586}
587
588impl EditPredictionProvider for ZetaEditPredictionProvider {
589 fn name() -> &'static str {
590 "zed-predict2"
591 }
592
593 fn display_name() -> &'static str {
594 "Zed's Edit Predictions 2"
595 }
596
597 fn show_completions_in_menu() -> bool {
598 true
599 }
600
601 fn show_tab_accept_marker() -> bool {
602 true
603 }
604
605 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
606 // TODO [zeta2]
607 DataCollectionState::Unsupported
608 }
609
610 fn toggle_data_collection(&mut self, _cx: &mut App) {
611 // TODO [zeta2]
612 }
613
614 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
615 self.zeta.read(cx).usage(cx)
616 }
617
618 fn is_enabled(
619 &self,
620 _buffer: &Entity<language::Buffer>,
621 _cursor_position: language::Anchor,
622 _cx: &App,
623 ) -> bool {
624 true
625 }
626
627 fn is_refreshing(&self) -> bool {
628 !self.pending_predictions.is_empty()
629 }
630
631 fn refresh(
632 &mut self,
633 project: Option<Entity<project::Project>>,
634 buffer: Entity<language::Buffer>,
635 cursor_position: language::Anchor,
636 _debounce: bool,
637 cx: &mut Context<Self>,
638 ) {
639 let Some(project) = project else {
640 return;
641 };
642
643 if self
644 .zeta
645 .read(cx)
646 .user_store
647 .read_with(cx, |user_store, _cx| {
648 user_store.account_too_young() || user_store.has_overdue_invoices()
649 })
650 {
651 return;
652 }
653
654 if let Some(current_prediction) = self.current_prediction.as_ref() {
655 let snapshot = buffer.read(cx).snapshot();
656 if current_prediction
657 .prediction
658 .interpolate(&snapshot)
659 .is_some()
660 {
661 return;
662 }
663 }
664
665 let pending_prediction_id = self.next_pending_prediction_id;
666 self.next_pending_prediction_id += 1;
667 let last_request_timestamp = self.last_request_timestamp;
668
669 let task = cx.spawn(async move |this, cx| {
670 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
671 .checked_duration_since(Instant::now())
672 {
673 cx.background_executor().timer(timeout).await;
674 }
675
676 let prediction_request = this.update(cx, |this, cx| {
677 this.last_request_timestamp = Instant::now();
678 this.zeta.update(cx, |zeta, cx| {
679 zeta.request_prediction(&project, &buffer, cursor_position, cx)
680 })
681 });
682
683 let prediction = match prediction_request {
684 Ok(prediction_request) => {
685 let prediction_request = prediction_request.await;
686 prediction_request.map(|c| {
687 c.map(|prediction| CurrentEditPrediction {
688 buffer_id: buffer.entity_id(),
689 prediction,
690 })
691 })
692 }
693 Err(error) => Err(error),
694 };
695
696 this.update(cx, |this, cx| {
697 if this.pending_predictions[0].id == pending_prediction_id {
698 this.pending_predictions.remove(0);
699 } else {
700 this.pending_predictions.clear();
701 }
702
703 let Some(new_prediction) = prediction
704 .context("edit prediction failed")
705 .log_err()
706 .flatten()
707 else {
708 cx.notify();
709 return;
710 };
711
712 if let Some(old_prediction) = this.current_prediction.as_ref() {
713 let snapshot = buffer.read(cx).snapshot();
714 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
715 this.current_prediction = Some(new_prediction);
716 }
717 } else {
718 this.current_prediction = Some(new_prediction);
719 }
720
721 cx.notify();
722 })
723 .ok();
724 });
725
726 // We always maintain at most two pending predictions. When we already
727 // have two, we replace the newest one.
728 if self.pending_predictions.len() <= 1 {
729 self.pending_predictions.push(PendingPrediction {
730 id: pending_prediction_id,
731 _task: task,
732 });
733 } else if self.pending_predictions.len() == 2 {
734 self.pending_predictions.pop();
735 self.pending_predictions.push(PendingPrediction {
736 id: pending_prediction_id,
737 _task: task,
738 });
739 }
740
741 cx.notify();
742 }
743
744 fn cycle(
745 &mut self,
746 _buffer: Entity<language::Buffer>,
747 _cursor_position: language::Anchor,
748 _direction: Direction,
749 _cx: &mut Context<Self>,
750 ) {
751 }
752
753 fn accept(&mut self, _cx: &mut Context<Self>) {
754 // TODO [zeta2] report accept
755 self.current_prediction.take();
756 self.pending_predictions.clear();
757 }
758
759 fn discard(&mut self, _cx: &mut Context<Self>) {
760 self.pending_predictions.clear();
761 self.current_prediction.take();
762 }
763
764 fn suggest(
765 &mut self,
766 buffer: &Entity<language::Buffer>,
767 cursor_position: language::Anchor,
768 cx: &mut Context<Self>,
769 ) -> Option<edit_prediction::EditPrediction> {
770 let CurrentEditPrediction {
771 buffer_id,
772 prediction,
773 ..
774 } = self.current_prediction.as_mut()?;
775
776 // Invalidate previous prediction if it was generated for a different buffer.
777 if *buffer_id != buffer.entity_id() {
778 self.current_prediction.take();
779 return None;
780 }
781
782 let buffer = buffer.read(cx);
783 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
784 self.current_prediction.take();
785 return None;
786 };
787
788 let cursor_row = cursor_position.to_point(buffer).row;
789 let (closest_edit_ix, (closest_edit_range, _)) =
790 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
791 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
792 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
793 cmp::min(distance_from_start, distance_from_end)
794 })?;
795
796 let mut edit_start_ix = closest_edit_ix;
797 for (range, _) in edits[..edit_start_ix].iter().rev() {
798 let distance_from_closest_edit =
799 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
800 if distance_from_closest_edit <= 1 {
801 edit_start_ix -= 1;
802 } else {
803 break;
804 }
805 }
806
807 let mut edit_end_ix = closest_edit_ix + 1;
808 for (range, _) in &edits[edit_end_ix..] {
809 let distance_from_closest_edit =
810 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
811 if distance_from_closest_edit <= 1 {
812 edit_end_ix += 1;
813 } else {
814 break;
815 }
816 }
817
818 Some(edit_prediction::EditPrediction {
819 id: Some(prediction.id.to_string().into()),
820 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
821 edit_preview: Some(prediction.edit_preview.clone()),
822 })
823 }
824}
825
826fn make_cloud_request(
827 excerpt_path: PathBuf,
828 context: EditPredictionContext,
829 events: Vec<predict_edits_v3::Event>,
830 can_collect_data: bool,
831 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
832 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
833 debug_info: bool,
834 worktrees: &Vec<worktree::Snapshot>,
835 index_state: Option<&SyntaxIndexState>,
836) -> predict_edits_v3::PredictEditsRequest {
837 let mut signatures = Vec::new();
838 let mut declaration_to_signature_index = HashMap::default();
839 let mut referenced_declarations = Vec::new();
840
841 for snippet in context.snippets {
842 let project_entry_id = snippet.declaration.project_entry_id();
843 // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
844 // Note that currently full_path is currently being used for excerpt_path.
845 let Some(path) = worktrees.iter().find_map(|worktree| {
846 let abs_path = worktree.abs_path();
847 worktree
848 .entry_for_id(project_entry_id)
849 .map(|e| abs_path.join(&e.path))
850 }) else {
851 continue;
852 };
853
854 let parent_index = index_state.and_then(|index_state| {
855 snippet.declaration.parent().and_then(|parent| {
856 add_signature(
857 parent,
858 &mut declaration_to_signature_index,
859 &mut signatures,
860 index_state,
861 )
862 })
863 });
864
865 let (text, text_is_truncated) = snippet.declaration.item_text();
866 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
867 path,
868 text: text.into(),
869 range: snippet.declaration.item_range(),
870 text_is_truncated,
871 signature_range: snippet.declaration.signature_range_in_item_text(),
872 parent_index,
873 score_components: snippet.score_components,
874 signature_score: snippet.scores.signature,
875 declaration_score: snippet.scores.declaration,
876 });
877 }
878
879 let excerpt_parent = index_state.and_then(|index_state| {
880 context
881 .excerpt
882 .parent_declarations
883 .last()
884 .and_then(|(parent, _)| {
885 add_signature(
886 *parent,
887 &mut declaration_to_signature_index,
888 &mut signatures,
889 index_state,
890 )
891 })
892 });
893
894 predict_edits_v3::PredictEditsRequest {
895 excerpt_path,
896 excerpt: context.excerpt_text.body,
897 excerpt_range: context.excerpt.range,
898 cursor_offset: context.cursor_offset_in_excerpt,
899 referenced_declarations,
900 signatures,
901 excerpt_parent,
902 events,
903 can_collect_data,
904 diagnostic_groups,
905 git_info,
906 debug_info,
907 }
908}
909
910fn add_signature(
911 declaration_id: DeclarationId,
912 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
913 signatures: &mut Vec<Signature>,
914 index: &SyntaxIndexState,
915) -> Option<usize> {
916 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
917 return Some(*signature_index);
918 }
919 let Some(parent_declaration) = index.declaration(declaration_id) else {
920 log::error!("bug: missing parent declaration");
921 return None;
922 };
923 let parent_index = parent_declaration.parent().and_then(|parent| {
924 add_signature(parent, declaration_to_signature_index, signatures, index)
925 });
926 let (text, text_is_truncated) = parent_declaration.signature_text();
927 let signature_index = signatures.len();
928 signatures.push(Signature {
929 text: text.into(),
930 text_is_truncated,
931 parent_index,
932 });
933 declaration_to_signature_index.insert(declaration_id, signature_index);
934 Some(signature_index)
935}
936
937fn interpolate(
938 old_snapshot: &BufferSnapshot,
939 new_snapshot: &BufferSnapshot,
940 current_edits: Arc<[(Range<Anchor>, String)]>,
941) -> Option<Vec<(Range<Anchor>, String)>> {
942 let mut edits = Vec::new();
943
944 let mut model_edits = current_edits.iter().peekable();
945 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
946 while let Some((model_old_range, _)) = model_edits.peek() {
947 let model_old_range = model_old_range.to_offset(old_snapshot);
948 if model_old_range.end < user_edit.old.start {
949 let (model_old_range, model_new_text) = model_edits.next().unwrap();
950 edits.push((model_old_range.clone(), model_new_text.clone()));
951 } else {
952 break;
953 }
954 }
955
956 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
957 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
958 if user_edit.old == model_old_offset_range {
959 let user_new_text = new_snapshot
960 .text_for_range(user_edit.new.clone())
961 .collect::<String>();
962
963 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
964 if !model_suffix.is_empty() {
965 let anchor = old_snapshot.anchor_after(user_edit.old.end);
966 edits.push((anchor..anchor, model_suffix.to_string()));
967 }
968
969 model_edits.next();
970 continue;
971 }
972 }
973 }
974
975 return None;
976 }
977
978 edits.extend(model_edits.cloned());
979
980 if edits.is_empty() { None } else { Some(edits) }
981}
982
983#[cfg(test)]
984mod tests {
985 use super::*;
986 use gpui::TestAppContext;
987 use language::ToOffset as _;
988
989 #[gpui::test]
990 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
991 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
992 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
993 to_prediction_edits(
994 [(2..5, "REM".to_string()), (9..11, "".to_string())],
995 &buffer,
996 cx,
997 )
998 .into()
999 });
1000
1001 let edit_preview = cx
1002 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1003 .await;
1004
1005 let prediction = EditPrediction {
1006 id: EditPredictionId(Uuid::new_v4()),
1007 edits,
1008 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1009 edit_preview,
1010 };
1011
1012 cx.update(|cx| {
1013 assert_eq!(
1014 from_prediction_edits(
1015 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1016 &buffer,
1017 cx
1018 ),
1019 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1020 );
1021
1022 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1023 assert_eq!(
1024 from_prediction_edits(
1025 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1026 &buffer,
1027 cx
1028 ),
1029 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1030 );
1031
1032 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1033 assert_eq!(
1034 from_prediction_edits(
1035 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1036 &buffer,
1037 cx
1038 ),
1039 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1040 );
1041
1042 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1043 assert_eq!(
1044 from_prediction_edits(
1045 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1046 &buffer,
1047 cx
1048 ),
1049 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1050 );
1051
1052 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1053 assert_eq!(
1054 from_prediction_edits(
1055 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1056 &buffer,
1057 cx
1058 ),
1059 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1060 );
1061
1062 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1063 assert_eq!(
1064 from_prediction_edits(
1065 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1066 &buffer,
1067 cx
1068 ),
1069 vec![(9..11, "".to_string())]
1070 );
1071
1072 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1073 assert_eq!(
1074 from_prediction_edits(
1075 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1076 &buffer,
1077 cx
1078 ),
1079 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1080 );
1081
1082 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1083 assert_eq!(
1084 from_prediction_edits(
1085 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1086 &buffer,
1087 cx
1088 ),
1089 vec![(4..4, "M".to_string())]
1090 );
1091
1092 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1093 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1094 })
1095 }
1096
1097 fn to_prediction_edits(
1098 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1099 buffer: &Entity<Buffer>,
1100 cx: &App,
1101 ) -> Vec<(Range<Anchor>, String)> {
1102 let buffer = buffer.read(cx);
1103 iterator
1104 .into_iter()
1105 .map(|(range, text)| {
1106 (
1107 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1108 text,
1109 )
1110 })
1111 .collect()
1112 }
1113
1114 fn from_prediction_edits(
1115 editor_edits: &[(Range<Anchor>, String)],
1116 buffer: &Entity<Buffer>,
1117 cx: &App,
1118 ) -> Vec<(Range<usize>, String)> {
1119 let buffer = buffer.read(cx);
1120 editor_edits
1121 .iter()
1122 .map(|(range, text)| {
1123 (
1124 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1125 text.clone(),
1126 )
1127 })
1128 .collect()
1129 }
1130}