1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use chrono::TimeDelta;
4use client::{Client, EditPredictionUsage, UserStore};
5use cloud_llm_client::predict_edits_v3::{self, Signature};
6use cloud_llm_client::{
7 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
8};
9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
10use edit_prediction_context::{
11 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
12 SyntaxIndexState,
13};
14use futures::AsyncReadExt as _;
15use futures::channel::mpsc;
16use gpui::http_client::Method;
17use gpui::{
18 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
19 http_client, prelude::*,
20};
21use language::{
22 Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
23};
24use language::{BufferSnapshot, EditPreview};
25use language_model::{LlmApiToken, RefreshLlmTokenListener};
26use project::Project;
27use release_channel::AppVersion;
28use std::cmp;
29use std::collections::{HashMap, VecDeque, hash_map};
30use std::path::PathBuf;
31use std::str::FromStr as _;
32use std::time::{Duration, Instant};
33use std::{ops::Range, sync::Arc};
34use thiserror::Error;
35use util::{ResultExt as _, some_or_debug_panic};
36use uuid::Uuid;
37use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
38
39const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
40
41/// Maximum number of events to track.
42const MAX_EVENT_COUNT: usize = 16;
43
44pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
45 max_bytes: 512,
46 min_bytes: 128,
47 target_before_cursor_over_total_bytes: 0.5,
48};
49
50pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
51 excerpt: DEFAULT_EXCERPT_OPTIONS,
52 max_diagnostic_bytes: 2048,
53};
54
55#[derive(Clone)]
56struct ZetaGlobal(Entity<Zeta>);
57
58impl Global for ZetaGlobal {}
59
60pub struct Zeta {
61 client: Arc<Client>,
62 user_store: Entity<UserStore>,
63 llm_token: LlmApiToken,
64 _llm_token_subscription: Subscription,
65 projects: HashMap<EntityId, ZetaProject>,
66 options: ZetaOptions,
67 update_required: bool,
68 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
69}
70
71#[derive(Debug, Clone, PartialEq)]
72pub struct ZetaOptions {
73 pub excerpt: EditPredictionExcerptOptions,
74 pub max_diagnostic_bytes: usize,
75}
76
77pub struct PredictionDebugInfo {
78 pub context: EditPredictionContext,
79 pub retrieval_time: TimeDelta,
80 pub request: RequestDebugInfo,
81 pub buffer: WeakEntity<Buffer>,
82 pub position: language::Anchor,
83}
84
85pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
86
87struct ZetaProject {
88 syntax_index: Entity<SyntaxIndex>,
89 events: VecDeque<Event>,
90 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
91}
92
93struct RegisteredBuffer {
94 snapshot: BufferSnapshot,
95 _subscriptions: [gpui::Subscription; 2],
96}
97
98#[derive(Clone)]
99pub enum Event {
100 BufferChange {
101 old_snapshot: BufferSnapshot,
102 new_snapshot: BufferSnapshot,
103 timestamp: Instant,
104 },
105}
106
107impl Zeta {
108 pub fn global(
109 client: &Arc<Client>,
110 user_store: &Entity<UserStore>,
111 cx: &mut App,
112 ) -> Entity<Self> {
113 cx.try_global::<ZetaGlobal>()
114 .map(|global| global.0.clone())
115 .unwrap_or_else(|| {
116 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
117 cx.set_global(ZetaGlobal(zeta.clone()));
118 zeta
119 })
120 }
121
122 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
123 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
124
125 Self {
126 projects: HashMap::new(),
127 client,
128 user_store,
129 options: DEFAULT_OPTIONS,
130 llm_token: LlmApiToken::default(),
131 _llm_token_subscription: cx.subscribe(
132 &refresh_llm_token_listener,
133 |this, _listener, _event, cx| {
134 let client = this.client.clone();
135 let llm_token = this.llm_token.clone();
136 cx.spawn(async move |_this, _cx| {
137 llm_token.refresh(&client).await?;
138 anyhow::Ok(())
139 })
140 .detach_and_log_err(cx);
141 },
142 ),
143 update_required: false,
144 debug_tx: None,
145 }
146 }
147
148 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
149 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
150 self.debug_tx = Some(debug_watch_tx);
151 debug_watch_rx
152 }
153
154 pub fn options(&self) -> &ZetaOptions {
155 &self.options
156 }
157
158 pub fn set_options(&mut self, options: ZetaOptions) {
159 self.options = options;
160 }
161
162 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
163 self.user_store.read(cx).edit_prediction_usage()
164 }
165
166 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
167 self.get_or_init_zeta_project(project, cx);
168 }
169
170 pub fn register_buffer(
171 &mut self,
172 buffer: &Entity<Buffer>,
173 project: &Entity<Project>,
174 cx: &mut Context<Self>,
175 ) {
176 let zeta_project = self.get_or_init_zeta_project(project, cx);
177 Self::register_buffer_impl(zeta_project, buffer, project, cx);
178 }
179
180 fn get_or_init_zeta_project(
181 &mut self,
182 project: &Entity<Project>,
183 cx: &mut App,
184 ) -> &mut ZetaProject {
185 self.projects
186 .entry(project.entity_id())
187 .or_insert_with(|| ZetaProject {
188 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
189 events: VecDeque::new(),
190 registered_buffers: HashMap::new(),
191 })
192 }
193
194 fn register_buffer_impl<'a>(
195 zeta_project: &'a mut ZetaProject,
196 buffer: &Entity<Buffer>,
197 project: &Entity<Project>,
198 cx: &mut Context<Self>,
199 ) -> &'a mut RegisteredBuffer {
200 let buffer_id = buffer.entity_id();
201 match zeta_project.registered_buffers.entry(buffer_id) {
202 hash_map::Entry::Occupied(entry) => entry.into_mut(),
203 hash_map::Entry::Vacant(entry) => {
204 let snapshot = buffer.read(cx).snapshot();
205 let project_entity_id = project.entity_id();
206 entry.insert(RegisteredBuffer {
207 snapshot,
208 _subscriptions: [
209 cx.subscribe(buffer, {
210 let project = project.downgrade();
211 move |this, buffer, event, cx| {
212 if let language::BufferEvent::Edited = event
213 && let Some(project) = project.upgrade()
214 {
215 this.report_changes_for_buffer(&buffer, &project, cx);
216 }
217 }
218 }),
219 cx.observe_release(buffer, move |this, _buffer, _cx| {
220 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
221 else {
222 return;
223 };
224 zeta_project.registered_buffers.remove(&buffer_id);
225 }),
226 ],
227 })
228 }
229 }
230 }
231
232 fn report_changes_for_buffer(
233 &mut self,
234 buffer: &Entity<Buffer>,
235 project: &Entity<Project>,
236 cx: &mut Context<Self>,
237 ) -> BufferSnapshot {
238 let zeta_project = self.get_or_init_zeta_project(project, cx);
239 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
240
241 let new_snapshot = buffer.read(cx).snapshot();
242 if new_snapshot.version != registered_buffer.snapshot.version {
243 let old_snapshot =
244 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
245 Self::push_event(
246 zeta_project,
247 Event::BufferChange {
248 old_snapshot,
249 new_snapshot: new_snapshot.clone(),
250 timestamp: Instant::now(),
251 },
252 );
253 }
254
255 new_snapshot
256 }
257
258 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
259 let events = &mut zeta_project.events;
260
261 if let Some(Event::BufferChange {
262 new_snapshot: last_new_snapshot,
263 timestamp: last_timestamp,
264 ..
265 }) = events.back_mut()
266 {
267 // Coalesce edits for the same buffer when they happen one after the other.
268 let Event::BufferChange {
269 old_snapshot,
270 new_snapshot,
271 timestamp,
272 } = &event;
273
274 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
275 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
276 && old_snapshot.version == last_new_snapshot.version
277 {
278 *last_new_snapshot = new_snapshot.clone();
279 *last_timestamp = *timestamp;
280 return;
281 }
282 }
283
284 if events.len() >= MAX_EVENT_COUNT {
285 // These are halved instead of popping to improve prompt caching.
286 events.drain(..MAX_EVENT_COUNT / 2);
287 }
288
289 events.push_back(event);
290 }
291
292 pub fn request_prediction(
293 &mut self,
294 project: &Entity<Project>,
295 buffer: &Entity<Buffer>,
296 position: language::Anchor,
297 cx: &mut Context<Self>,
298 ) -> Task<Result<Option<EditPrediction>>> {
299 let project_state = self.projects.get(&project.entity_id());
300
301 let index_state = project_state.map(|state| {
302 state
303 .syntax_index
304 .read_with(cx, |index, _cx| index.state().clone())
305 });
306 let options = self.options.clone();
307 let snapshot = buffer.read(cx).snapshot();
308 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
309 return Task::ready(Err(anyhow!("No file path for excerpt")));
310 };
311 let client = self.client.clone();
312 let llm_token = self.llm_token.clone();
313 let app_version = AppVersion::global(cx);
314 let worktree_snapshots = project
315 .read(cx)
316 .worktrees(cx)
317 .map(|worktree| worktree.read(cx).snapshot())
318 .collect::<Vec<_>>();
319 let debug_tx = self.debug_tx.clone();
320
321 let events = project_state
322 .map(|state| {
323 state
324 .events
325 .iter()
326 .map(|event| match event {
327 Event::BufferChange {
328 old_snapshot,
329 new_snapshot,
330 ..
331 } => {
332 let path = new_snapshot.file().map(|f| f.path().to_path_buf());
333
334 let old_path = old_snapshot.file().and_then(|f| {
335 let old_path = f.path().as_ref();
336 if Some(old_path) != path.as_deref() {
337 Some(old_path.to_path_buf())
338 } else {
339 None
340 }
341 });
342
343 predict_edits_v3::Event::BufferChange {
344 old_path,
345 path,
346 diff: language::unified_diff(
347 &old_snapshot.text(),
348 &new_snapshot.text(),
349 ),
350 //todo: Actually detect if this edit was predicted or not
351 predicted: false,
352 }
353 }
354 })
355 .collect::<Vec<_>>()
356 })
357 .unwrap_or_default();
358
359 let diagnostics = snapshot.diagnostic_sets().clone();
360
361 let request_task = cx.background_spawn({
362 let snapshot = snapshot.clone();
363 let buffer = buffer.clone();
364 async move {
365 let index_state = if let Some(index_state) = index_state {
366 Some(index_state.lock_owned().await)
367 } else {
368 None
369 };
370
371 let cursor_offset = position.to_offset(&snapshot);
372 let cursor_point = cursor_offset.to_point(&snapshot);
373
374 let before_retrieval = chrono::Utc::now();
375
376 let Some(context) = EditPredictionContext::gather_context(
377 cursor_point,
378 &snapshot,
379 &options.excerpt,
380 index_state.as_deref(),
381 ) else {
382 return Ok(None);
383 };
384
385 let debug_context = if let Some(debug_tx) = debug_tx {
386 Some((debug_tx, context.clone()))
387 } else {
388 None
389 };
390
391 let (diagnostic_groups, diagnostic_groups_truncated) =
392 Self::gather_nearby_diagnostics(
393 cursor_offset,
394 &diagnostics,
395 &snapshot,
396 options.max_diagnostic_bytes,
397 );
398
399 let request = make_cloud_request(
400 excerpt_path.clone(),
401 context,
402 events,
403 // TODO data collection
404 false,
405 diagnostic_groups,
406 diagnostic_groups_truncated,
407 None,
408 debug_context.is_some(),
409 &worktree_snapshots,
410 index_state.as_deref(),
411 );
412
413 let retrieval_time = chrono::Utc::now() - before_retrieval;
414 let response = Self::perform_request(client, llm_token, app_version, request).await;
415
416 if let Some((debug_tx, context)) = debug_context {
417 debug_tx
418 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
419 |response| {
420 let Some(request) =
421 some_or_debug_panic(response.0.debug_info.clone())
422 else {
423 return Err("Missing debug info".to_string());
424 };
425 Ok(PredictionDebugInfo {
426 context,
427 request,
428 retrieval_time,
429 buffer: buffer.downgrade(),
430 position,
431 })
432 },
433 ))
434 .ok();
435 }
436
437 anyhow::Ok(Some(response?))
438 }
439 });
440
441 let buffer = buffer.clone();
442
443 cx.spawn(async move |this, cx| {
444 match request_task.await {
445 Ok(Some((response, usage))) => {
446 log::debug!("predicted edits: {:?}", &response.edits);
447
448 if let Some(usage) = usage {
449 this.update(cx, |this, cx| {
450 this.user_store.update(cx, |user_store, cx| {
451 user_store.update_edit_prediction_usage(usage, cx);
452 });
453 })
454 .ok();
455 }
456
457 // TODO telemetry: duration, etc
458
459 // TODO produce smaller edits by diffing against snapshot first
460 //
461 // Cloud returns entire snippets/excerpts ranges as they were included
462 // in the request, but we should display smaller edits to the user.
463 //
464 // We can do this by computing a diff of each one against the snapshot.
465 // Similar to zeta::Zeta::compute_edits, but per edit.
466 let edits = response
467 .edits
468 .into_iter()
469 .map(|edit| {
470 // TODO edits to different files
471 (
472 snapshot.anchor_before(edit.range.start)
473 ..snapshot.anchor_before(edit.range.end),
474 edit.content,
475 )
476 })
477 .collect::<Vec<_>>()
478 .into();
479
480 let Some((edits, snapshot, edit_preview_task)) =
481 buffer.read_with(cx, |buffer, cx| {
482 let new_snapshot = buffer.snapshot();
483 let edits: Arc<[_]> =
484 interpolate(&snapshot, &new_snapshot, edits)?.into();
485 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
486 })?
487 else {
488 return Ok(None);
489 };
490
491 Ok(Some(EditPrediction {
492 id: EditPredictionId(response.request_id),
493 edits,
494 snapshot,
495 edit_preview: edit_preview_task.await,
496 }))
497 }
498 Ok(None) => Ok(None),
499 Err(err) => {
500 if err.is::<ZedUpdateRequiredError>() {
501 cx.update(|cx| {
502 this.update(cx, |this, _cx| {
503 this.update_required = true;
504 })
505 .ok();
506
507 let error_message: SharedString = err.to_string().into();
508 show_app_notification(
509 NotificationId::unique::<ZedUpdateRequiredError>(),
510 cx,
511 move |cx| {
512 cx.new(|cx| {
513 ErrorMessagePrompt::new(error_message.clone(), cx)
514 .with_link_button(
515 "Update Zed",
516 "https://zed.dev/releases",
517 )
518 })
519 },
520 );
521 })
522 .ok();
523 }
524
525 Err(err)
526 }
527 }
528 })
529 }
530
531 async fn perform_request(
532 client: Arc<Client>,
533 llm_token: LlmApiToken,
534 app_version: SemanticVersion,
535 request: predict_edits_v3::PredictEditsRequest,
536 ) -> Result<(
537 predict_edits_v3::PredictEditsResponse,
538 Option<EditPredictionUsage>,
539 )> {
540 let http_client = client.http_client();
541 let mut token = llm_token.acquire(&client).await?;
542 let mut did_retry = false;
543
544 loop {
545 let request_builder = http_client::Request::builder().method(Method::POST);
546 let request_builder =
547 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
548 request_builder.uri(predict_edits_url)
549 } else {
550 request_builder.uri(
551 http_client
552 .build_zed_llm_url("/predict_edits/v3", &[])?
553 .as_ref(),
554 )
555 };
556 let request = request_builder
557 .header("Content-Type", "application/json")
558 .header("Authorization", format!("Bearer {}", token))
559 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
560 .body(serde_json::to_string(&request)?.into())?;
561
562 let mut response = http_client.send(request).await?;
563
564 if let Some(minimum_required_version) = response
565 .headers()
566 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
567 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
568 {
569 anyhow::ensure!(
570 app_version >= minimum_required_version,
571 ZedUpdateRequiredError {
572 minimum_version: minimum_required_version
573 }
574 );
575 }
576
577 if response.status().is_success() {
578 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
579
580 let mut body = Vec::new();
581 response.body_mut().read_to_end(&mut body).await?;
582 return Ok((serde_json::from_slice(&body)?, usage));
583 } else if !did_retry
584 && response
585 .headers()
586 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
587 .is_some()
588 {
589 did_retry = true;
590 token = llm_token.refresh(&client).await?;
591 } else {
592 let mut body = String::new();
593 response.body_mut().read_to_string(&mut body).await?;
594 anyhow::bail!(
595 "error predicting edits.\nStatus: {:?}\nBody: {}",
596 response.status(),
597 body
598 );
599 }
600 }
601 }
602
603 fn gather_nearby_diagnostics(
604 cursor_offset: usize,
605 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
606 snapshot: &BufferSnapshot,
607 max_diagnostics_bytes: usize,
608 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
609 // TODO: Could make this more efficient
610 let mut diagnostic_groups = Vec::new();
611 for (language_server_id, diagnostics) in diagnostic_sets {
612 let mut groups = Vec::new();
613 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
614 diagnostic_groups.extend(
615 groups
616 .into_iter()
617 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
618 );
619 }
620
621 // sort by proximity to cursor
622 diagnostic_groups.sort_by_key(|group| {
623 let range = &group.entries[group.primary_ix].range;
624 if range.start >= cursor_offset {
625 range.start - cursor_offset
626 } else if cursor_offset >= range.end {
627 cursor_offset - range.end
628 } else {
629 (cursor_offset - range.start).min(range.end - cursor_offset)
630 }
631 });
632
633 let mut results = Vec::new();
634 let mut diagnostic_groups_truncated = false;
635 let mut diagnostics_byte_count = 0;
636 for group in diagnostic_groups {
637 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
638 diagnostics_byte_count += raw_value.get().len();
639 if diagnostics_byte_count > max_diagnostics_bytes {
640 diagnostic_groups_truncated = true;
641 break;
642 }
643 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
644 }
645
646 (results, diagnostic_groups_truncated)
647 }
648
649 // TODO: Dedupe with similar code in request_prediction?
650 pub fn cloud_request_for_zeta_cli(
651 &mut self,
652 project: &Entity<Project>,
653 buffer: &Entity<Buffer>,
654 position: language::Anchor,
655 cx: &mut Context<Self>,
656 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
657 let project_state = self.projects.get(&project.entity_id());
658
659 let index_state = project_state.map(|state| {
660 state
661 .syntax_index
662 .read_with(cx, |index, _cx| index.state().clone())
663 });
664 let options = self.options.clone();
665 let snapshot = buffer.read(cx).snapshot();
666 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
667 return Task::ready(Err(anyhow!("No file path for excerpt")));
668 };
669 let worktree_snapshots = project
670 .read(cx)
671 .worktrees(cx)
672 .map(|worktree| worktree.read(cx).snapshot())
673 .collect::<Vec<_>>();
674
675 cx.background_spawn(async move {
676 let index_state = if let Some(index_state) = index_state {
677 Some(index_state.lock_owned().await)
678 } else {
679 None
680 };
681
682 let cursor_point = position.to_point(&snapshot);
683
684 let debug_info = true;
685 EditPredictionContext::gather_context(
686 cursor_point,
687 &snapshot,
688 &options.excerpt,
689 index_state.as_deref(),
690 )
691 .context("Failed to select excerpt")
692 .map(|context| {
693 make_cloud_request(
694 excerpt_path.clone(),
695 context,
696 // TODO pass everything
697 Vec::new(),
698 false,
699 Vec::new(),
700 false,
701 None,
702 debug_info,
703 &worktree_snapshots,
704 index_state.as_deref(),
705 )
706 })
707 })
708 }
709}
710
711#[derive(Error, Debug)]
712#[error(
713 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
714)]
715pub struct ZedUpdateRequiredError {
716 minimum_version: SemanticVersion,
717}
718
719pub struct ZetaEditPredictionProvider {
720 zeta: Entity<Zeta>,
721 current_prediction: Option<CurrentEditPrediction>,
722 next_pending_prediction_id: usize,
723 pending_predictions: ArrayVec<PendingPrediction, 2>,
724 last_request_timestamp: Instant,
725}
726
727impl ZetaEditPredictionProvider {
728 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
729
730 pub fn new(
731 project: Option<&Entity<Project>>,
732 client: &Arc<Client>,
733 user_store: &Entity<UserStore>,
734 cx: &mut App,
735 ) -> Self {
736 let zeta = Zeta::global(client, user_store, cx);
737 if let Some(project) = project {
738 zeta.update(cx, |zeta, cx| {
739 zeta.register_project(project, cx);
740 });
741 }
742
743 Self {
744 zeta,
745 current_prediction: None,
746 next_pending_prediction_id: 0,
747 pending_predictions: ArrayVec::new(),
748 last_request_timestamp: Instant::now(),
749 }
750 }
751}
752
753#[derive(Clone)]
754struct CurrentEditPrediction {
755 buffer_id: EntityId,
756 prediction: EditPrediction,
757}
758
759impl CurrentEditPrediction {
760 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
761 if self.buffer_id != old_prediction.buffer_id {
762 return true;
763 }
764
765 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
766 return true;
767 };
768 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
769 return false;
770 };
771
772 if old_edits.len() == 1 && new_edits.len() == 1 {
773 let (old_range, old_text) = &old_edits[0];
774 let (new_range, new_text) = &new_edits[0];
775 new_range == old_range && new_text.starts_with(old_text)
776 } else {
777 true
778 }
779 }
780}
781
782#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
783pub struct EditPredictionId(Uuid);
784
785impl From<EditPredictionId> for gpui::ElementId {
786 fn from(value: EditPredictionId) -> Self {
787 gpui::ElementId::Uuid(value.0)
788 }
789}
790
791impl std::fmt::Display for EditPredictionId {
792 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
793 write!(f, "{}", self.0)
794 }
795}
796
797#[derive(Clone)]
798pub struct EditPrediction {
799 id: EditPredictionId,
800 edits: Arc<[(Range<Anchor>, String)]>,
801 snapshot: BufferSnapshot,
802 edit_preview: EditPreview,
803}
804
805impl EditPrediction {
806 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
807 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
808 }
809}
810
811struct PendingPrediction {
812 id: usize,
813 _task: Task<()>,
814}
815
816impl EditPredictionProvider for ZetaEditPredictionProvider {
817 fn name() -> &'static str {
818 "zed-predict2"
819 }
820
821 fn display_name() -> &'static str {
822 "Zed's Edit Predictions 2"
823 }
824
825 fn show_completions_in_menu() -> bool {
826 true
827 }
828
829 fn show_tab_accept_marker() -> bool {
830 true
831 }
832
833 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
834 // TODO [zeta2]
835 DataCollectionState::Unsupported
836 }
837
838 fn toggle_data_collection(&mut self, _cx: &mut App) {
839 // TODO [zeta2]
840 }
841
842 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
843 self.zeta.read(cx).usage(cx)
844 }
845
846 fn is_enabled(
847 &self,
848 _buffer: &Entity<language::Buffer>,
849 _cursor_position: language::Anchor,
850 _cx: &App,
851 ) -> bool {
852 true
853 }
854
855 fn is_refreshing(&self) -> bool {
856 !self.pending_predictions.is_empty()
857 }
858
859 fn refresh(
860 &mut self,
861 project: Option<Entity<project::Project>>,
862 buffer: Entity<language::Buffer>,
863 cursor_position: language::Anchor,
864 _debounce: bool,
865 cx: &mut Context<Self>,
866 ) {
867 let Some(project) = project else {
868 return;
869 };
870
871 if self
872 .zeta
873 .read(cx)
874 .user_store
875 .read_with(cx, |user_store, _cx| {
876 user_store.account_too_young() || user_store.has_overdue_invoices()
877 })
878 {
879 return;
880 }
881
882 if let Some(current_prediction) = self.current_prediction.as_ref() {
883 let snapshot = buffer.read(cx).snapshot();
884 if current_prediction
885 .prediction
886 .interpolate(&snapshot)
887 .is_some()
888 {
889 return;
890 }
891 }
892
893 let pending_prediction_id = self.next_pending_prediction_id;
894 self.next_pending_prediction_id += 1;
895 let last_request_timestamp = self.last_request_timestamp;
896
897 let task = cx.spawn(async move |this, cx| {
898 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
899 .checked_duration_since(Instant::now())
900 {
901 cx.background_executor().timer(timeout).await;
902 }
903
904 let prediction_request = this.update(cx, |this, cx| {
905 this.last_request_timestamp = Instant::now();
906 this.zeta.update(cx, |zeta, cx| {
907 zeta.request_prediction(&project, &buffer, cursor_position, cx)
908 })
909 });
910
911 let prediction = match prediction_request {
912 Ok(prediction_request) => {
913 let prediction_request = prediction_request.await;
914 prediction_request.map(|c| {
915 c.map(|prediction| CurrentEditPrediction {
916 buffer_id: buffer.entity_id(),
917 prediction,
918 })
919 })
920 }
921 Err(error) => Err(error),
922 };
923
924 this.update(cx, |this, cx| {
925 if this.pending_predictions[0].id == pending_prediction_id {
926 this.pending_predictions.remove(0);
927 } else {
928 this.pending_predictions.clear();
929 }
930
931 let Some(new_prediction) = prediction
932 .context("edit prediction failed")
933 .log_err()
934 .flatten()
935 else {
936 cx.notify();
937 return;
938 };
939
940 if let Some(old_prediction) = this.current_prediction.as_ref() {
941 let snapshot = buffer.read(cx).snapshot();
942 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
943 this.current_prediction = Some(new_prediction);
944 }
945 } else {
946 this.current_prediction = Some(new_prediction);
947 }
948
949 cx.notify();
950 })
951 .ok();
952 });
953
954 // We always maintain at most two pending predictions. When we already
955 // have two, we replace the newest one.
956 if self.pending_predictions.len() <= 1 {
957 self.pending_predictions.push(PendingPrediction {
958 id: pending_prediction_id,
959 _task: task,
960 });
961 } else if self.pending_predictions.len() == 2 {
962 self.pending_predictions.pop();
963 self.pending_predictions.push(PendingPrediction {
964 id: pending_prediction_id,
965 _task: task,
966 });
967 }
968
969 cx.notify();
970 }
971
972 fn cycle(
973 &mut self,
974 _buffer: Entity<language::Buffer>,
975 _cursor_position: language::Anchor,
976 _direction: Direction,
977 _cx: &mut Context<Self>,
978 ) {
979 }
980
981 fn accept(&mut self, _cx: &mut Context<Self>) {
982 // TODO [zeta2] report accept
983 self.current_prediction.take();
984 self.pending_predictions.clear();
985 }
986
987 fn discard(&mut self, _cx: &mut Context<Self>) {
988 self.pending_predictions.clear();
989 self.current_prediction.take();
990 }
991
992 fn suggest(
993 &mut self,
994 buffer: &Entity<language::Buffer>,
995 cursor_position: language::Anchor,
996 cx: &mut Context<Self>,
997 ) -> Option<edit_prediction::EditPrediction> {
998 let CurrentEditPrediction {
999 buffer_id,
1000 prediction,
1001 ..
1002 } = self.current_prediction.as_mut()?;
1003
1004 // Invalidate previous prediction if it was generated for a different buffer.
1005 if *buffer_id != buffer.entity_id() {
1006 self.current_prediction.take();
1007 return None;
1008 }
1009
1010 let buffer = buffer.read(cx);
1011 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
1012 self.current_prediction.take();
1013 return None;
1014 };
1015
1016 let cursor_row = cursor_position.to_point(buffer).row;
1017 let (closest_edit_ix, (closest_edit_range, _)) =
1018 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1019 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1020 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1021 cmp::min(distance_from_start, distance_from_end)
1022 })?;
1023
1024 let mut edit_start_ix = closest_edit_ix;
1025 for (range, _) in edits[..edit_start_ix].iter().rev() {
1026 let distance_from_closest_edit =
1027 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1028 if distance_from_closest_edit <= 1 {
1029 edit_start_ix -= 1;
1030 } else {
1031 break;
1032 }
1033 }
1034
1035 let mut edit_end_ix = closest_edit_ix + 1;
1036 for (range, _) in &edits[edit_end_ix..] {
1037 let distance_from_closest_edit =
1038 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1039 if distance_from_closest_edit <= 1 {
1040 edit_end_ix += 1;
1041 } else {
1042 break;
1043 }
1044 }
1045
1046 Some(edit_prediction::EditPrediction {
1047 id: Some(prediction.id.to_string().into()),
1048 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1049 edit_preview: Some(prediction.edit_preview.clone()),
1050 })
1051 }
1052}
1053
1054fn make_cloud_request(
1055 excerpt_path: PathBuf,
1056 context: EditPredictionContext,
1057 events: Vec<predict_edits_v3::Event>,
1058 can_collect_data: bool,
1059 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1060 diagnostic_groups_truncated: bool,
1061 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1062 debug_info: bool,
1063 worktrees: &Vec<worktree::Snapshot>,
1064 index_state: Option<&SyntaxIndexState>,
1065) -> predict_edits_v3::PredictEditsRequest {
1066 let mut signatures = Vec::new();
1067 let mut declaration_to_signature_index = HashMap::default();
1068 let mut referenced_declarations = Vec::new();
1069
1070 for snippet in context.snippets {
1071 let project_entry_id = snippet.declaration.project_entry_id();
1072 let Some(path) = worktrees.iter().find_map(|worktree| {
1073 worktree.entry_for_id(project_entry_id).map(|entry| {
1074 let mut full_path = PathBuf::new();
1075 full_path.push(worktree.root_name());
1076 full_path.push(&entry.path);
1077 full_path
1078 })
1079 }) else {
1080 continue;
1081 };
1082
1083 let parent_index = index_state.and_then(|index_state| {
1084 snippet.declaration.parent().and_then(|parent| {
1085 add_signature(
1086 parent,
1087 &mut declaration_to_signature_index,
1088 &mut signatures,
1089 index_state,
1090 )
1091 })
1092 });
1093
1094 let (text, text_is_truncated) = snippet.declaration.item_text();
1095 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1096 path,
1097 text: text.into(),
1098 range: snippet.declaration.item_range(),
1099 text_is_truncated,
1100 signature_range: snippet.declaration.signature_range_in_item_text(),
1101 parent_index,
1102 score_components: snippet.score_components,
1103 signature_score: snippet.scores.signature,
1104 declaration_score: snippet.scores.declaration,
1105 });
1106 }
1107
1108 let excerpt_parent = index_state.and_then(|index_state| {
1109 context
1110 .excerpt
1111 .parent_declarations
1112 .last()
1113 .and_then(|(parent, _)| {
1114 add_signature(
1115 *parent,
1116 &mut declaration_to_signature_index,
1117 &mut signatures,
1118 index_state,
1119 )
1120 })
1121 });
1122
1123 predict_edits_v3::PredictEditsRequest {
1124 excerpt_path,
1125 excerpt: context.excerpt_text.body,
1126 excerpt_range: context.excerpt.range,
1127 cursor_offset: context.cursor_offset_in_excerpt,
1128 referenced_declarations,
1129 signatures,
1130 excerpt_parent,
1131 events,
1132 can_collect_data,
1133 diagnostic_groups,
1134 diagnostic_groups_truncated,
1135
1136 git_info,
1137 debug_info,
1138 }
1139}
1140
1141fn add_signature(
1142 declaration_id: DeclarationId,
1143 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1144 signatures: &mut Vec<Signature>,
1145 index: &SyntaxIndexState,
1146) -> Option<usize> {
1147 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1148 return Some(*signature_index);
1149 }
1150 let Some(parent_declaration) = index.declaration(declaration_id) else {
1151 log::error!("bug: missing parent declaration");
1152 return None;
1153 };
1154 let parent_index = parent_declaration.parent().and_then(|parent| {
1155 add_signature(parent, declaration_to_signature_index, signatures, index)
1156 });
1157 let (text, text_is_truncated) = parent_declaration.signature_text();
1158 let signature_index = signatures.len();
1159 signatures.push(Signature {
1160 text: text.into(),
1161 text_is_truncated,
1162 parent_index,
1163 range: parent_declaration.signature_range(),
1164 });
1165 declaration_to_signature_index.insert(declaration_id, signature_index);
1166 Some(signature_index)
1167}
1168
1169fn interpolate(
1170 old_snapshot: &BufferSnapshot,
1171 new_snapshot: &BufferSnapshot,
1172 current_edits: Arc<[(Range<Anchor>, String)]>,
1173) -> Option<Vec<(Range<Anchor>, String)>> {
1174 let mut edits = Vec::new();
1175
1176 let mut model_edits = current_edits.iter().peekable();
1177 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1178 while let Some((model_old_range, _)) = model_edits.peek() {
1179 let model_old_range = model_old_range.to_offset(old_snapshot);
1180 if model_old_range.end < user_edit.old.start {
1181 let (model_old_range, model_new_text) = model_edits.next().unwrap();
1182 edits.push((model_old_range.clone(), model_new_text.clone()));
1183 } else {
1184 break;
1185 }
1186 }
1187
1188 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1189 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1190 if user_edit.old == model_old_offset_range {
1191 let user_new_text = new_snapshot
1192 .text_for_range(user_edit.new.clone())
1193 .collect::<String>();
1194
1195 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1196 if !model_suffix.is_empty() {
1197 let anchor = old_snapshot.anchor_after(user_edit.old.end);
1198 edits.push((anchor..anchor, model_suffix.to_string()));
1199 }
1200
1201 model_edits.next();
1202 continue;
1203 }
1204 }
1205 }
1206
1207 return None;
1208 }
1209
1210 edits.extend(model_edits.cloned());
1211
1212 if edits.is_empty() { None } else { Some(edits) }
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217 use super::*;
1218 use gpui::TestAppContext;
1219
1220 #[gpui::test]
1221 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1222 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1223 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1224 to_prediction_edits(
1225 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1226 &buffer,
1227 cx,
1228 )
1229 .into()
1230 });
1231
1232 let edit_preview = cx
1233 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1234 .await;
1235
1236 let prediction = EditPrediction {
1237 id: EditPredictionId(Uuid::new_v4()),
1238 edits,
1239 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1240 edit_preview,
1241 };
1242
1243 cx.update(|cx| {
1244 assert_eq!(
1245 from_prediction_edits(
1246 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1247 &buffer,
1248 cx
1249 ),
1250 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1251 );
1252
1253 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1254 assert_eq!(
1255 from_prediction_edits(
1256 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1257 &buffer,
1258 cx
1259 ),
1260 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1261 );
1262
1263 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1264 assert_eq!(
1265 from_prediction_edits(
1266 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1267 &buffer,
1268 cx
1269 ),
1270 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1271 );
1272
1273 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1274 assert_eq!(
1275 from_prediction_edits(
1276 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1277 &buffer,
1278 cx
1279 ),
1280 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1281 );
1282
1283 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1284 assert_eq!(
1285 from_prediction_edits(
1286 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1287 &buffer,
1288 cx
1289 ),
1290 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1291 );
1292
1293 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1294 assert_eq!(
1295 from_prediction_edits(
1296 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1297 &buffer,
1298 cx
1299 ),
1300 vec![(9..11, "".to_string())]
1301 );
1302
1303 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1304 assert_eq!(
1305 from_prediction_edits(
1306 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1307 &buffer,
1308 cx
1309 ),
1310 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1311 );
1312
1313 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1314 assert_eq!(
1315 from_prediction_edits(
1316 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1317 &buffer,
1318 cx
1319 ),
1320 vec![(4..4, "M".to_string())]
1321 );
1322
1323 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1324 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1325 })
1326 }
1327
1328 fn to_prediction_edits(
1329 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1330 buffer: &Entity<Buffer>,
1331 cx: &App,
1332 ) -> Vec<(Range<Anchor>, String)> {
1333 let buffer = buffer.read(cx);
1334 iterator
1335 .into_iter()
1336 .map(|(range, text)| {
1337 (
1338 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1339 text,
1340 )
1341 })
1342 .collect()
1343 }
1344
1345 fn from_prediction_edits(
1346 editor_edits: &[(Range<Anchor>, String)],
1347 buffer: &Entity<Buffer>,
1348 cx: &App,
1349 ) -> Vec<(Range<usize>, String)> {
1350 let buffer = buffer.read(cx);
1351 editor_edits
1352 .iter()
1353 .map(|(range, text)| {
1354 (
1355 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1356 text.clone(),
1357 )
1358 })
1359 .collect()
1360 }
1361}