1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use chrono::TimeDelta;
4use client::{Client, EditPredictionUsage, UserStore};
5use cloud_llm_client::predict_edits_v3::{self, Signature};
6use cloud_llm_client::{
7 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
10use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
11use edit_prediction_context::{
12 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
13 SyntaxIndexState,
14};
15use futures::AsyncReadExt as _;
16use futures::channel::mpsc;
17use gpui::http_client::Method;
18use gpui::{
19 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
20 http_client, prelude::*,
21};
22use language::{
23 Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
24 text_diff,
25};
26use language::{BufferSnapshot, EditPreview};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use project::Project;
29use release_channel::AppVersion;
30use std::borrow::Cow;
31use std::cmp;
32use std::collections::{HashMap, VecDeque, hash_map};
33use std::path::PathBuf;
34use std::str::FromStr as _;
35use std::time::{Duration, Instant};
36use std::{ops::Range, sync::Arc};
37use thiserror::Error;
38use util::{ResultExt as _, some_or_debug_panic};
39use uuid::Uuid;
40use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
41
42const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
43
44/// Maximum number of events to track.
45const MAX_EVENT_COUNT: usize = 16;
46
47pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
48 max_bytes: 512,
49 min_bytes: 128,
50 target_before_cursor_over_total_bytes: 0.5,
51};
52
53pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
54 excerpt: DEFAULT_EXCERPT_OPTIONS,
55 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
56 max_diagnostic_bytes: 2048,
57};
58
59#[derive(Clone)]
60struct ZetaGlobal(Entity<Zeta>);
61
62impl Global for ZetaGlobal {}
63
64pub struct Zeta {
65 client: Arc<Client>,
66 user_store: Entity<UserStore>,
67 llm_token: LlmApiToken,
68 _llm_token_subscription: Subscription,
69 projects: HashMap<EntityId, ZetaProject>,
70 options: ZetaOptions,
71 update_required: bool,
72 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
76pub struct ZetaOptions {
77 pub excerpt: EditPredictionExcerptOptions,
78 pub max_prompt_bytes: usize,
79 pub max_diagnostic_bytes: usize,
80}
81
82pub struct PredictionDebugInfo {
83 pub context: EditPredictionContext,
84 pub retrieval_time: TimeDelta,
85 pub request: RequestDebugInfo,
86 pub buffer: WeakEntity<Buffer>,
87 pub position: language::Anchor,
88}
89
90pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
91
92struct ZetaProject {
93 syntax_index: Entity<SyntaxIndex>,
94 events: VecDeque<Event>,
95 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
96}
97
98struct RegisteredBuffer {
99 snapshot: BufferSnapshot,
100 _subscriptions: [gpui::Subscription; 2],
101}
102
103#[derive(Clone)]
104pub enum Event {
105 BufferChange {
106 old_snapshot: BufferSnapshot,
107 new_snapshot: BufferSnapshot,
108 timestamp: Instant,
109 },
110}
111
112impl Zeta {
113 pub fn global(
114 client: &Arc<Client>,
115 user_store: &Entity<UserStore>,
116 cx: &mut App,
117 ) -> Entity<Self> {
118 cx.try_global::<ZetaGlobal>()
119 .map(|global| global.0.clone())
120 .unwrap_or_else(|| {
121 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
122 cx.set_global(ZetaGlobal(zeta.clone()));
123 zeta
124 })
125 }
126
127 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
128 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
129
130 Self {
131 projects: HashMap::new(),
132 client,
133 user_store,
134 options: DEFAULT_OPTIONS,
135 llm_token: LlmApiToken::default(),
136 _llm_token_subscription: cx.subscribe(
137 &refresh_llm_token_listener,
138 |this, _listener, _event, cx| {
139 let client = this.client.clone();
140 let llm_token = this.llm_token.clone();
141 cx.spawn(async move |_this, _cx| {
142 llm_token.refresh(&client).await?;
143 anyhow::Ok(())
144 })
145 .detach_and_log_err(cx);
146 },
147 ),
148 update_required: false,
149 debug_tx: None,
150 }
151 }
152
153 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
154 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
155 self.debug_tx = Some(debug_watch_tx);
156 debug_watch_rx
157 }
158
159 pub fn options(&self) -> &ZetaOptions {
160 &self.options
161 }
162
163 pub fn set_options(&mut self, options: ZetaOptions) {
164 self.options = options;
165 }
166
167 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
168 self.user_store.read(cx).edit_prediction_usage()
169 }
170
171 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
172 self.get_or_init_zeta_project(project, cx);
173 }
174
175 pub fn register_buffer(
176 &mut self,
177 buffer: &Entity<Buffer>,
178 project: &Entity<Project>,
179 cx: &mut Context<Self>,
180 ) {
181 let zeta_project = self.get_or_init_zeta_project(project, cx);
182 Self::register_buffer_impl(zeta_project, buffer, project, cx);
183 }
184
185 fn get_or_init_zeta_project(
186 &mut self,
187 project: &Entity<Project>,
188 cx: &mut App,
189 ) -> &mut ZetaProject {
190 self.projects
191 .entry(project.entity_id())
192 .or_insert_with(|| ZetaProject {
193 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
194 events: VecDeque::new(),
195 registered_buffers: HashMap::new(),
196 })
197 }
198
199 fn register_buffer_impl<'a>(
200 zeta_project: &'a mut ZetaProject,
201 buffer: &Entity<Buffer>,
202 project: &Entity<Project>,
203 cx: &mut Context<Self>,
204 ) -> &'a mut RegisteredBuffer {
205 let buffer_id = buffer.entity_id();
206 match zeta_project.registered_buffers.entry(buffer_id) {
207 hash_map::Entry::Occupied(entry) => entry.into_mut(),
208 hash_map::Entry::Vacant(entry) => {
209 let snapshot = buffer.read(cx).snapshot();
210 let project_entity_id = project.entity_id();
211 entry.insert(RegisteredBuffer {
212 snapshot,
213 _subscriptions: [
214 cx.subscribe(buffer, {
215 let project = project.downgrade();
216 move |this, buffer, event, cx| {
217 if let language::BufferEvent::Edited = event
218 && let Some(project) = project.upgrade()
219 {
220 this.report_changes_for_buffer(&buffer, &project, cx);
221 }
222 }
223 }),
224 cx.observe_release(buffer, move |this, _buffer, _cx| {
225 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
226 else {
227 return;
228 };
229 zeta_project.registered_buffers.remove(&buffer_id);
230 }),
231 ],
232 })
233 }
234 }
235 }
236
237 fn report_changes_for_buffer(
238 &mut self,
239 buffer: &Entity<Buffer>,
240 project: &Entity<Project>,
241 cx: &mut Context<Self>,
242 ) -> BufferSnapshot {
243 let zeta_project = self.get_or_init_zeta_project(project, cx);
244 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
245
246 let new_snapshot = buffer.read(cx).snapshot();
247 if new_snapshot.version != registered_buffer.snapshot.version {
248 let old_snapshot =
249 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
250 Self::push_event(
251 zeta_project,
252 Event::BufferChange {
253 old_snapshot,
254 new_snapshot: new_snapshot.clone(),
255 timestamp: Instant::now(),
256 },
257 );
258 }
259
260 new_snapshot
261 }
262
263 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
264 let events = &mut zeta_project.events;
265
266 if let Some(Event::BufferChange {
267 new_snapshot: last_new_snapshot,
268 timestamp: last_timestamp,
269 ..
270 }) = events.back_mut()
271 {
272 // Coalesce edits for the same buffer when they happen one after the other.
273 let Event::BufferChange {
274 old_snapshot,
275 new_snapshot,
276 timestamp,
277 } = &event;
278
279 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
280 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
281 && old_snapshot.version == last_new_snapshot.version
282 {
283 *last_new_snapshot = new_snapshot.clone();
284 *last_timestamp = *timestamp;
285 return;
286 }
287 }
288
289 if events.len() >= MAX_EVENT_COUNT {
290 // These are halved instead of popping to improve prompt caching.
291 events.drain(..MAX_EVENT_COUNT / 2);
292 }
293
294 events.push_back(event);
295 }
296
297 pub fn request_prediction(
298 &mut self,
299 project: &Entity<Project>,
300 buffer: &Entity<Buffer>,
301 position: language::Anchor,
302 cx: &mut Context<Self>,
303 ) -> Task<Result<Option<EditPrediction>>> {
304 let project_state = self.projects.get(&project.entity_id());
305
306 let index_state = project_state.map(|state| {
307 state
308 .syntax_index
309 .read_with(cx, |index, _cx| index.state().clone())
310 });
311 let options = self.options.clone();
312 let snapshot = buffer.read(cx).snapshot();
313 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
314 return Task::ready(Err(anyhow!("No file path for excerpt")));
315 };
316 let client = self.client.clone();
317 let llm_token = self.llm_token.clone();
318 let app_version = AppVersion::global(cx);
319 let worktree_snapshots = project
320 .read(cx)
321 .worktrees(cx)
322 .map(|worktree| worktree.read(cx).snapshot())
323 .collect::<Vec<_>>();
324 let debug_tx = self.debug_tx.clone();
325
326 let events = project_state
327 .map(|state| {
328 state
329 .events
330 .iter()
331 .map(|event| match event {
332 Event::BufferChange {
333 old_snapshot,
334 new_snapshot,
335 ..
336 } => {
337 let path = new_snapshot.file().map(|f| f.path().to_path_buf());
338
339 let old_path = old_snapshot.file().and_then(|f| {
340 let old_path = f.path().as_ref();
341 if Some(old_path) != path.as_deref() {
342 Some(old_path.to_path_buf())
343 } else {
344 None
345 }
346 });
347
348 predict_edits_v3::Event::BufferChange {
349 old_path,
350 path,
351 diff: language::unified_diff(
352 &old_snapshot.text(),
353 &new_snapshot.text(),
354 ),
355 //todo: Actually detect if this edit was predicted or not
356 predicted: false,
357 }
358 }
359 })
360 .collect::<Vec<_>>()
361 })
362 .unwrap_or_default();
363
364 let diagnostics = snapshot.diagnostic_sets().clone();
365
366 let request_task = cx.background_spawn({
367 let snapshot = snapshot.clone();
368 let buffer = buffer.clone();
369 async move {
370 let index_state = if let Some(index_state) = index_state {
371 Some(index_state.lock_owned().await)
372 } else {
373 None
374 };
375
376 let cursor_offset = position.to_offset(&snapshot);
377 let cursor_point = cursor_offset.to_point(&snapshot);
378
379 let before_retrieval = chrono::Utc::now();
380
381 let Some(context) = EditPredictionContext::gather_context(
382 cursor_point,
383 &snapshot,
384 &options.excerpt,
385 index_state.as_deref(),
386 ) else {
387 return Ok(None);
388 };
389
390 let debug_context = if let Some(debug_tx) = debug_tx {
391 Some((debug_tx, context.clone()))
392 } else {
393 None
394 };
395
396 let (diagnostic_groups, diagnostic_groups_truncated) =
397 Self::gather_nearby_diagnostics(
398 cursor_offset,
399 &diagnostics,
400 &snapshot,
401 options.max_diagnostic_bytes,
402 );
403
404 let request = make_cloud_request(
405 excerpt_path.clone(),
406 context,
407 events,
408 // TODO data collection
409 false,
410 diagnostic_groups,
411 diagnostic_groups_truncated,
412 None,
413 debug_context.is_some(),
414 &worktree_snapshots,
415 index_state.as_deref(),
416 Some(options.max_prompt_bytes),
417 );
418
419 let retrieval_time = chrono::Utc::now() - before_retrieval;
420 let response = Self::perform_request(client, llm_token, app_version, request).await;
421
422 if let Some((debug_tx, context)) = debug_context {
423 debug_tx
424 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
425 |response| {
426 let Some(request) =
427 some_or_debug_panic(response.0.debug_info.clone())
428 else {
429 return Err("Missing debug info".to_string());
430 };
431 Ok(PredictionDebugInfo {
432 context,
433 request,
434 retrieval_time,
435 buffer: buffer.downgrade(),
436 position,
437 })
438 },
439 ))
440 .ok();
441 }
442
443 let (response, usage) = response?;
444 let edits = Self::compute_edits(&response.edits, &snapshot);
445
446 anyhow::Ok(Some((response.request_id, edits, usage)))
447 }
448 });
449
450 let buffer = buffer.clone();
451
452 cx.spawn(async move |this, cx| {
453 match request_task.await {
454 Ok(Some((id, edits, usage))) => {
455 if let Some(usage) = usage {
456 this.update(cx, |this, cx| {
457 this.user_store.update(cx, |user_store, cx| {
458 user_store.update_edit_prediction_usage(usage, cx);
459 });
460 })
461 .ok();
462 }
463
464 // TODO telemetry: duration, etc
465 let Some((edits, snapshot, edit_preview_task)) =
466 buffer.read_with(cx, |buffer, cx| {
467 let new_snapshot = buffer.snapshot();
468 let edits: Arc<[_]> =
469 interpolate(&snapshot, &new_snapshot, edits)?.into();
470 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
471 })?
472 else {
473 return Ok(None);
474 };
475
476 Ok(Some(EditPrediction {
477 id: EditPredictionId(id),
478 edits,
479 snapshot,
480 edit_preview: edit_preview_task.await,
481 }))
482 }
483 Ok(None) => Ok(None),
484 Err(err) => {
485 if err.is::<ZedUpdateRequiredError>() {
486 cx.update(|cx| {
487 this.update(cx, |this, _cx| {
488 this.update_required = true;
489 })
490 .ok();
491
492 let error_message: SharedString = err.to_string().into();
493 show_app_notification(
494 NotificationId::unique::<ZedUpdateRequiredError>(),
495 cx,
496 move |cx| {
497 cx.new(|cx| {
498 ErrorMessagePrompt::new(error_message.clone(), cx)
499 .with_link_button(
500 "Update Zed",
501 "https://zed.dev/releases",
502 )
503 })
504 },
505 );
506 })
507 .ok();
508 }
509
510 Err(err)
511 }
512 }
513 })
514 }
515
516 async fn perform_request(
517 client: Arc<Client>,
518 llm_token: LlmApiToken,
519 app_version: SemanticVersion,
520 request: predict_edits_v3::PredictEditsRequest,
521 ) -> Result<(
522 predict_edits_v3::PredictEditsResponse,
523 Option<EditPredictionUsage>,
524 )> {
525 let http_client = client.http_client();
526 let mut token = llm_token.acquire(&client).await?;
527 let mut did_retry = false;
528
529 loop {
530 let request_builder = http_client::Request::builder().method(Method::POST);
531 let request_builder =
532 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
533 request_builder.uri(predict_edits_url)
534 } else {
535 request_builder.uri(
536 http_client
537 .build_zed_llm_url("/predict_edits/v3", &[])?
538 .as_ref(),
539 )
540 };
541 let request = request_builder
542 .header("Content-Type", "application/json")
543 .header("Authorization", format!("Bearer {}", token))
544 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
545 .body(serde_json::to_string(&request)?.into())?;
546
547 let mut response = http_client.send(request).await?;
548
549 if let Some(minimum_required_version) = response
550 .headers()
551 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
552 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
553 {
554 anyhow::ensure!(
555 app_version >= minimum_required_version,
556 ZedUpdateRequiredError {
557 minimum_version: minimum_required_version
558 }
559 );
560 }
561
562 if response.status().is_success() {
563 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
564
565 let mut body = Vec::new();
566 response.body_mut().read_to_end(&mut body).await?;
567 return Ok((serde_json::from_slice(&body)?, usage));
568 } else if !did_retry
569 && response
570 .headers()
571 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
572 .is_some()
573 {
574 did_retry = true;
575 token = llm_token.refresh(&client).await?;
576 } else {
577 let mut body = String::new();
578 response.body_mut().read_to_string(&mut body).await?;
579 anyhow::bail!(
580 "error predicting edits.\nStatus: {:?}\nBody: {}",
581 response.status(),
582 body
583 );
584 }
585 }
586 }
587
588 fn compute_edits(
589 edits: &[predict_edits_v3::Edit],
590 snapshot: &BufferSnapshot,
591 ) -> Arc<[(Range<Anchor>, String)]> {
592 edits
593 .iter()
594 .flat_map(|edit| {
595 // TODO multi-file edits
596 let old_text = snapshot.text_for_range(edit.range.clone());
597
598 Self::compute_excerpt_edits(
599 old_text.collect::<Cow<str>>(),
600 &edit.content,
601 edit.range.start,
602 &snapshot,
603 )
604 })
605 .collect::<Vec<_>>()
606 .into()
607 }
608
609 fn compute_excerpt_edits(
610 old_text: Cow<str>,
611 new_text: &str,
612 offset: usize,
613 snapshot: &BufferSnapshot,
614 ) -> impl Iterator<Item = (Range<Anchor>, String)> {
615 text_diff(&old_text, new_text)
616 .into_iter()
617 .map(move |(mut old_range, new_text)| {
618 old_range.start += offset;
619 old_range.end += offset;
620
621 let prefix_len = common_prefix(
622 snapshot.chars_for_range(old_range.clone()),
623 new_text.chars(),
624 );
625 old_range.start += prefix_len;
626
627 let suffix_len = common_prefix(
628 snapshot.reversed_chars_for_range(old_range.clone()),
629 new_text[prefix_len..].chars().rev(),
630 );
631 old_range.end = old_range.end.saturating_sub(suffix_len);
632
633 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
634 let range = if old_range.is_empty() {
635 let anchor = snapshot.anchor_after(old_range.start);
636 anchor..anchor
637 } else {
638 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
639 };
640 (range, new_text)
641 })
642 }
643
644 fn gather_nearby_diagnostics(
645 cursor_offset: usize,
646 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
647 snapshot: &BufferSnapshot,
648 max_diagnostics_bytes: usize,
649 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
650 // TODO: Could make this more efficient
651 let mut diagnostic_groups = Vec::new();
652 for (language_server_id, diagnostics) in diagnostic_sets {
653 let mut groups = Vec::new();
654 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
655 diagnostic_groups.extend(
656 groups
657 .into_iter()
658 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
659 );
660 }
661
662 // sort by proximity to cursor
663 diagnostic_groups.sort_by_key(|group| {
664 let range = &group.entries[group.primary_ix].range;
665 if range.start >= cursor_offset {
666 range.start - cursor_offset
667 } else if cursor_offset >= range.end {
668 cursor_offset - range.end
669 } else {
670 (cursor_offset - range.start).min(range.end - cursor_offset)
671 }
672 });
673
674 let mut results = Vec::new();
675 let mut diagnostic_groups_truncated = false;
676 let mut diagnostics_byte_count = 0;
677 for group in diagnostic_groups {
678 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
679 diagnostics_byte_count += raw_value.get().len();
680 if diagnostics_byte_count > max_diagnostics_bytes {
681 diagnostic_groups_truncated = true;
682 break;
683 }
684 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
685 }
686
687 (results, diagnostic_groups_truncated)
688 }
689
690 // TODO: Dedupe with similar code in request_prediction?
691 pub fn cloud_request_for_zeta_cli(
692 &mut self,
693 project: &Entity<Project>,
694 buffer: &Entity<Buffer>,
695 position: language::Anchor,
696 cx: &mut Context<Self>,
697 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
698 let project_state = self.projects.get(&project.entity_id());
699
700 let index_state = project_state.map(|state| {
701 state
702 .syntax_index
703 .read_with(cx, |index, _cx| index.state().clone())
704 });
705 let options = self.options.clone();
706 let snapshot = buffer.read(cx).snapshot();
707 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
708 return Task::ready(Err(anyhow!("No file path for excerpt")));
709 };
710 let worktree_snapshots = project
711 .read(cx)
712 .worktrees(cx)
713 .map(|worktree| worktree.read(cx).snapshot())
714 .collect::<Vec<_>>();
715
716 cx.background_spawn(async move {
717 let index_state = if let Some(index_state) = index_state {
718 Some(index_state.lock_owned().await)
719 } else {
720 None
721 };
722
723 let cursor_point = position.to_point(&snapshot);
724
725 let debug_info = true;
726 EditPredictionContext::gather_context(
727 cursor_point,
728 &snapshot,
729 &options.excerpt,
730 index_state.as_deref(),
731 )
732 .context("Failed to select excerpt")
733 .map(|context| {
734 make_cloud_request(
735 excerpt_path.clone(),
736 context,
737 // TODO pass everything
738 Vec::new(),
739 false,
740 Vec::new(),
741 false,
742 None,
743 debug_info,
744 &worktree_snapshots,
745 index_state.as_deref(),
746 Some(options.max_prompt_bytes),
747 )
748 })
749 })
750 }
751}
752
753fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
754 a.zip(b)
755 .take_while(|(a, b)| a == b)
756 .map(|(a, _)| a.len_utf8())
757 .sum()
758}
759
760#[derive(Error, Debug)]
761#[error(
762 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
763)]
764pub struct ZedUpdateRequiredError {
765 minimum_version: SemanticVersion,
766}
767
768pub struct ZetaEditPredictionProvider {
769 zeta: Entity<Zeta>,
770 current_prediction: Option<CurrentEditPrediction>,
771 next_pending_prediction_id: usize,
772 pending_predictions: ArrayVec<PendingPrediction, 2>,
773 last_request_timestamp: Instant,
774}
775
776impl ZetaEditPredictionProvider {
777 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
778
779 pub fn new(
780 project: Option<&Entity<Project>>,
781 client: &Arc<Client>,
782 user_store: &Entity<UserStore>,
783 cx: &mut App,
784 ) -> Self {
785 let zeta = Zeta::global(client, user_store, cx);
786 if let Some(project) = project {
787 zeta.update(cx, |zeta, cx| {
788 zeta.register_project(project, cx);
789 });
790 }
791
792 Self {
793 zeta,
794 current_prediction: None,
795 next_pending_prediction_id: 0,
796 pending_predictions: ArrayVec::new(),
797 last_request_timestamp: Instant::now(),
798 }
799 }
800}
801
802#[derive(Clone)]
803struct CurrentEditPrediction {
804 buffer_id: EntityId,
805 prediction: EditPrediction,
806}
807
808impl CurrentEditPrediction {
809 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
810 if self.buffer_id != old_prediction.buffer_id {
811 return true;
812 }
813
814 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
815 return true;
816 };
817 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
818 return false;
819 };
820
821 if old_edits.len() == 1 && new_edits.len() == 1 {
822 let (old_range, old_text) = &old_edits[0];
823 let (new_range, new_text) = &new_edits[0];
824 new_range == old_range && new_text.starts_with(old_text)
825 } else {
826 true
827 }
828 }
829}
830
831#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
832pub struct EditPredictionId(Uuid);
833
834impl From<EditPredictionId> for gpui::ElementId {
835 fn from(value: EditPredictionId) -> Self {
836 gpui::ElementId::Uuid(value.0)
837 }
838}
839
840impl std::fmt::Display for EditPredictionId {
841 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
842 write!(f, "{}", self.0)
843 }
844}
845
846#[derive(Clone)]
847pub struct EditPrediction {
848 id: EditPredictionId,
849 edits: Arc<[(Range<Anchor>, String)]>,
850 snapshot: BufferSnapshot,
851 edit_preview: EditPreview,
852}
853
854impl EditPrediction {
855 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
856 interpolate(&self.snapshot, new_snapshot, self.edits.clone())
857 }
858}
859
860struct PendingPrediction {
861 id: usize,
862 _task: Task<()>,
863}
864
865impl EditPredictionProvider for ZetaEditPredictionProvider {
866 fn name() -> &'static str {
867 "zed-predict2"
868 }
869
870 fn display_name() -> &'static str {
871 "Zed's Edit Predictions 2"
872 }
873
874 fn show_completions_in_menu() -> bool {
875 true
876 }
877
878 fn show_tab_accept_marker() -> bool {
879 true
880 }
881
882 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
883 // TODO [zeta2]
884 DataCollectionState::Unsupported
885 }
886
887 fn toggle_data_collection(&mut self, _cx: &mut App) {
888 // TODO [zeta2]
889 }
890
891 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
892 self.zeta.read(cx).usage(cx)
893 }
894
895 fn is_enabled(
896 &self,
897 _buffer: &Entity<language::Buffer>,
898 _cursor_position: language::Anchor,
899 _cx: &App,
900 ) -> bool {
901 true
902 }
903
904 fn is_refreshing(&self) -> bool {
905 !self.pending_predictions.is_empty()
906 }
907
908 fn refresh(
909 &mut self,
910 project: Option<Entity<project::Project>>,
911 buffer: Entity<language::Buffer>,
912 cursor_position: language::Anchor,
913 _debounce: bool,
914 cx: &mut Context<Self>,
915 ) {
916 let Some(project) = project else {
917 return;
918 };
919
920 if self
921 .zeta
922 .read(cx)
923 .user_store
924 .read_with(cx, |user_store, _cx| {
925 user_store.account_too_young() || user_store.has_overdue_invoices()
926 })
927 {
928 return;
929 }
930
931 if let Some(current_prediction) = self.current_prediction.as_ref() {
932 let snapshot = buffer.read(cx).snapshot();
933 if current_prediction
934 .prediction
935 .interpolate(&snapshot)
936 .is_some()
937 {
938 return;
939 }
940 }
941
942 let pending_prediction_id = self.next_pending_prediction_id;
943 self.next_pending_prediction_id += 1;
944 let last_request_timestamp = self.last_request_timestamp;
945
946 let task = cx.spawn(async move |this, cx| {
947 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
948 .checked_duration_since(Instant::now())
949 {
950 cx.background_executor().timer(timeout).await;
951 }
952
953 let prediction_request = this.update(cx, |this, cx| {
954 this.last_request_timestamp = Instant::now();
955 this.zeta.update(cx, |zeta, cx| {
956 zeta.request_prediction(&project, &buffer, cursor_position, cx)
957 })
958 });
959
960 let prediction = match prediction_request {
961 Ok(prediction_request) => {
962 let prediction_request = prediction_request.await;
963 prediction_request.map(|c| {
964 c.map(|prediction| CurrentEditPrediction {
965 buffer_id: buffer.entity_id(),
966 prediction,
967 })
968 })
969 }
970 Err(error) => Err(error),
971 };
972
973 this.update(cx, |this, cx| {
974 if this.pending_predictions[0].id == pending_prediction_id {
975 this.pending_predictions.remove(0);
976 } else {
977 this.pending_predictions.clear();
978 }
979
980 let Some(new_prediction) = prediction
981 .context("edit prediction failed")
982 .log_err()
983 .flatten()
984 else {
985 cx.notify();
986 return;
987 };
988
989 if let Some(old_prediction) = this.current_prediction.as_ref() {
990 let snapshot = buffer.read(cx).snapshot();
991 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
992 this.current_prediction = Some(new_prediction);
993 }
994 } else {
995 this.current_prediction = Some(new_prediction);
996 }
997
998 cx.notify();
999 })
1000 .ok();
1001 });
1002
1003 // We always maintain at most two pending predictions. When we already
1004 // have two, we replace the newest one.
1005 if self.pending_predictions.len() <= 1 {
1006 self.pending_predictions.push(PendingPrediction {
1007 id: pending_prediction_id,
1008 _task: task,
1009 });
1010 } else if self.pending_predictions.len() == 2 {
1011 self.pending_predictions.pop();
1012 self.pending_predictions.push(PendingPrediction {
1013 id: pending_prediction_id,
1014 _task: task,
1015 });
1016 }
1017
1018 cx.notify();
1019 }
1020
1021 fn cycle(
1022 &mut self,
1023 _buffer: Entity<language::Buffer>,
1024 _cursor_position: language::Anchor,
1025 _direction: Direction,
1026 _cx: &mut Context<Self>,
1027 ) {
1028 }
1029
1030 fn accept(&mut self, _cx: &mut Context<Self>) {
1031 // TODO [zeta2] report accept
1032 self.current_prediction.take();
1033 self.pending_predictions.clear();
1034 }
1035
1036 fn discard(&mut self, _cx: &mut Context<Self>) {
1037 self.pending_predictions.clear();
1038 self.current_prediction.take();
1039 }
1040
1041 fn suggest(
1042 &mut self,
1043 buffer: &Entity<language::Buffer>,
1044 cursor_position: language::Anchor,
1045 cx: &mut Context<Self>,
1046 ) -> Option<edit_prediction::EditPrediction> {
1047 let CurrentEditPrediction {
1048 buffer_id,
1049 prediction,
1050 ..
1051 } = self.current_prediction.as_mut()?;
1052
1053 // Invalidate previous prediction if it was generated for a different buffer.
1054 if *buffer_id != buffer.entity_id() {
1055 self.current_prediction.take();
1056 return None;
1057 }
1058
1059 let buffer = buffer.read(cx);
1060 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
1061 self.current_prediction.take();
1062 return None;
1063 };
1064
1065 let cursor_row = cursor_position.to_point(buffer).row;
1066 let (closest_edit_ix, (closest_edit_range, _)) =
1067 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1068 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1069 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1070 cmp::min(distance_from_start, distance_from_end)
1071 })?;
1072
1073 let mut edit_start_ix = closest_edit_ix;
1074 for (range, _) in edits[..edit_start_ix].iter().rev() {
1075 let distance_from_closest_edit =
1076 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1077 if distance_from_closest_edit <= 1 {
1078 edit_start_ix -= 1;
1079 } else {
1080 break;
1081 }
1082 }
1083
1084 let mut edit_end_ix = closest_edit_ix + 1;
1085 for (range, _) in &edits[edit_end_ix..] {
1086 let distance_from_closest_edit =
1087 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1088 if distance_from_closest_edit <= 1 {
1089 edit_end_ix += 1;
1090 } else {
1091 break;
1092 }
1093 }
1094
1095 Some(edit_prediction::EditPrediction {
1096 id: Some(prediction.id.to_string().into()),
1097 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1098 edit_preview: Some(prediction.edit_preview.clone()),
1099 })
1100 }
1101}
1102
1103fn make_cloud_request(
1104 excerpt_path: PathBuf,
1105 context: EditPredictionContext,
1106 events: Vec<predict_edits_v3::Event>,
1107 can_collect_data: bool,
1108 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1109 diagnostic_groups_truncated: bool,
1110 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1111 debug_info: bool,
1112 worktrees: &Vec<worktree::Snapshot>,
1113 index_state: Option<&SyntaxIndexState>,
1114 prompt_max_bytes: Option<usize>,
1115) -> predict_edits_v3::PredictEditsRequest {
1116 let mut signatures = Vec::new();
1117 let mut declaration_to_signature_index = HashMap::default();
1118 let mut referenced_declarations = Vec::new();
1119
1120 for snippet in context.snippets {
1121 let project_entry_id = snippet.declaration.project_entry_id();
1122 let Some(path) = worktrees.iter().find_map(|worktree| {
1123 worktree.entry_for_id(project_entry_id).map(|entry| {
1124 let mut full_path = PathBuf::new();
1125 full_path.push(worktree.root_name());
1126 full_path.push(&entry.path);
1127 full_path
1128 })
1129 }) else {
1130 continue;
1131 };
1132
1133 let parent_index = index_state.and_then(|index_state| {
1134 snippet.declaration.parent().and_then(|parent| {
1135 add_signature(
1136 parent,
1137 &mut declaration_to_signature_index,
1138 &mut signatures,
1139 index_state,
1140 )
1141 })
1142 });
1143
1144 let (text, text_is_truncated) = snippet.declaration.item_text();
1145 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1146 path,
1147 text: text.into(),
1148 range: snippet.declaration.item_range(),
1149 text_is_truncated,
1150 signature_range: snippet.declaration.signature_range_in_item_text(),
1151 parent_index,
1152 score_components: snippet.score_components,
1153 signature_score: snippet.scores.signature,
1154 declaration_score: snippet.scores.declaration,
1155 });
1156 }
1157
1158 let excerpt_parent = index_state.and_then(|index_state| {
1159 context
1160 .excerpt
1161 .parent_declarations
1162 .last()
1163 .and_then(|(parent, _)| {
1164 add_signature(
1165 *parent,
1166 &mut declaration_to_signature_index,
1167 &mut signatures,
1168 index_state,
1169 )
1170 })
1171 });
1172
1173 predict_edits_v3::PredictEditsRequest {
1174 excerpt_path,
1175 excerpt: context.excerpt_text.body,
1176 excerpt_range: context.excerpt.range,
1177 cursor_offset: context.cursor_offset_in_excerpt,
1178 referenced_declarations,
1179 signatures,
1180 excerpt_parent,
1181 events,
1182 can_collect_data,
1183 diagnostic_groups,
1184 diagnostic_groups_truncated,
1185 git_info,
1186 debug_info,
1187 prompt_max_bytes,
1188 }
1189}
1190
1191fn add_signature(
1192 declaration_id: DeclarationId,
1193 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1194 signatures: &mut Vec<Signature>,
1195 index: &SyntaxIndexState,
1196) -> Option<usize> {
1197 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1198 return Some(*signature_index);
1199 }
1200 let Some(parent_declaration) = index.declaration(declaration_id) else {
1201 log::error!("bug: missing parent declaration");
1202 return None;
1203 };
1204 let parent_index = parent_declaration.parent().and_then(|parent| {
1205 add_signature(parent, declaration_to_signature_index, signatures, index)
1206 });
1207 let (text, text_is_truncated) = parent_declaration.signature_text();
1208 let signature_index = signatures.len();
1209 signatures.push(Signature {
1210 text: text.into(),
1211 text_is_truncated,
1212 parent_index,
1213 range: parent_declaration.signature_range(),
1214 });
1215 declaration_to_signature_index.insert(declaration_id, signature_index);
1216 Some(signature_index)
1217}
1218
1219fn interpolate(
1220 old_snapshot: &BufferSnapshot,
1221 new_snapshot: &BufferSnapshot,
1222 current_edits: Arc<[(Range<Anchor>, String)]>,
1223) -> Option<Vec<(Range<Anchor>, String)>> {
1224 let mut edits = Vec::new();
1225
1226 let mut model_edits = current_edits.iter().peekable();
1227 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1228 while let Some((model_old_range, _)) = model_edits.peek() {
1229 let model_old_range = model_old_range.to_offset(old_snapshot);
1230 if model_old_range.end < user_edit.old.start {
1231 let (model_old_range, model_new_text) = model_edits.next().unwrap();
1232 edits.push((model_old_range.clone(), model_new_text.clone()));
1233 } else {
1234 break;
1235 }
1236 }
1237
1238 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1239 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1240 if user_edit.old == model_old_offset_range {
1241 let user_new_text = new_snapshot
1242 .text_for_range(user_edit.new.clone())
1243 .collect::<String>();
1244
1245 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1246 if !model_suffix.is_empty() {
1247 let anchor = old_snapshot.anchor_after(user_edit.old.end);
1248 edits.push((anchor..anchor, model_suffix.to_string()));
1249 }
1250
1251 model_edits.next();
1252 continue;
1253 }
1254 }
1255 }
1256
1257 return None;
1258 }
1259
1260 edits.extend(model_edits.cloned());
1261
1262 if edits.is_empty() { None } else { Some(edits) }
1263}
1264
1265#[cfg(test)]
1266mod tests {
1267 use super::*;
1268 use gpui::TestAppContext;
1269 use indoc::indoc;
1270
1271 #[gpui::test]
1272 async fn test_compute_edits(cx: &mut TestAppContext) {
1273 let old = indoc! {r#"
1274 fn main() {
1275 let args =
1276 println!("{}", args[1])
1277 }
1278 "#};
1279
1280 let new = indoc! {r#"
1281 fn main() {
1282 let args = std::env::args();
1283 println!("{}", args[1]);
1284 }
1285 "#};
1286
1287 let buffer = cx.new(|cx| Buffer::local(old, cx));
1288 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1289
1290 // TODO cover more cases when multi-file is supported
1291 let big_edits = vec![predict_edits_v3::Edit {
1292 path: PathBuf::from("test.txt"),
1293 range: 0..old.len(),
1294 content: new.into(),
1295 }];
1296
1297 let edits = Zeta::compute_edits(&big_edits, &snapshot);
1298 assert_eq!(edits.len(), 2);
1299 assert_eq!(
1300 edits[0].0.to_point(&snapshot).start,
1301 language::Point::new(1, 14)
1302 );
1303 assert_eq!(edits[0].1, " std::env::args();");
1304 assert_eq!(
1305 edits[1].0.to_point(&snapshot).start,
1306 language::Point::new(2, 27)
1307 );
1308 assert_eq!(edits[1].1, ";");
1309 }
1310
1311 #[gpui::test]
1312 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1313 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1314 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1315 to_prediction_edits(
1316 [(2..5, "REM".to_string()), (9..11, "".to_string())],
1317 &buffer,
1318 cx,
1319 )
1320 .into()
1321 });
1322
1323 let edit_preview = cx
1324 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1325 .await;
1326
1327 let prediction = EditPrediction {
1328 id: EditPredictionId(Uuid::new_v4()),
1329 edits,
1330 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1331 edit_preview,
1332 };
1333
1334 cx.update(|cx| {
1335 assert_eq!(
1336 from_prediction_edits(
1337 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1338 &buffer,
1339 cx
1340 ),
1341 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1342 );
1343
1344 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1345 assert_eq!(
1346 from_prediction_edits(
1347 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1348 &buffer,
1349 cx
1350 ),
1351 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1352 );
1353
1354 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1355 assert_eq!(
1356 from_prediction_edits(
1357 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1358 &buffer,
1359 cx
1360 ),
1361 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1362 );
1363
1364 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1365 assert_eq!(
1366 from_prediction_edits(
1367 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1368 &buffer,
1369 cx
1370 ),
1371 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1372 );
1373
1374 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1375 assert_eq!(
1376 from_prediction_edits(
1377 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1378 &buffer,
1379 cx
1380 ),
1381 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1382 );
1383
1384 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1385 assert_eq!(
1386 from_prediction_edits(
1387 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1388 &buffer,
1389 cx
1390 ),
1391 vec![(9..11, "".to_string())]
1392 );
1393
1394 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1395 assert_eq!(
1396 from_prediction_edits(
1397 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1398 &buffer,
1399 cx
1400 ),
1401 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1402 );
1403
1404 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1405 assert_eq!(
1406 from_prediction_edits(
1407 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1408 &buffer,
1409 cx
1410 ),
1411 vec![(4..4, "M".to_string())]
1412 );
1413
1414 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1415 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1416 })
1417 }
1418
1419 fn to_prediction_edits(
1420 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1421 buffer: &Entity<Buffer>,
1422 cx: &App,
1423 ) -> Vec<(Range<Anchor>, String)> {
1424 let buffer = buffer.read(cx);
1425 iterator
1426 .into_iter()
1427 .map(|(range, text)| {
1428 (
1429 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1430 text,
1431 )
1432 })
1433 .collect()
1434 }
1435
1436 fn from_prediction_edits(
1437 editor_edits: &[(Range<Anchor>, String)],
1438 buffer: &Entity<Buffer>,
1439 cx: &App,
1440 ) -> Vec<(Range<usize>, String)> {
1441 let buffer = buffer.read(cx);
1442 editor_edits
1443 .iter()
1444 .map(|(range, text)| {
1445 (
1446 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1447 text.clone(),
1448 )
1449 })
1450 .collect()
1451 }
1452}