1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use chrono::TimeDelta;
4use client::{Client, EditPredictionUsage, UserStore};
5use cloud_llm_client::predict_edits_v3::{self, Signature};
6use cloud_llm_client::{
7 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
10use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
11use edit_prediction_context::{
12 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
13 SyntaxIndexState,
14};
15use futures::AsyncReadExt as _;
16use futures::channel::mpsc;
17use gpui::http_client::Method;
18use gpui::{
19 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
20 http_client, prelude::*,
21};
22use language::{
23 Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
24};
25use language::{BufferSnapshot, EditPreview};
26use language_model::{LlmApiToken, RefreshLlmTokenListener};
27use project::Project;
28use release_channel::AppVersion;
29use std::cmp;
30use std::collections::{HashMap, VecDeque, hash_map};
31use std::path::PathBuf;
32use std::str::FromStr as _;
33use std::time::{Duration, Instant};
34use std::{ops::Range, sync::Arc};
35use thiserror::Error;
36use util::{ResultExt as _, some_or_debug_panic};
37use uuid::Uuid;
38use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
39
40const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
41
42/// Maximum number of events to track.
43const MAX_EVENT_COUNT: usize = 16;
44
45pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
46 max_bytes: 512,
47 min_bytes: 128,
48 target_before_cursor_over_total_bytes: 0.5,
49};
50
51pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
52 excerpt: DEFAULT_EXCERPT_OPTIONS,
53 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
54 max_diagnostic_bytes: 2048,
55};
56
57#[derive(Clone)]
58struct ZetaGlobal(Entity<Zeta>);
59
60impl Global for ZetaGlobal {}
61
62pub struct Zeta {
63 client: Arc<Client>,
64 user_store: Entity<UserStore>,
65 llm_token: LlmApiToken,
66 _llm_token_subscription: Subscription,
67 projects: HashMap<EntityId, ZetaProject>,
68 options: ZetaOptions,
69 update_required: bool,
70 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
71}
72
73#[derive(Debug, Clone, PartialEq)]
74pub struct ZetaOptions {
75 pub excerpt: EditPredictionExcerptOptions,
76 pub max_prompt_bytes: usize,
77 pub max_diagnostic_bytes: usize,
78}
79
80pub struct PredictionDebugInfo {
81 pub context: EditPredictionContext,
82 pub retrieval_time: TimeDelta,
83 pub request: RequestDebugInfo,
84 pub buffer: WeakEntity<Buffer>,
85 pub position: language::Anchor,
86}
87
88pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
89
90struct ZetaProject {
91 syntax_index: Entity<SyntaxIndex>,
92 events: VecDeque<Event>,
93 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
94}
95
96struct RegisteredBuffer {
97 snapshot: BufferSnapshot,
98 _subscriptions: [gpui::Subscription; 2],
99}
100
101#[derive(Clone)]
102pub enum Event {
103 BufferChange {
104 old_snapshot: BufferSnapshot,
105 new_snapshot: BufferSnapshot,
106 timestamp: Instant,
107 },
108}
109
110impl Zeta {
111 pub fn global(
112 client: &Arc<Client>,
113 user_store: &Entity<UserStore>,
114 cx: &mut App,
115 ) -> Entity<Self> {
116 cx.try_global::<ZetaGlobal>()
117 .map(|global| global.0.clone())
118 .unwrap_or_else(|| {
119 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
120 cx.set_global(ZetaGlobal(zeta.clone()));
121 zeta
122 })
123 }
124
125 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
126 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
127
128 Self {
129 projects: HashMap::new(),
130 client,
131 user_store,
132 options: DEFAULT_OPTIONS,
133 llm_token: LlmApiToken::default(),
134 _llm_token_subscription: cx.subscribe(
135 &refresh_llm_token_listener,
136 |this, _listener, _event, cx| {
137 let client = this.client.clone();
138 let llm_token = this.llm_token.clone();
139 cx.spawn(async move |_this, _cx| {
140 llm_token.refresh(&client).await?;
141 anyhow::Ok(())
142 })
143 .detach_and_log_err(cx);
144 },
145 ),
146 update_required: false,
147 debug_tx: None,
148 }
149 }
150
151 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
152 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
153 self.debug_tx = Some(debug_watch_tx);
154 debug_watch_rx
155 }
156
157 pub fn options(&self) -> &ZetaOptions {
158 &self.options
159 }
160
161 pub fn set_options(&mut self, options: ZetaOptions) {
162 self.options = options;
163 }
164
165 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
166 self.user_store.read(cx).edit_prediction_usage()
167 }
168
169 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
170 self.get_or_init_zeta_project(project, cx);
171 }
172
173 pub fn register_buffer(
174 &mut self,
175 buffer: &Entity<Buffer>,
176 project: &Entity<Project>,
177 cx: &mut Context<Self>,
178 ) {
179 let zeta_project = self.get_or_init_zeta_project(project, cx);
180 Self::register_buffer_impl(zeta_project, buffer, project, cx);
181 }
182
183 fn get_or_init_zeta_project(
184 &mut self,
185 project: &Entity<Project>,
186 cx: &mut App,
187 ) -> &mut ZetaProject {
188 self.projects
189 .entry(project.entity_id())
190 .or_insert_with(|| ZetaProject {
191 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
192 events: VecDeque::new(),
193 registered_buffers: HashMap::new(),
194 })
195 }
196
197 fn register_buffer_impl<'a>(
198 zeta_project: &'a mut ZetaProject,
199 buffer: &Entity<Buffer>,
200 project: &Entity<Project>,
201 cx: &mut Context<Self>,
202 ) -> &'a mut RegisteredBuffer {
203 let buffer_id = buffer.entity_id();
204 match zeta_project.registered_buffers.entry(buffer_id) {
205 hash_map::Entry::Occupied(entry) => entry.into_mut(),
206 hash_map::Entry::Vacant(entry) => {
207 let snapshot = buffer.read(cx).snapshot();
208 let project_entity_id = project.entity_id();
209 entry.insert(RegisteredBuffer {
210 snapshot,
211 _subscriptions: [
212 cx.subscribe(buffer, {
213 let project = project.downgrade();
214 move |this, buffer, event, cx| {
215 if let language::BufferEvent::Edited = event
216 && let Some(project) = project.upgrade()
217 {
218 this.report_changes_for_buffer(&buffer, &project, cx);
219 }
220 }
221 }),
222 cx.observe_release(buffer, move |this, _buffer, _cx| {
223 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
224 else {
225 return;
226 };
227 zeta_project.registered_buffers.remove(&buffer_id);
228 }),
229 ],
230 })
231 }
232 }
233 }
234
235 fn report_changes_for_buffer(
236 &mut self,
237 buffer: &Entity<Buffer>,
238 project: &Entity<Project>,
239 cx: &mut Context<Self>,
240 ) -> BufferSnapshot {
241 let zeta_project = self.get_or_init_zeta_project(project, cx);
242 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
243
244 let new_snapshot = buffer.read(cx).snapshot();
245 if new_snapshot.version != registered_buffer.snapshot.version {
246 let old_snapshot =
247 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
248 Self::push_event(
249 zeta_project,
250 Event::BufferChange {
251 old_snapshot,
252 new_snapshot: new_snapshot.clone(),
253 timestamp: Instant::now(),
254 },
255 );
256 }
257
258 new_snapshot
259 }
260
261 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
262 let events = &mut zeta_project.events;
263
264 if let Some(Event::BufferChange {
265 new_snapshot: last_new_snapshot,
266 timestamp: last_timestamp,
267 ..
268 }) = events.back_mut()
269 {
270 // Coalesce edits for the same buffer when they happen one after the other.
271 let Event::BufferChange {
272 old_snapshot,
273 new_snapshot,
274 timestamp,
275 } = &event;
276
277 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
278 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
279 && old_snapshot.version == last_new_snapshot.version
280 {
281 *last_new_snapshot = new_snapshot.clone();
282 *last_timestamp = *timestamp;
283 return;
284 }
285 }
286
287 if events.len() >= MAX_EVENT_COUNT {
288 // These are halved instead of popping to improve prompt caching.
289 events.drain(..MAX_EVENT_COUNT / 2);
290 }
291
292 events.push_back(event);
293 }
294
295 pub fn request_prediction(
296 &mut self,
297 project: &Entity<Project>,
298 buffer: &Entity<Buffer>,
299 position: language::Anchor,
300 cx: &mut Context<Self>,
301 ) -> Task<Result<Option<EditPrediction>>> {
302 let project_state = self.projects.get(&project.entity_id());
303
304 let index_state = project_state.map(|state| {
305 state
306 .syntax_index
307 .read_with(cx, |index, _cx| index.state().clone())
308 });
309 let options = self.options.clone();
310 let snapshot = buffer.read(cx).snapshot();
311 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
312 return Task::ready(Err(anyhow!("No file path for excerpt")));
313 };
314 let client = self.client.clone();
315 let llm_token = self.llm_token.clone();
316 let app_version = AppVersion::global(cx);
317 let worktree_snapshots = project
318 .read(cx)
319 .worktrees(cx)
320 .map(|worktree| worktree.read(cx).snapshot())
321 .collect::<Vec<_>>();
322 let debug_tx = self.debug_tx.clone();
323
324 let events = project_state
325 .map(|state| {
326 state
327 .events
328 .iter()
329 .map(|event| match event {
330 Event::BufferChange {
331 old_snapshot,
332 new_snapshot,
333 ..
334 } => {
335 let path = new_snapshot.file().map(|f| f.path().to_path_buf());
336
337 let old_path = old_snapshot.file().and_then(|f| {
338 let old_path = f.path().as_ref();
339 if Some(old_path) != path.as_deref() {
340 Some(old_path.to_path_buf())
341 } else {
342 None
343 }
344 });
345
346 predict_edits_v3::Event::BufferChange {
347 old_path,
348 path,
349 diff: language::unified_diff(
350 &old_snapshot.text(),
351 &new_snapshot.text(),
352 ),
353 //todo: Actually detect if this edit was predicted or not
354 predicted: false,
355 }
356 }
357 })
358 .collect::<Vec<_>>()
359 })
360 .unwrap_or_default();
361
362 let diagnostics = snapshot.diagnostic_sets().clone();
363
364 let request_task = cx.background_spawn({
365 let snapshot = snapshot.clone();
366 let buffer = buffer.clone();
367 async move {
368 let index_state = if let Some(index_state) = index_state {
369 Some(index_state.lock_owned().await)
370 } else {
371 None
372 };
373
374 let cursor_offset = position.to_offset(&snapshot);
375 let cursor_point = cursor_offset.to_point(&snapshot);
376
377 let before_retrieval = chrono::Utc::now();
378
379 let Some(context) = EditPredictionContext::gather_context(
380 cursor_point,
381 &snapshot,
382 &options.excerpt,
383 index_state.as_deref(),
384 ) else {
385 return Ok(None);
386 };
387
388 let debug_context = if let Some(debug_tx) = debug_tx {
389 Some((debug_tx, context.clone()))
390 } else {
391 None
392 };
393
394 let (diagnostic_groups, diagnostic_groups_truncated) =
395 Self::gather_nearby_diagnostics(
396 cursor_offset,
397 &diagnostics,
398 &snapshot,
399 options.max_diagnostic_bytes,
400 );
401
402 let request = make_cloud_request(
403 excerpt_path.clone(),
404 context,
405 events,
406 // TODO data collection
407 false,
408 diagnostic_groups,
409 diagnostic_groups_truncated,
410 None,
411 debug_context.is_some(),
412 &worktree_snapshots,
413 index_state.as_deref(),
414 Some(options.max_prompt_bytes),
415 );
416
417 let retrieval_time = chrono::Utc::now() - before_retrieval;
418 let response = Self::perform_request(client, llm_token, app_version, request).await;
419
420 if let Some((debug_tx, context)) = debug_context {
421 debug_tx
422 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
423 |response| {
424 let Some(request) =
425 some_or_debug_panic(response.0.debug_info.clone())
426 else {
427 return Err("Missing debug info".to_string());
428 };
429 Ok(PredictionDebugInfo {
430 context,
431 request,
432 retrieval_time,
433 buffer: buffer.downgrade(),
434 position,
435 })
436 },
437 ))
438 .ok();
439 }
440
441 anyhow::Ok(Some(response?))
442 }
443 });
444
445 let buffer = buffer.clone();
446
447 cx.spawn(async move |this, cx| {
448 match request_task.await {
449 Ok(Some((response, usage))) => {
450 log::debug!("predicted edits: {:?}", &response.edits);
451
452 if let Some(usage) = usage {
453 this.update(cx, |this, cx| {
454 this.user_store.update(cx, |user_store, cx| {
455 user_store.update_edit_prediction_usage(usage, cx);
456 });
457 })
458 .ok();
459 }
460
461 // TODO telemetry: duration, etc
462
463 // TODO produce smaller edits by diffing against snapshot first
464 //
465 // Cloud returns entire snippets/excerpts ranges as they were included
466 // in the request, but we should display smaller edits to the user.
467 //
468 // We can do this by computing a diff of each one against the snapshot.
469 // Similar to zeta::Zeta::compute_edits, but per edit.
470 let edits = response
471 .edits
472 .into_iter()
473 .map(|edit| {
474 // TODO edits to different files
475 (
476 snapshot.anchor_before(edit.range.start)
477 ..snapshot.anchor_before(edit.range.end),
478 edit.content,
479 )
480 })
481 .collect::<Vec<_>>()
482 .into();
483
484 let Some((edits, snapshot, edit_preview_task)) =
485 buffer.read_with(cx, |buffer, cx| {
486 let new_snapshot = buffer.snapshot();
487 let edits: Arc<[_]> =
488 interpolate(&snapshot, &new_snapshot, edits)?.into();
489 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
490 })?
491 else {
492 return Ok(None);
493 };
494
495 Ok(Some(EditPrediction {
496 id: EditPredictionId(response.request_id),
497 edits,
498 snapshot,
499 edit_preview: edit_preview_task.await,
500 }))
501 }
502 Ok(None) => Ok(None),
503 Err(err) => {
504 if err.is::<ZedUpdateRequiredError>() {
505 cx.update(|cx| {
506 this.update(cx, |this, _cx| {
507 this.update_required = true;
508 })
509 .ok();
510
511 let error_message: SharedString = err.to_string().into();
512 show_app_notification(
513 NotificationId::unique::<ZedUpdateRequiredError>(),
514 cx,
515 move |cx| {
516 cx.new(|cx| {
517 ErrorMessagePrompt::new(error_message.clone(), cx)
518 .with_link_button(
519 "Update Zed",
520 "https://zed.dev/releases",
521 )
522 })
523 },
524 );
525 })
526 .ok();
527 }
528
529 Err(err)
530 }
531 }
532 })
533 }
534
535 async fn perform_request(
536 client: Arc<Client>,
537 llm_token: LlmApiToken,
538 app_version: SemanticVersion,
539 request: predict_edits_v3::PredictEditsRequest,
540 ) -> Result<(
541 predict_edits_v3::PredictEditsResponse,
542 Option<EditPredictionUsage>,
543 )> {
544 let http_client = client.http_client();
545 let mut token = llm_token.acquire(&client).await?;
546 let mut did_retry = false;
547
548 loop {
549 let request_builder = http_client::Request::builder().method(Method::POST);
550 let request_builder =
551 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
552 request_builder.uri(predict_edits_url)
553 } else {
554 request_builder.uri(
555 http_client
556 .build_zed_llm_url("/predict_edits/v3", &[])?
557 .as_ref(),
558 )
559 };
560 let request = request_builder
561 .header("Content-Type", "application/json")
562 .header("Authorization", format!("Bearer {}", token))
563 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
564 .body(serde_json::to_string(&request)?.into())?;
565
566 let mut response = http_client.send(request).await?;
567
568 if let Some(minimum_required_version) = response
569 .headers()
570 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
571 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
572 {
573 anyhow::ensure!(
574 app_version >= minimum_required_version,
575 ZedUpdateRequiredError {
576 minimum_version: minimum_required_version
577 }
578 );
579 }
580
581 if response.status().is_success() {
582 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
583
584 let mut body = Vec::new();
585 response.body_mut().read_to_end(&mut body).await?;
586 return Ok((serde_json::from_slice(&body)?, usage));
587 } else if !did_retry
588 && response
589 .headers()
590 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
591 .is_some()
592 {
593 did_retry = true;
594 token = llm_token.refresh(&client).await?;
595 } else {
596 let mut body = String::new();
597 response.body_mut().read_to_string(&mut body).await?;
598 anyhow::bail!(
599 "error predicting edits.\nStatus: {:?}\nBody: {}",
600 response.status(),
601 body
602 );
603 }
604 }
605 }
606
607 fn gather_nearby_diagnostics(
608 cursor_offset: usize,
609 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
610 snapshot: &BufferSnapshot,
611 max_diagnostics_bytes: usize,
612 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
613 // TODO: Could make this more efficient
614 let mut diagnostic_groups = Vec::new();
615 for (language_server_id, diagnostics) in diagnostic_sets {
616 let mut groups = Vec::new();
617 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
618 diagnostic_groups.extend(
619 groups
620 .into_iter()
621 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
622 );
623 }
624
625 // sort by proximity to cursor
626 diagnostic_groups.sort_by_key(|group| {
627 let range = &group.entries[group.primary_ix].range;
628 if range.start >= cursor_offset {
629 range.start - cursor_offset
630 } else if cursor_offset >= range.end {
631 cursor_offset - range.end
632 } else {
633 (cursor_offset - range.start).min(range.end - cursor_offset)
634 }
635 });
636
637 let mut results = Vec::new();
638 let mut diagnostic_groups_truncated = false;
639 let mut diagnostics_byte_count = 0;
640 for group in diagnostic_groups {
641 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
642 diagnostics_byte_count += raw_value.get().len();
643 if diagnostics_byte_count > max_diagnostics_bytes {
644 diagnostic_groups_truncated = true;
645 break;
646 }
647 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
648 }
649
650 (results, diagnostic_groups_truncated)
651 }
652
653 // TODO: Dedupe with similar code in request_prediction?
654 pub fn cloud_request_for_zeta_cli(
655 &mut self,
656 project: &Entity<Project>,
657 buffer: &Entity<Buffer>,
658 position: language::Anchor,
659 cx: &mut Context<Self>,
660 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
661 let project_state = self.projects.get(&project.entity_id());
662
663 let index_state = project_state.map(|state| {
664 state
665 .syntax_index
666 .read_with(cx, |index, _cx| index.state().clone())
667 });
668 let options = self.options.clone();
669 let snapshot = buffer.read(cx).snapshot();
670 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
671 return Task::ready(Err(anyhow!("No file path for excerpt")));
672 };
673 let worktree_snapshots = project
674 .read(cx)
675 .worktrees(cx)
676 .map(|worktree| worktree.read(cx).snapshot())
677 .collect::<Vec<_>>();
678
679 cx.background_spawn(async move {
680 let index_state = if let Some(index_state) = index_state {
681 Some(index_state.lock_owned().await)
682 } else {
683 None
684 };
685
686 let cursor_point = position.to_point(&snapshot);
687
688 let debug_info = true;
689 EditPredictionContext::gather_context(
690 cursor_point,
691 &snapshot,
692 &options.excerpt,
693 index_state.as_deref(),
694 )
695 .context("Failed to select excerpt")
696 .map(|context| {
697 make_cloud_request(
698 excerpt_path.clone(),
699 context,
700 // TODO pass everything
701 Vec::new(),
702 false,
703 Vec::new(),
704 false,
705 None,
706 debug_info,
707 &worktree_snapshots,
708 index_state.as_deref(),
709 Some(options.max_prompt_bytes),
710 )
711 })
712 })
713 }
714}
715
716#[derive(Error, Debug)]
717#[error(
718 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
719)]
720pub struct ZedUpdateRequiredError {
721 minimum_version: SemanticVersion,
722}
723
724pub struct ZetaEditPredictionProvider {
725 zeta: Entity<Zeta>,
726 current_prediction: Option<CurrentEditPrediction>,
727 next_pending_prediction_id: usize,
728 pending_predictions: ArrayVec<PendingPrediction, 2>,
729 last_request_timestamp: Instant,
730}
731
732impl ZetaEditPredictionProvider {
733 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
734
735 pub fn new(
736 project: Option<&Entity<Project>>,
737 client: &Arc<Client>,
738 user_store: &Entity<UserStore>,
739 cx: &mut App,
740 ) -> Self {
741 let zeta = Zeta::global(client, user_store, cx);
742 if let Some(project) = project {
743 zeta.update(cx, |zeta, cx| {
744 zeta.register_project(project, cx);
745 });
746 }
747
748 Self {
749 zeta,
750 current_prediction: None,
751 next_pending_prediction_id: 0,
752 pending_predictions: ArrayVec::new(),
753 last_request_timestamp: Instant::now(),
754 }
755 }
756}
757
758#[derive(Clone)]
759struct CurrentEditPrediction {
760 buffer_id: EntityId,
761 prediction: EditPrediction,
762}
763
764impl CurrentEditPrediction {
765 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
766 if self.buffer_id != old_prediction.buffer_id {
767 return true;
768 }
769
770 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
771 return true;
772 };
773 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
774 return false;
775 };
776
777 if old_edits.len() == 1 && new_edits.len() == 1 {
778 let (old_range, old_text) = &old_edits[0];
779 let (new_range, new_text) = &new_edits[0];
780 new_range == old_range && new_text.starts_with(old_text)
781 } else {
782 true
783 }
784 }
785}
786
787#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
788pub struct EditPredictionId(Uuid);
789
790impl From<EditPredictionId> for gpui::ElementId {
791 fn from(value: EditPredictionId) -> Self {
792 gpui::ElementId::Uuid(value.0)
793 }
794}
795
796impl std::fmt::Display for EditPredictionId {
797 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
798 write!(f, "{}", self.0)
799 }
800}
801
802#[derive(Clone)]
803pub struct EditPrediction {
804 id: EditPredictionId,
805 edits: Arc<[(Range<Anchor>, String)]>,
806 snapshot: BufferSnapshot,
807 edit_preview: EditPreview,
808}
809
810impl EditPrediction {
811 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
812 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
813 }
814}
815
816struct PendingPrediction {
817 id: usize,
818 _task: Task<()>,
819}
820
821impl EditPredictionProvider for ZetaEditPredictionProvider {
822 fn name() -> &'static str {
823 "zed-predict2"
824 }
825
826 fn display_name() -> &'static str {
827 "Zed's Edit Predictions 2"
828 }
829
830 fn show_completions_in_menu() -> bool {
831 true
832 }
833
834 fn show_tab_accept_marker() -> bool {
835 true
836 }
837
838 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
839 // TODO [zeta2]
840 DataCollectionState::Unsupported
841 }
842
843 fn toggle_data_collection(&mut self, _cx: &mut App) {
844 // TODO [zeta2]
845 }
846
847 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
848 self.zeta.read(cx).usage(cx)
849 }
850
851 fn is_enabled(
852 &self,
853 _buffer: &Entity<language::Buffer>,
854 _cursor_position: language::Anchor,
855 _cx: &App,
856 ) -> bool {
857 true
858 }
859
860 fn is_refreshing(&self) -> bool {
861 !self.pending_predictions.is_empty()
862 }
863
864 fn refresh(
865 &mut self,
866 project: Option<Entity<project::Project>>,
867 buffer: Entity<language::Buffer>,
868 cursor_position: language::Anchor,
869 _debounce: bool,
870 cx: &mut Context<Self>,
871 ) {
872 let Some(project) = project else {
873 return;
874 };
875
876 if self
877 .zeta
878 .read(cx)
879 .user_store
880 .read_with(cx, |user_store, _cx| {
881 user_store.account_too_young() || user_store.has_overdue_invoices()
882 })
883 {
884 return;
885 }
886
887 if let Some(current_prediction) = self.current_prediction.as_ref() {
888 let snapshot = buffer.read(cx).snapshot();
889 if current_prediction
890 .prediction
891 .interpolate(&snapshot)
892 .is_some()
893 {
894 return;
895 }
896 }
897
898 let pending_prediction_id = self.next_pending_prediction_id;
899 self.next_pending_prediction_id += 1;
900 let last_request_timestamp = self.last_request_timestamp;
901
902 let task = cx.spawn(async move |this, cx| {
903 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
904 .checked_duration_since(Instant::now())
905 {
906 cx.background_executor().timer(timeout).await;
907 }
908
909 let prediction_request = this.update(cx, |this, cx| {
910 this.last_request_timestamp = Instant::now();
911 this.zeta.update(cx, |zeta, cx| {
912 zeta.request_prediction(&project, &buffer, cursor_position, cx)
913 })
914 });
915
916 let prediction = match prediction_request {
917 Ok(prediction_request) => {
918 let prediction_request = prediction_request.await;
919 prediction_request.map(|c| {
920 c.map(|prediction| CurrentEditPrediction {
921 buffer_id: buffer.entity_id(),
922 prediction,
923 })
924 })
925 }
926 Err(error) => Err(error),
927 };
928
929 this.update(cx, |this, cx| {
930 if this.pending_predictions[0].id == pending_prediction_id {
931 this.pending_predictions.remove(0);
932 } else {
933 this.pending_predictions.clear();
934 }
935
936 let Some(new_prediction) = prediction
937 .context("edit prediction failed")
938 .log_err()
939 .flatten()
940 else {
941 cx.notify();
942 return;
943 };
944
945 if let Some(old_prediction) = this.current_prediction.as_ref() {
946 let snapshot = buffer.read(cx).snapshot();
947 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
948 this.current_prediction = Some(new_prediction);
949 }
950 } else {
951 this.current_prediction = Some(new_prediction);
952 }
953
954 cx.notify();
955 })
956 .ok();
957 });
958
959 // We always maintain at most two pending predictions. When we already
960 // have two, we replace the newest one.
961 if self.pending_predictions.len() <= 1 {
962 self.pending_predictions.push(PendingPrediction {
963 id: pending_prediction_id,
964 _task: task,
965 });
966 } else if self.pending_predictions.len() == 2 {
967 self.pending_predictions.pop();
968 self.pending_predictions.push(PendingPrediction {
969 id: pending_prediction_id,
970 _task: task,
971 });
972 }
973
974 cx.notify();
975 }
976
977 fn cycle(
978 &mut self,
979 _buffer: Entity<language::Buffer>,
980 _cursor_position: language::Anchor,
981 _direction: Direction,
982 _cx: &mut Context<Self>,
983 ) {
984 }
985
986 fn accept(&mut self, _cx: &mut Context<Self>) {
987 // TODO [zeta2] report accept
988 self.current_prediction.take();
989 self.pending_predictions.clear();
990 }
991
992 fn discard(&mut self, _cx: &mut Context<Self>) {
993 self.pending_predictions.clear();
994 self.current_prediction.take();
995 }
996
997 fn suggest(
998 &mut self,
999 buffer: &Entity<language::Buffer>,
1000 cursor_position: language::Anchor,
1001 cx: &mut Context<Self>,
1002 ) -> Option<edit_prediction::EditPrediction> {
1003 let CurrentEditPrediction {
1004 buffer_id,
1005 prediction,
1006 ..
1007 } = self.current_prediction.as_mut()?;
1008
1009 // Invalidate previous prediction if it was generated for a different buffer.
1010 if *buffer_id != buffer.entity_id() {
1011 self.current_prediction.take();
1012 return None;
1013 }
1014
1015 let buffer = buffer.read(cx);
1016 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
1017 self.current_prediction.take();
1018 return None;
1019 };
1020
1021 let cursor_row = cursor_position.to_point(buffer).row;
1022 let (closest_edit_ix, (closest_edit_range, _)) =
1023 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1024 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1025 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1026 cmp::min(distance_from_start, distance_from_end)
1027 })?;
1028
1029 let mut edit_start_ix = closest_edit_ix;
1030 for (range, _) in edits[..edit_start_ix].iter().rev() {
1031 let distance_from_closest_edit =
1032 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1033 if distance_from_closest_edit <= 1 {
1034 edit_start_ix -= 1;
1035 } else {
1036 break;
1037 }
1038 }
1039
1040 let mut edit_end_ix = closest_edit_ix + 1;
1041 for (range, _) in &edits[edit_end_ix..] {
1042 let distance_from_closest_edit =
1043 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1044 if distance_from_closest_edit <= 1 {
1045 edit_end_ix += 1;
1046 } else {
1047 break;
1048 }
1049 }
1050
1051 Some(edit_prediction::EditPrediction {
1052 id: Some(prediction.id.to_string().into()),
1053 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1054 edit_preview: Some(prediction.edit_preview.clone()),
1055 })
1056 }
1057}
1058
1059fn make_cloud_request(
1060 excerpt_path: PathBuf,
1061 context: EditPredictionContext,
1062 events: Vec<predict_edits_v3::Event>,
1063 can_collect_data: bool,
1064 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1065 diagnostic_groups_truncated: bool,
1066 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1067 debug_info: bool,
1068 worktrees: &Vec<worktree::Snapshot>,
1069 index_state: Option<&SyntaxIndexState>,
1070 prompt_max_bytes: Option<usize>,
1071) -> predict_edits_v3::PredictEditsRequest {
1072 let mut signatures = Vec::new();
1073 let mut declaration_to_signature_index = HashMap::default();
1074 let mut referenced_declarations = Vec::new();
1075
1076 for snippet in context.snippets {
1077 let project_entry_id = snippet.declaration.project_entry_id();
1078 let Some(path) = worktrees.iter().find_map(|worktree| {
1079 worktree.entry_for_id(project_entry_id).map(|entry| {
1080 let mut full_path = PathBuf::new();
1081 full_path.push(worktree.root_name());
1082 full_path.push(&entry.path);
1083 full_path
1084 })
1085 }) else {
1086 continue;
1087 };
1088
1089 let parent_index = index_state.and_then(|index_state| {
1090 snippet.declaration.parent().and_then(|parent| {
1091 add_signature(
1092 parent,
1093 &mut declaration_to_signature_index,
1094 &mut signatures,
1095 index_state,
1096 )
1097 })
1098 });
1099
1100 let (text, text_is_truncated) = snippet.declaration.item_text();
1101 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1102 path,
1103 text: text.into(),
1104 range: snippet.declaration.item_range(),
1105 text_is_truncated,
1106 signature_range: snippet.declaration.signature_range_in_item_text(),
1107 parent_index,
1108 score_components: snippet.score_components,
1109 signature_score: snippet.scores.signature,
1110 declaration_score: snippet.scores.declaration,
1111 });
1112 }
1113
1114 let excerpt_parent = index_state.and_then(|index_state| {
1115 context
1116 .excerpt
1117 .parent_declarations
1118 .last()
1119 .and_then(|(parent, _)| {
1120 add_signature(
1121 *parent,
1122 &mut declaration_to_signature_index,
1123 &mut signatures,
1124 index_state,
1125 )
1126 })
1127 });
1128
1129 predict_edits_v3::PredictEditsRequest {
1130 excerpt_path,
1131 excerpt: context.excerpt_text.body,
1132 excerpt_range: context.excerpt.range,
1133 cursor_offset: context.cursor_offset_in_excerpt,
1134 referenced_declarations,
1135 signatures,
1136 excerpt_parent,
1137 events,
1138 can_collect_data,
1139 diagnostic_groups,
1140 diagnostic_groups_truncated,
1141 git_info,
1142 debug_info,
1143 prompt_max_bytes,
1144 }
1145}
1146
1147fn add_signature(
1148 declaration_id: DeclarationId,
1149 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1150 signatures: &mut Vec<Signature>,
1151 index: &SyntaxIndexState,
1152) -> Option<usize> {
1153 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1154 return Some(*signature_index);
1155 }
1156 let Some(parent_declaration) = index.declaration(declaration_id) else {
1157 log::error!("bug: missing parent declaration");
1158 return None;
1159 };
1160 let parent_index = parent_declaration.parent().and_then(|parent| {
1161 add_signature(parent, declaration_to_signature_index, signatures, index)
1162 });
1163 let (text, text_is_truncated) = parent_declaration.signature_text();
1164 let signature_index = signatures.len();
1165 signatures.push(Signature {
1166 text: text.into(),
1167 text_is_truncated,
1168 parent_index,
1169 range: parent_declaration.signature_range(),
1170 });
1171 declaration_to_signature_index.insert(declaration_id, signature_index);
1172 Some(signature_index)
1173}
1174
1175fn interpolate(
1176 old_snapshot: &BufferSnapshot,
1177 new_snapshot: &BufferSnapshot,
1178 current_edits: Arc<[(Range<Anchor>, String)]>,
1179) -> Option<Vec<(Range<Anchor>, String)>> {
1180 let mut edits = Vec::new();
1181
1182 let mut model_edits = current_edits.iter().peekable();
1183 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1184 while let Some((model_old_range, _)) = model_edits.peek() {
1185 let model_old_range = model_old_range.to_offset(old_snapshot);
1186 if model_old_range.end < user_edit.old.start {
1187 let (model_old_range, model_new_text) = model_edits.next().unwrap();
1188 edits.push((model_old_range.clone(), model_new_text.clone()));
1189 } else {
1190 break;
1191 }
1192 }
1193
1194 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1195 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1196 if user_edit.old == model_old_offset_range {
1197 let user_new_text = new_snapshot
1198 .text_for_range(user_edit.new.clone())
1199 .collect::<String>();
1200
1201 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1202 if !model_suffix.is_empty() {
1203 let anchor = old_snapshot.anchor_after(user_edit.old.end);
1204 edits.push((anchor..anchor, model_suffix.to_string()));
1205 }
1206
1207 model_edits.next();
1208 continue;
1209 }
1210 }
1211 }
1212
1213 return None;
1214 }
1215
1216 edits.extend(model_edits.cloned());
1217
1218 if edits.is_empty() { None } else { Some(edits) }
1219}
1220
1221#[cfg(test)]
1222mod tests {
1223 use super::*;
1224 use gpui::TestAppContext;
1225
1226 #[gpui::test]
1227 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1228 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1229 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1230 to_prediction_edits(
1231 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1232 &buffer,
1233 cx,
1234 )
1235 .into()
1236 });
1237
1238 let edit_preview = cx
1239 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1240 .await;
1241
1242 let prediction = EditPrediction {
1243 id: EditPredictionId(Uuid::new_v4()),
1244 edits,
1245 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1246 edit_preview,
1247 };
1248
1249 cx.update(|cx| {
1250 assert_eq!(
1251 from_prediction_edits(
1252 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1253 &buffer,
1254 cx
1255 ),
1256 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1257 );
1258
1259 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1260 assert_eq!(
1261 from_prediction_edits(
1262 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1263 &buffer,
1264 cx
1265 ),
1266 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1267 );
1268
1269 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1270 assert_eq!(
1271 from_prediction_edits(
1272 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1273 &buffer,
1274 cx
1275 ),
1276 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1277 );
1278
1279 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1280 assert_eq!(
1281 from_prediction_edits(
1282 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1283 &buffer,
1284 cx
1285 ),
1286 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1287 );
1288
1289 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1290 assert_eq!(
1291 from_prediction_edits(
1292 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1293 &buffer,
1294 cx
1295 ),
1296 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1297 );
1298
1299 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1300 assert_eq!(
1301 from_prediction_edits(
1302 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1303 &buffer,
1304 cx
1305 ),
1306 vec![(9..11, "".to_string())]
1307 );
1308
1309 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1310 assert_eq!(
1311 from_prediction_edits(
1312 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1313 &buffer,
1314 cx
1315 ),
1316 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1317 );
1318
1319 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1320 assert_eq!(
1321 from_prediction_edits(
1322 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1323 &buffer,
1324 cx
1325 ),
1326 vec![(4..4, "M".to_string())]
1327 );
1328
1329 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1330 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1331 })
1332 }
1333
1334 fn to_prediction_edits(
1335 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1336 buffer: &Entity<Buffer>,
1337 cx: &App,
1338 ) -> Vec<(Range<Anchor>, String)> {
1339 let buffer = buffer.read(cx);
1340 iterator
1341 .into_iter()
1342 .map(|(range, text)| {
1343 (
1344 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1345 text,
1346 )
1347 })
1348 .collect()
1349 }
1350
1351 fn from_prediction_edits(
1352 editor_edits: &[(Range<Anchor>, String)],
1353 buffer: &Entity<Buffer>,
1354 cx: &App,
1355 ) -> Vec<(Range<usize>, String)> {
1356 let buffer = buffer.read(cx);
1357 editor_edits
1358 .iter()
1359 .map(|(range, text)| {
1360 (
1361 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1362 text.clone(),
1363 )
1364 })
1365 .collect()
1366 }
1367}