1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
24 http_client::{self, AsyncBody, Method},
25 prelude::*,
26};
27use language::language_settings::all_language_settings;
28use language::{Anchor, Buffer, File, Point, ToPoint};
29use language::{BufferSnapshot, OffsetRangeExt};
30use language_model::{LlmApiToken, RefreshLlmTokenListener};
31use project::{Project, ProjectPath, WorktreeId};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
36use std::collections::{VecDeque, hash_map};
37use workspace::Workspace;
38
39use std::ops::Range;
40use std::path::Path;
41use std::rc::Rc;
42use std::str::FromStr as _;
43use std::sync::{Arc, LazyLock};
44use std::time::{Duration, Instant};
45use std::{env, mem};
46use thiserror::Error;
47use util::{RangeExt as _, ResultExt as _};
48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
49
50mod cursor_excerpt;
51mod license_detection;
52pub mod mercury;
53mod onboarding_modal;
54pub mod open_ai_response;
55mod prediction;
56pub mod sweep_ai;
57
58#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
59pub mod udiff;
60
61mod zed_edit_prediction_delegate;
62pub mod zeta1;
63pub mod zeta2;
64
65#[cfg(test)]
66mod edit_prediction_tests;
67
68use crate::license_detection::LicenseDetectionWatcher;
69use crate::mercury::Mercury;
70use crate::onboarding_modal::ZedPredictModal;
71pub use crate::prediction::EditPrediction;
72pub use crate::prediction::EditPredictionId;
73use crate::prediction::EditPredictionResult;
74pub use crate::sweep_ai::SweepAi;
75pub use telemetry_events::EditPredictionRating;
76pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
77
78actions!(
79 edit_prediction,
80 [
81 /// Resets the edit prediction onboarding state.
82 ResetOnboarding,
83 /// Clears the edit prediction history.
84 ClearHistory,
85 ]
86);
87
88/// Maximum number of events to track.
89const EVENT_COUNT_MAX: usize = 6;
90const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
91const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
92const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
93
94pub struct SweepFeatureFlag;
95
96impl FeatureFlag for SweepFeatureFlag {
97 const NAME: &str = "sweep-ai";
98}
99
100pub struct MercuryFeatureFlag;
101
102impl FeatureFlag for MercuryFeatureFlag {
103 const NAME: &str = "mercury";
104}
105
106pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
107 context: EditPredictionExcerptOptions {
108 max_bytes: 512,
109 min_bytes: 128,
110 target_before_cursor_over_total_bytes: 0.5,
111 },
112 prompt_format: PromptFormat::DEFAULT,
113};
114
115static USE_OLLAMA: LazyLock<bool> =
116 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
117
118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
119 match env::var("ZED_ZETA2_MODEL").as_deref() {
120 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
121 Ok(model) => model,
122 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
123 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
124 }
125 .to_string()
126});
127static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
128 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
129 if *USE_OLLAMA {
130 Some("http://localhost:11434/v1/chat/completions".into())
131 } else {
132 None
133 }
134 })
135});
136
137pub struct Zeta2FeatureFlag;
138
139impl FeatureFlag for Zeta2FeatureFlag {
140 const NAME: &'static str = "zeta2";
141
142 fn enabled_for_staff() -> bool {
143 true
144 }
145}
146
147#[derive(Clone)]
148struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
149
150impl Global for EditPredictionStoreGlobal {}
151
152pub struct EditPredictionStore {
153 client: Arc<Client>,
154 user_store: Entity<UserStore>,
155 llm_token: LlmApiToken,
156 _llm_token_subscription: Subscription,
157 projects: HashMap<EntityId, ProjectState>,
158 use_context: bool,
159 options: ZetaOptions,
160 update_required: bool,
161 #[cfg(feature = "cli-support")]
162 eval_cache: Option<Arc<dyn EvalCache>>,
163 edit_prediction_model: EditPredictionModel,
164 pub sweep_ai: SweepAi,
165 pub mercury: Mercury,
166 data_collection_choice: DataCollectionChoice,
167 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
168 shown_predictions: VecDeque<EditPrediction>,
169 rated_predictions: HashSet<EditPredictionId>,
170}
171
172#[derive(Copy, Clone, Default, PartialEq, Eq)]
173pub enum EditPredictionModel {
174 #[default]
175 Zeta1,
176 Zeta2,
177 Sweep,
178 Mercury,
179}
180
181pub struct EditPredictionModelInput {
182 project: Entity<Project>,
183 buffer: Entity<Buffer>,
184 snapshot: BufferSnapshot,
185 position: Anchor,
186 events: Vec<Arc<zeta_prompt::Event>>,
187 related_files: Arc<[RelatedFile]>,
188 recent_paths: VecDeque<ProjectPath>,
189 trigger: PredictEditsRequestTrigger,
190 diagnostic_search_range: Range<Point>,
191 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
192}
193
194#[derive(Debug, Clone, PartialEq)]
195pub struct ZetaOptions {
196 pub context: EditPredictionExcerptOptions,
197 pub prompt_format: predict_edits_v3::PromptFormat,
198}
199
200#[derive(Debug)]
201pub enum DebugEvent {
202 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
203 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
204 EditPredictionStarted(EditPredictionStartedDebugEvent),
205 EditPredictionFinished(EditPredictionFinishedDebugEvent),
206}
207
208#[derive(Debug)]
209pub struct ContextRetrievalStartedDebugEvent {
210 pub project_entity_id: EntityId,
211 pub timestamp: Instant,
212 pub search_prompt: String,
213}
214
215#[derive(Debug)]
216pub struct ContextRetrievalFinishedDebugEvent {
217 pub project_entity_id: EntityId,
218 pub timestamp: Instant,
219 pub metadata: Vec<(&'static str, SharedString)>,
220}
221
222#[derive(Debug)]
223pub struct EditPredictionStartedDebugEvent {
224 pub buffer: WeakEntity<Buffer>,
225 pub position: Anchor,
226 pub prompt: Option<String>,
227}
228
229#[derive(Debug)]
230pub struct EditPredictionFinishedDebugEvent {
231 pub buffer: WeakEntity<Buffer>,
232 pub position: Anchor,
233 pub model_output: Option<String>,
234}
235
236pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
237
238struct ProjectState {
239 events: VecDeque<Arc<zeta_prompt::Event>>,
240 last_event: Option<LastEvent>,
241 recent_paths: VecDeque<ProjectPath>,
242 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
243 current_prediction: Option<CurrentEditPrediction>,
244 next_pending_prediction_id: usize,
245 pending_predictions: ArrayVec<PendingPrediction, 2>,
246 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
247 last_prediction_refresh: Option<(EntityId, Instant)>,
248 cancelled_predictions: HashSet<usize>,
249 context: Entity<RelatedExcerptStore>,
250 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
251 _subscription: gpui::Subscription,
252}
253
254impl ProjectState {
255 pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
256 self.events
257 .iter()
258 .cloned()
259 .chain(
260 self.last_event
261 .as_ref()
262 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
263 )
264 .collect()
265 }
266
267 fn cancel_pending_prediction(
268 &mut self,
269 pending_prediction: PendingPrediction,
270 cx: &mut Context<EditPredictionStore>,
271 ) {
272 self.cancelled_predictions.insert(pending_prediction.id);
273
274 cx.spawn(async move |this, cx| {
275 let Some(prediction_id) = pending_prediction.task.await else {
276 return;
277 };
278
279 this.update(cx, |this, _cx| {
280 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
281 })
282 .ok();
283 })
284 .detach()
285 }
286
287 fn active_buffer(
288 &self,
289 project: &Entity<Project>,
290 cx: &App,
291 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
292 let project = project.read(cx);
293 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
294 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
295 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
296 Some((active_buffer, registered_buffer.last_position))
297 }
298}
299
300#[derive(Debug, Clone)]
301struct CurrentEditPrediction {
302 pub requested_by: PredictionRequestedBy,
303 pub prediction: EditPrediction,
304 pub was_shown: bool,
305}
306
307impl CurrentEditPrediction {
308 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
309 let Some(new_edits) = self
310 .prediction
311 .interpolate(&self.prediction.buffer.read(cx))
312 else {
313 return false;
314 };
315
316 if self.prediction.buffer != old_prediction.prediction.buffer {
317 return true;
318 }
319
320 let Some(old_edits) = old_prediction
321 .prediction
322 .interpolate(&old_prediction.prediction.buffer.read(cx))
323 else {
324 return true;
325 };
326
327 let requested_by_buffer_id = self.requested_by.buffer_id();
328
329 // This reduces the occurrence of UI thrash from replacing edits
330 //
331 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
332 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
333 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
334 && old_edits.len() == 1
335 && new_edits.len() == 1
336 {
337 let (old_range, old_text) = &old_edits[0];
338 let (new_range, new_text) = &new_edits[0];
339 new_range == old_range && new_text.starts_with(old_text.as_ref())
340 } else {
341 true
342 }
343 }
344}
345
346#[derive(Debug, Clone)]
347enum PredictionRequestedBy {
348 DiagnosticsUpdate,
349 Buffer(EntityId),
350}
351
352impl PredictionRequestedBy {
353 pub fn buffer_id(&self) -> Option<EntityId> {
354 match self {
355 PredictionRequestedBy::DiagnosticsUpdate => None,
356 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
357 }
358 }
359}
360
361#[derive(Debug)]
362struct PendingPrediction {
363 id: usize,
364 task: Task<Option<EditPredictionId>>,
365}
366
367/// A prediction from the perspective of a buffer.
368#[derive(Debug)]
369enum BufferEditPrediction<'a> {
370 Local { prediction: &'a EditPrediction },
371 Jump { prediction: &'a EditPrediction },
372}
373
374#[cfg(test)]
375impl std::ops::Deref for BufferEditPrediction<'_> {
376 type Target = EditPrediction;
377
378 fn deref(&self) -> &Self::Target {
379 match self {
380 BufferEditPrediction::Local { prediction } => prediction,
381 BufferEditPrediction::Jump { prediction } => prediction,
382 }
383 }
384}
385
386struct RegisteredBuffer {
387 snapshot: BufferSnapshot,
388 last_position: Option<Anchor>,
389 _subscriptions: [gpui::Subscription; 2],
390}
391
392struct LastEvent {
393 old_snapshot: BufferSnapshot,
394 new_snapshot: BufferSnapshot,
395 end_edit_anchor: Option<Anchor>,
396}
397
398impl LastEvent {
399 pub fn finalize(
400 &self,
401 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
402 cx: &App,
403 ) -> Option<Arc<zeta_prompt::Event>> {
404 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
405 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
406
407 let file = self.new_snapshot.file();
408 let old_file = self.old_snapshot.file();
409
410 let in_open_source_repo = [file, old_file].iter().all(|file| {
411 file.is_some_and(|file| {
412 license_detection_watchers
413 .get(&file.worktree_id(cx))
414 .is_some_and(|watcher| watcher.is_project_open_source())
415 })
416 });
417
418 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
419
420 if path == old_path && diff.is_empty() {
421 None
422 } else {
423 Some(Arc::new(zeta_prompt::Event::BufferChange {
424 old_path,
425 path,
426 diff,
427 in_open_source_repo,
428 // TODO: Actually detect if this edit was predicted or not
429 predicted: false,
430 }))
431 }
432 }
433}
434
435fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
436 if let Some(file) = snapshot.file() {
437 file.full_path(cx).into()
438 } else {
439 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
440 }
441}
442
443impl EditPredictionStore {
444 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
445 cx.try_global::<EditPredictionStoreGlobal>()
446 .map(|global| global.0.clone())
447 }
448
449 pub fn global(
450 client: &Arc<Client>,
451 user_store: &Entity<UserStore>,
452 cx: &mut App,
453 ) -> Entity<Self> {
454 cx.try_global::<EditPredictionStoreGlobal>()
455 .map(|global| global.0.clone())
456 .unwrap_or_else(|| {
457 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
458 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
459 ep_store
460 })
461 }
462
463 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
464 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
465 let data_collection_choice = Self::load_data_collection_choice();
466
467 let llm_token = LlmApiToken::default();
468
469 let (reject_tx, reject_rx) = mpsc::unbounded();
470 cx.background_spawn({
471 let client = client.clone();
472 let llm_token = llm_token.clone();
473 let app_version = AppVersion::global(cx);
474 let background_executor = cx.background_executor().clone();
475 async move {
476 Self::handle_rejected_predictions(
477 reject_rx,
478 client,
479 llm_token,
480 app_version,
481 background_executor,
482 )
483 .await
484 }
485 })
486 .detach();
487
488 let mut this = Self {
489 projects: HashMap::default(),
490 client,
491 user_store,
492 options: DEFAULT_OPTIONS,
493 use_context: false,
494 llm_token,
495 _llm_token_subscription: cx.subscribe(
496 &refresh_llm_token_listener,
497 |this, _listener, _event, cx| {
498 let client = this.client.clone();
499 let llm_token = this.llm_token.clone();
500 cx.spawn(async move |_this, _cx| {
501 llm_token.refresh(&client).await?;
502 anyhow::Ok(())
503 })
504 .detach_and_log_err(cx);
505 },
506 ),
507 update_required: false,
508 #[cfg(feature = "cli-support")]
509 eval_cache: None,
510 edit_prediction_model: EditPredictionModel::Zeta2,
511 sweep_ai: SweepAi::new(cx),
512 mercury: Mercury::new(cx),
513 data_collection_choice,
514 reject_predictions_tx: reject_tx,
515 rated_predictions: Default::default(),
516 shown_predictions: Default::default(),
517 };
518
519 this.configure_context_retrieval(cx);
520 let weak_this = cx.weak_entity();
521 cx.on_flags_ready(move |_, cx| {
522 weak_this
523 .update(cx, |this, cx| this.configure_context_retrieval(cx))
524 .ok();
525 })
526 .detach();
527 cx.observe_global::<SettingsStore>(|this, cx| {
528 this.configure_context_retrieval(cx);
529 })
530 .detach();
531
532 this
533 }
534
535 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
536 self.edit_prediction_model = model;
537 }
538
539 pub fn has_sweep_api_token(&self) -> bool {
540 self.sweep_ai
541 .api_token
542 .clone()
543 .now_or_never()
544 .flatten()
545 .is_some()
546 }
547
548 pub fn has_mercury_api_token(&self) -> bool {
549 self.mercury
550 .api_token
551 .clone()
552 .now_or_never()
553 .flatten()
554 .is_some()
555 }
556
557 #[cfg(feature = "cli-support")]
558 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
559 self.eval_cache = Some(cache);
560 }
561
562 pub fn options(&self) -> &ZetaOptions {
563 &self.options
564 }
565
566 pub fn set_options(&mut self, options: ZetaOptions) {
567 self.options = options;
568 }
569
570 pub fn set_use_context(&mut self, use_context: bool) {
571 self.use_context = use_context;
572 }
573
574 pub fn clear_history(&mut self) {
575 for project_state in self.projects.values_mut() {
576 project_state.events.clear();
577 }
578 }
579
580 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
581 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
582 project_state.events.clear();
583 }
584 }
585
586 pub fn edit_history_for_project(
587 &self,
588 project: &Entity<Project>,
589 ) -> Vec<Arc<zeta_prompt::Event>> {
590 self.projects
591 .get(&project.entity_id())
592 .map(|project_state| project_state.events.iter().cloned().collect())
593 .unwrap_or_default()
594 }
595
596 pub fn context_for_project<'a>(
597 &'a self,
598 project: &Entity<Project>,
599 cx: &'a App,
600 ) -> Arc<[RelatedFile]> {
601 self.projects
602 .get(&project.entity_id())
603 .map(|project| project.context.read(cx).related_files())
604 .unwrap_or_else(|| vec![].into())
605 }
606
607 pub fn context_for_project_with_buffers<'a>(
608 &'a self,
609 project: &Entity<Project>,
610 cx: &'a App,
611 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
612 self.projects
613 .get(&project.entity_id())
614 .map(|project| project.context.read(cx).related_files_with_buffers())
615 }
616
617 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
618 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
619 self.user_store.read(cx).edit_prediction_usage()
620 } else {
621 None
622 }
623 }
624
625 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
626 self.get_or_init_project(project, cx);
627 }
628
629 pub fn register_buffer(
630 &mut self,
631 buffer: &Entity<Buffer>,
632 project: &Entity<Project>,
633 cx: &mut Context<Self>,
634 ) {
635 let project_state = self.get_or_init_project(project, cx);
636 Self::register_buffer_impl(project_state, buffer, project, cx);
637 }
638
639 fn get_or_init_project(
640 &mut self,
641 project: &Entity<Project>,
642 cx: &mut Context<Self>,
643 ) -> &mut ProjectState {
644 let entity_id = project.entity_id();
645 self.projects
646 .entry(entity_id)
647 .or_insert_with(|| ProjectState {
648 context: {
649 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
650 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
651 this.handle_excerpt_store_event(entity_id, event);
652 })
653 .detach();
654 related_excerpt_store
655 },
656 events: VecDeque::new(),
657 last_event: None,
658 recent_paths: VecDeque::new(),
659 debug_tx: None,
660 registered_buffers: HashMap::default(),
661 current_prediction: None,
662 cancelled_predictions: HashSet::default(),
663 pending_predictions: ArrayVec::new(),
664 next_pending_prediction_id: 0,
665 last_prediction_refresh: None,
666 license_detection_watchers: HashMap::default(),
667 _subscription: cx.subscribe(&project, Self::handle_project_event),
668 })
669 }
670
671 pub fn remove_project(&mut self, project: &Entity<Project>) {
672 self.projects.remove(&project.entity_id());
673 }
674
675 fn handle_excerpt_store_event(
676 &mut self,
677 project_entity_id: EntityId,
678 event: &RelatedExcerptStoreEvent,
679 ) {
680 if let Some(project_state) = self.projects.get(&project_entity_id) {
681 if let Some(debug_tx) = project_state.debug_tx.clone() {
682 match event {
683 RelatedExcerptStoreEvent::StartedRefresh => {
684 debug_tx
685 .unbounded_send(DebugEvent::ContextRetrievalStarted(
686 ContextRetrievalStartedDebugEvent {
687 project_entity_id: project_entity_id,
688 timestamp: Instant::now(),
689 search_prompt: String::new(),
690 },
691 ))
692 .ok();
693 }
694 RelatedExcerptStoreEvent::FinishedRefresh {
695 cache_hit_count,
696 cache_miss_count,
697 mean_definition_latency,
698 max_definition_latency,
699 } => {
700 debug_tx
701 .unbounded_send(DebugEvent::ContextRetrievalFinished(
702 ContextRetrievalFinishedDebugEvent {
703 project_entity_id: project_entity_id,
704 timestamp: Instant::now(),
705 metadata: vec![
706 (
707 "Cache Hits",
708 format!(
709 "{}/{}",
710 cache_hit_count,
711 cache_hit_count + cache_miss_count
712 )
713 .into(),
714 ),
715 (
716 "Max LSP Time",
717 format!("{} ms", max_definition_latency.as_millis())
718 .into(),
719 ),
720 (
721 "Mean LSP Time",
722 format!("{} ms", mean_definition_latency.as_millis())
723 .into(),
724 ),
725 ],
726 },
727 ))
728 .ok();
729 }
730 }
731 }
732 }
733 }
734
735 pub fn debug_info(
736 &mut self,
737 project: &Entity<Project>,
738 cx: &mut Context<Self>,
739 ) -> mpsc::UnboundedReceiver<DebugEvent> {
740 let project_state = self.get_or_init_project(project, cx);
741 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
742 project_state.debug_tx = Some(debug_watch_tx);
743 debug_watch_rx
744 }
745
746 fn handle_project_event(
747 &mut self,
748 project: Entity<Project>,
749 event: &project::Event,
750 cx: &mut Context<Self>,
751 ) {
752 // TODO [zeta2] init with recent paths
753 match event {
754 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
755 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
756 return;
757 };
758 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
759 if let Some(path) = path {
760 if let Some(ix) = project_state
761 .recent_paths
762 .iter()
763 .position(|probe| probe == &path)
764 {
765 project_state.recent_paths.remove(ix);
766 }
767 project_state.recent_paths.push_front(path);
768 }
769 }
770 project::Event::DiagnosticsUpdated { .. } => {
771 if cx.has_flag::<Zeta2FeatureFlag>() {
772 self.refresh_prediction_from_diagnostics(project, cx);
773 }
774 }
775 _ => (),
776 }
777 }
778
779 fn register_buffer_impl<'a>(
780 project_state: &'a mut ProjectState,
781 buffer: &Entity<Buffer>,
782 project: &Entity<Project>,
783 cx: &mut Context<Self>,
784 ) -> &'a mut RegisteredBuffer {
785 let buffer_id = buffer.entity_id();
786
787 if let Some(file) = buffer.read(cx).file() {
788 let worktree_id = file.worktree_id(cx);
789 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
790 project_state
791 .license_detection_watchers
792 .entry(worktree_id)
793 .or_insert_with(|| {
794 let project_entity_id = project.entity_id();
795 cx.observe_release(&worktree, move |this, _worktree, _cx| {
796 let Some(project_state) = this.projects.get_mut(&project_entity_id)
797 else {
798 return;
799 };
800 project_state
801 .license_detection_watchers
802 .remove(&worktree_id);
803 })
804 .detach();
805 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
806 });
807 }
808 }
809
810 match project_state.registered_buffers.entry(buffer_id) {
811 hash_map::Entry::Occupied(entry) => entry.into_mut(),
812 hash_map::Entry::Vacant(entry) => {
813 let snapshot = buffer.read(cx).snapshot();
814 let project_entity_id = project.entity_id();
815 entry.insert(RegisteredBuffer {
816 snapshot,
817 last_position: None,
818 _subscriptions: [
819 cx.subscribe(buffer, {
820 let project = project.downgrade();
821 move |this, buffer, event, cx| {
822 if let language::BufferEvent::Edited = event
823 && let Some(project) = project.upgrade()
824 {
825 this.report_changes_for_buffer(&buffer, &project, cx);
826 }
827 }
828 }),
829 cx.observe_release(buffer, move |this, _buffer, _cx| {
830 let Some(project_state) = this.projects.get_mut(&project_entity_id)
831 else {
832 return;
833 };
834 project_state.registered_buffers.remove(&buffer_id);
835 }),
836 ],
837 })
838 }
839 }
840 }
841
842 fn report_changes_for_buffer(
843 &mut self,
844 buffer: &Entity<Buffer>,
845 project: &Entity<Project>,
846 cx: &mut Context<Self>,
847 ) {
848 let project_state = self.get_or_init_project(project, cx);
849 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
850
851 let new_snapshot = buffer.read(cx).snapshot();
852 if new_snapshot.version == registered_buffer.snapshot.version {
853 return;
854 }
855
856 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
857 let end_edit_anchor = new_snapshot
858 .anchored_edits_since::<Point>(&old_snapshot.version)
859 .last()
860 .map(|(_, range)| range.end);
861 let events = &mut project_state.events;
862
863 if let Some(LastEvent {
864 new_snapshot: last_new_snapshot,
865 end_edit_anchor: last_end_edit_anchor,
866 ..
867 }) = project_state.last_event.as_mut()
868 {
869 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
870 == last_new_snapshot.remote_id()
871 && old_snapshot.version == last_new_snapshot.version;
872
873 let should_coalesce = is_next_snapshot_of_same_buffer
874 && end_edit_anchor
875 .as_ref()
876 .zip(last_end_edit_anchor.as_ref())
877 .is_some_and(|(a, b)| {
878 let a = a.to_point(&new_snapshot);
879 let b = b.to_point(&new_snapshot);
880 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
881 });
882
883 if should_coalesce {
884 *last_end_edit_anchor = end_edit_anchor;
885 *last_new_snapshot = new_snapshot;
886 return;
887 }
888 }
889
890 if events.len() + 1 >= EVENT_COUNT_MAX {
891 events.pop_front();
892 }
893
894 if let Some(event) = project_state.last_event.take() {
895 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
896 }
897
898 project_state.last_event = Some(LastEvent {
899 old_snapshot,
900 new_snapshot,
901 end_edit_anchor,
902 });
903 }
904
905 fn prediction_at(
906 &mut self,
907 buffer: &Entity<Buffer>,
908 position: Option<language::Anchor>,
909 project: &Entity<Project>,
910 cx: &App,
911 ) -> Option<BufferEditPrediction<'_>> {
912 let project_state = self.projects.get_mut(&project.entity_id())?;
913 if let Some(position) = position
914 && let Some(buffer) = project_state
915 .registered_buffers
916 .get_mut(&buffer.entity_id())
917 {
918 buffer.last_position = Some(position);
919 }
920
921 let CurrentEditPrediction {
922 requested_by,
923 prediction,
924 ..
925 } = project_state.current_prediction.as_ref()?;
926
927 if prediction.targets_buffer(buffer.read(cx)) {
928 Some(BufferEditPrediction::Local { prediction })
929 } else {
930 let show_jump = match requested_by {
931 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
932 requested_by_buffer_id == &buffer.entity_id()
933 }
934 PredictionRequestedBy::DiagnosticsUpdate => true,
935 };
936
937 if show_jump {
938 Some(BufferEditPrediction::Jump { prediction })
939 } else {
940 None
941 }
942 }
943 }
944
945 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
946 match self.edit_prediction_model {
947 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
948 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
949 }
950
951 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
952 return;
953 };
954
955 let Some(prediction) = project_state.current_prediction.take() else {
956 return;
957 };
958 let request_id = prediction.prediction.id.to_string();
959 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
960 project_state.cancel_pending_prediction(pending_prediction, cx);
961 }
962
963 let client = self.client.clone();
964 let llm_token = self.llm_token.clone();
965 let app_version = AppVersion::global(cx);
966 cx.spawn(async move |this, cx| {
967 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
968 http_client::Url::parse(&predict_edits_url)?
969 } else {
970 client
971 .http_client()
972 .build_zed_llm_url("/predict_edits/accept", &[])?
973 };
974
975 let response = cx
976 .background_spawn(Self::send_api_request::<()>(
977 move |builder| {
978 let req = builder.uri(url.as_ref()).body(
979 serde_json::to_string(&AcceptEditPredictionBody {
980 request_id: request_id.clone(),
981 })?
982 .into(),
983 );
984 Ok(req?)
985 },
986 client,
987 llm_token,
988 app_version,
989 ))
990 .await;
991
992 Self::handle_api_response(&this, response, cx)?;
993 anyhow::Ok(())
994 })
995 .detach_and_log_err(cx);
996 }
997
998 async fn handle_rejected_predictions(
999 rx: UnboundedReceiver<EditPredictionRejection>,
1000 client: Arc<Client>,
1001 llm_token: LlmApiToken,
1002 app_version: Version,
1003 background_executor: BackgroundExecutor,
1004 ) {
1005 let mut rx = std::pin::pin!(rx.peekable());
1006 let mut batched = Vec::new();
1007
1008 while let Some(rejection) = rx.next().await {
1009 batched.push(rejection);
1010
1011 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1012 select_biased! {
1013 next = rx.as_mut().peek().fuse() => {
1014 if next.is_some() {
1015 continue;
1016 }
1017 }
1018 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1019 }
1020 }
1021
1022 let url = client
1023 .http_client()
1024 .build_zed_llm_url("/predict_edits/reject", &[])
1025 .unwrap();
1026
1027 let flush_count = batched
1028 .len()
1029 // in case items have accumulated after failure
1030 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1031 let start = batched.len() - flush_count;
1032
1033 let body = RejectEditPredictionsBodyRef {
1034 rejections: &batched[start..],
1035 };
1036
1037 let result = Self::send_api_request::<()>(
1038 |builder| {
1039 let req = builder
1040 .uri(url.as_ref())
1041 .body(serde_json::to_string(&body)?.into());
1042 anyhow::Ok(req?)
1043 },
1044 client.clone(),
1045 llm_token.clone(),
1046 app_version.clone(),
1047 )
1048 .await;
1049
1050 if result.log_err().is_some() {
1051 batched.drain(start..);
1052 }
1053 }
1054 }
1055
1056 fn reject_current_prediction(
1057 &mut self,
1058 reason: EditPredictionRejectReason,
1059 project: &Entity<Project>,
1060 ) {
1061 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1062 project_state.pending_predictions.clear();
1063 if let Some(prediction) = project_state.current_prediction.take() {
1064 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1065 }
1066 };
1067 }
1068
1069 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1070 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1071 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1072 if !current_prediction.was_shown {
1073 current_prediction.was_shown = true;
1074 self.shown_predictions
1075 .push_front(current_prediction.prediction.clone());
1076 if self.shown_predictions.len() > 50 {
1077 let completion = self.shown_predictions.pop_back().unwrap();
1078 self.rated_predictions.remove(&completion.id);
1079 }
1080 }
1081 }
1082 }
1083 }
1084
1085 fn reject_prediction(
1086 &mut self,
1087 prediction_id: EditPredictionId,
1088 reason: EditPredictionRejectReason,
1089 was_shown: bool,
1090 ) {
1091 match self.edit_prediction_model {
1092 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1093 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1094 }
1095
1096 self.reject_predictions_tx
1097 .unbounded_send(EditPredictionRejection {
1098 request_id: prediction_id.to_string(),
1099 reason,
1100 was_shown,
1101 })
1102 .log_err();
1103 }
1104
1105 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1106 self.projects
1107 .get(&project.entity_id())
1108 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1109 }
1110
1111 pub fn refresh_prediction_from_buffer(
1112 &mut self,
1113 project: Entity<Project>,
1114 buffer: Entity<Buffer>,
1115 position: language::Anchor,
1116 cx: &mut Context<Self>,
1117 ) {
1118 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1119 let Some(request_task) = this
1120 .update(cx, |this, cx| {
1121 this.request_prediction(
1122 &project,
1123 &buffer,
1124 position,
1125 PredictEditsRequestTrigger::Other,
1126 cx,
1127 )
1128 })
1129 .log_err()
1130 else {
1131 return Task::ready(anyhow::Ok(None));
1132 };
1133
1134 cx.spawn(async move |_cx| {
1135 request_task.await.map(|prediction_result| {
1136 prediction_result.map(|prediction_result| {
1137 (
1138 prediction_result,
1139 PredictionRequestedBy::Buffer(buffer.entity_id()),
1140 )
1141 })
1142 })
1143 })
1144 })
1145 }
1146
1147 pub fn refresh_prediction_from_diagnostics(
1148 &mut self,
1149 project: Entity<Project>,
1150 cx: &mut Context<Self>,
1151 ) {
1152 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1153 return;
1154 };
1155
1156 // Prefer predictions from buffer
1157 if project_state.current_prediction.is_some() {
1158 return;
1159 };
1160
1161 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1162 let Some((active_buffer, snapshot, cursor_point)) = this
1163 .read_with(cx, |this, cx| {
1164 let project_state = this.projects.get(&project.entity_id())?;
1165 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1166 let snapshot = buffer.read(cx).snapshot();
1167
1168 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1169 return None;
1170 }
1171
1172 let cursor_point = position
1173 .map(|pos| pos.to_point(&snapshot))
1174 .unwrap_or_default();
1175
1176 Some((buffer, snapshot, cursor_point))
1177 })
1178 .log_err()
1179 .flatten()
1180 else {
1181 return Task::ready(anyhow::Ok(None));
1182 };
1183
1184 cx.spawn(async move |cx| {
1185 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1186 active_buffer,
1187 &snapshot,
1188 Default::default(),
1189 cursor_point,
1190 &project,
1191 cx,
1192 )
1193 .await?
1194 else {
1195 return anyhow::Ok(None);
1196 };
1197
1198 let Some(prediction_result) = this
1199 .update(cx, |this, cx| {
1200 this.request_prediction(
1201 &project,
1202 &jump_buffer,
1203 jump_position,
1204 PredictEditsRequestTrigger::Diagnostics,
1205 cx,
1206 )
1207 })?
1208 .await?
1209 else {
1210 return anyhow::Ok(None);
1211 };
1212
1213 this.update(cx, |this, cx| {
1214 Some((
1215 if this
1216 .get_or_init_project(&project, cx)
1217 .current_prediction
1218 .is_none()
1219 {
1220 prediction_result
1221 } else {
1222 EditPredictionResult {
1223 id: prediction_result.id,
1224 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1225 }
1226 },
1227 PredictionRequestedBy::DiagnosticsUpdate,
1228 ))
1229 })
1230 })
1231 });
1232 }
1233
1234 fn predictions_enabled_at(
1235 snapshot: &BufferSnapshot,
1236 position: Option<language::Anchor>,
1237 cx: &App,
1238 ) -> bool {
1239 let file = snapshot.file();
1240 let all_settings = all_language_settings(file, cx);
1241 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1242 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1243 {
1244 return false;
1245 }
1246
1247 if let Some(last_position) = position {
1248 let settings = snapshot.settings_at(last_position, cx);
1249
1250 if !settings.edit_predictions_disabled_in.is_empty()
1251 && let Some(scope) = snapshot.language_scope_at(last_position)
1252 && let Some(scope_name) = scope.override_name()
1253 && settings
1254 .edit_predictions_disabled_in
1255 .iter()
1256 .any(|s| s == scope_name)
1257 {
1258 return false;
1259 }
1260 }
1261
1262 true
1263 }
1264
1265 #[cfg(not(test))]
1266 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1267 #[cfg(test)]
1268 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1269
1270 fn queue_prediction_refresh(
1271 &mut self,
1272 project: Entity<Project>,
1273 throttle_entity: EntityId,
1274 cx: &mut Context<Self>,
1275 do_refresh: impl FnOnce(
1276 WeakEntity<Self>,
1277 &mut AsyncApp,
1278 )
1279 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1280 + 'static,
1281 ) {
1282 let project_state = self.get_or_init_project(&project, cx);
1283 let pending_prediction_id = project_state.next_pending_prediction_id;
1284 project_state.next_pending_prediction_id += 1;
1285 let last_request = project_state.last_prediction_refresh;
1286
1287 let task = cx.spawn(async move |this, cx| {
1288 if let Some((last_entity, last_timestamp)) = last_request
1289 && throttle_entity == last_entity
1290 && let Some(timeout) =
1291 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1292 {
1293 cx.background_executor().timer(timeout).await;
1294 }
1295
1296 // If this task was cancelled before the throttle timeout expired,
1297 // do not perform a request.
1298 let mut is_cancelled = true;
1299 this.update(cx, |this, cx| {
1300 let project_state = this.get_or_init_project(&project, cx);
1301 if !project_state
1302 .cancelled_predictions
1303 .remove(&pending_prediction_id)
1304 {
1305 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1306 is_cancelled = false;
1307 }
1308 })
1309 .ok();
1310 if is_cancelled {
1311 return None;
1312 }
1313
1314 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1315 let new_prediction_id = new_prediction_result
1316 .as_ref()
1317 .map(|(prediction, _)| prediction.id.clone());
1318
1319 // When a prediction completes, remove it from the pending list, and cancel
1320 // any pending predictions that were enqueued before it.
1321 this.update(cx, |this, cx| {
1322 let project_state = this.get_or_init_project(&project, cx);
1323
1324 let is_cancelled = project_state
1325 .cancelled_predictions
1326 .remove(&pending_prediction_id);
1327
1328 let new_current_prediction = if !is_cancelled
1329 && let Some((prediction_result, requested_by)) = new_prediction_result
1330 {
1331 match prediction_result.prediction {
1332 Ok(prediction) => {
1333 let new_prediction = CurrentEditPrediction {
1334 requested_by,
1335 prediction,
1336 was_shown: false,
1337 };
1338
1339 if let Some(current_prediction) =
1340 project_state.current_prediction.as_ref()
1341 {
1342 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1343 {
1344 this.reject_current_prediction(
1345 EditPredictionRejectReason::Replaced,
1346 &project,
1347 );
1348
1349 Some(new_prediction)
1350 } else {
1351 this.reject_prediction(
1352 new_prediction.prediction.id,
1353 EditPredictionRejectReason::CurrentPreferred,
1354 false,
1355 );
1356 None
1357 }
1358 } else {
1359 Some(new_prediction)
1360 }
1361 }
1362 Err(reject_reason) => {
1363 this.reject_prediction(prediction_result.id, reject_reason, false);
1364 None
1365 }
1366 }
1367 } else {
1368 None
1369 };
1370
1371 let project_state = this.get_or_init_project(&project, cx);
1372
1373 if let Some(new_prediction) = new_current_prediction {
1374 project_state.current_prediction = Some(new_prediction);
1375 }
1376
1377 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1378 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1379 if pending_prediction.id == pending_prediction_id {
1380 pending_predictions.remove(ix);
1381 for pending_prediction in pending_predictions.drain(0..ix) {
1382 project_state.cancel_pending_prediction(pending_prediction, cx)
1383 }
1384 break;
1385 }
1386 }
1387 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1388 cx.notify();
1389 })
1390 .ok();
1391
1392 new_prediction_id
1393 });
1394
1395 if project_state.pending_predictions.len() <= 1 {
1396 project_state.pending_predictions.push(PendingPrediction {
1397 id: pending_prediction_id,
1398 task,
1399 });
1400 } else if project_state.pending_predictions.len() == 2 {
1401 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1402 project_state.pending_predictions.push(PendingPrediction {
1403 id: pending_prediction_id,
1404 task,
1405 });
1406 project_state.cancel_pending_prediction(pending_prediction, cx);
1407 }
1408 }
1409
1410 pub fn request_prediction(
1411 &mut self,
1412 project: &Entity<Project>,
1413 active_buffer: &Entity<Buffer>,
1414 position: language::Anchor,
1415 trigger: PredictEditsRequestTrigger,
1416 cx: &mut Context<Self>,
1417 ) -> Task<Result<Option<EditPredictionResult>>> {
1418 self.request_prediction_internal(
1419 project.clone(),
1420 active_buffer.clone(),
1421 position,
1422 trigger,
1423 cx.has_flag::<Zeta2FeatureFlag>(),
1424 cx,
1425 )
1426 }
1427
1428 fn request_prediction_internal(
1429 &mut self,
1430 project: Entity<Project>,
1431 active_buffer: Entity<Buffer>,
1432 position: language::Anchor,
1433 trigger: PredictEditsRequestTrigger,
1434 allow_jump: bool,
1435 cx: &mut Context<Self>,
1436 ) -> Task<Result<Option<EditPredictionResult>>> {
1437 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1438
1439 self.get_or_init_project(&project, cx);
1440 let project_state = self.projects.get(&project.entity_id()).unwrap();
1441 let events = project_state.events(cx);
1442 let has_events = !events.is_empty();
1443 let debug_tx = project_state.debug_tx.clone();
1444
1445 let snapshot = active_buffer.read(cx).snapshot();
1446 let cursor_point = position.to_point(&snapshot);
1447 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1448 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1449 let diagnostic_search_range =
1450 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1451
1452 let related_files = if self.use_context {
1453 self.context_for_project(&project, cx)
1454 } else {
1455 Vec::new().into()
1456 };
1457
1458 let inputs = EditPredictionModelInput {
1459 project: project.clone(),
1460 buffer: active_buffer.clone(),
1461 snapshot: snapshot.clone(),
1462 position,
1463 events,
1464 related_files,
1465 recent_paths: project_state.recent_paths.clone(),
1466 trigger,
1467 diagnostic_search_range: diagnostic_search_range.clone(),
1468 debug_tx,
1469 };
1470
1471 let task = match self.edit_prediction_model {
1472 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1473 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1474 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1475 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1476 };
1477
1478 cx.spawn(async move |this, cx| {
1479 let prediction = task.await?;
1480
1481 if prediction.is_none() && allow_jump {
1482 let cursor_point = position.to_point(&snapshot);
1483 if has_events
1484 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1485 active_buffer.clone(),
1486 &snapshot,
1487 diagnostic_search_range,
1488 cursor_point,
1489 &project,
1490 cx,
1491 )
1492 .await?
1493 {
1494 return this
1495 .update(cx, |this, cx| {
1496 this.request_prediction_internal(
1497 project,
1498 jump_buffer,
1499 jump_position,
1500 trigger,
1501 false,
1502 cx,
1503 )
1504 })?
1505 .await;
1506 }
1507
1508 return anyhow::Ok(None);
1509 }
1510
1511 Ok(prediction)
1512 })
1513 }
1514
1515 async fn next_diagnostic_location(
1516 active_buffer: Entity<Buffer>,
1517 active_buffer_snapshot: &BufferSnapshot,
1518 active_buffer_diagnostic_search_range: Range<Point>,
1519 active_buffer_cursor_point: Point,
1520 project: &Entity<Project>,
1521 cx: &mut AsyncApp,
1522 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1523 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1524 let mut jump_location = active_buffer_snapshot
1525 .diagnostic_groups(None)
1526 .into_iter()
1527 .filter_map(|(_, group)| {
1528 let range = &group.entries[group.primary_ix]
1529 .range
1530 .to_point(&active_buffer_snapshot);
1531 if range.overlaps(&active_buffer_diagnostic_search_range) {
1532 None
1533 } else {
1534 Some(range.start)
1535 }
1536 })
1537 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1538 .map(|position| {
1539 (
1540 active_buffer.clone(),
1541 active_buffer_snapshot.anchor_before(position),
1542 )
1543 });
1544
1545 if jump_location.is_none() {
1546 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1547 let file = buffer.file()?;
1548
1549 Some(ProjectPath {
1550 worktree_id: file.worktree_id(cx),
1551 path: file.path().clone(),
1552 })
1553 })?;
1554
1555 let buffer_task = project.update(cx, |project, cx| {
1556 let (path, _, _) = project
1557 .diagnostic_summaries(false, cx)
1558 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1559 .max_by_key(|(path, _, _)| {
1560 // find the buffer with errors that shares most parent directories
1561 path.path
1562 .components()
1563 .zip(
1564 active_buffer_path
1565 .as_ref()
1566 .map(|p| p.path.components())
1567 .unwrap_or_default(),
1568 )
1569 .take_while(|(a, b)| a == b)
1570 .count()
1571 })?;
1572
1573 Some(project.open_buffer(path, cx))
1574 })?;
1575
1576 if let Some(buffer_task) = buffer_task {
1577 let closest_buffer = buffer_task.await?;
1578
1579 jump_location = closest_buffer
1580 .read_with(cx, |buffer, _cx| {
1581 buffer
1582 .buffer_diagnostics(None)
1583 .into_iter()
1584 .min_by_key(|entry| entry.diagnostic.severity)
1585 .map(|entry| entry.range.start)
1586 })?
1587 .map(|position| (closest_buffer, position));
1588 }
1589 }
1590
1591 anyhow::Ok(jump_location)
1592 }
1593
1594 async fn send_raw_llm_request(
1595 request: open_ai::Request,
1596 client: Arc<Client>,
1597 llm_token: LlmApiToken,
1598 app_version: Version,
1599 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1600 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1601 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1602 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1603 http_client::Url::parse(&predict_edits_url)?
1604 } else {
1605 client
1606 .http_client()
1607 .build_zed_llm_url("/predict_edits/raw", &[])?
1608 };
1609
1610 #[cfg(feature = "cli-support")]
1611 let cache_key = if let Some(cache) = eval_cache {
1612 use collections::FxHasher;
1613 use std::hash::{Hash, Hasher};
1614
1615 let mut hasher = FxHasher::default();
1616 url.hash(&mut hasher);
1617 let request_str = serde_json::to_string_pretty(&request)?;
1618 request_str.hash(&mut hasher);
1619 let hash = hasher.finish();
1620
1621 let key = (eval_cache_kind, hash);
1622 if let Some(response_str) = cache.read(key) {
1623 return Ok((serde_json::from_str(&response_str)?, None));
1624 }
1625
1626 Some((cache, request_str, key))
1627 } else {
1628 None
1629 };
1630
1631 let (response, usage) = Self::send_api_request(
1632 |builder| {
1633 let req = builder
1634 .uri(url.as_ref())
1635 .body(serde_json::to_string(&request)?.into());
1636 Ok(req?)
1637 },
1638 client,
1639 llm_token,
1640 app_version,
1641 )
1642 .await?;
1643
1644 #[cfg(feature = "cli-support")]
1645 if let Some((cache, request, key)) = cache_key {
1646 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1647 }
1648
1649 Ok((response, usage))
1650 }
1651
1652 fn handle_api_response<T>(
1653 this: &WeakEntity<Self>,
1654 response: Result<(T, Option<EditPredictionUsage>)>,
1655 cx: &mut gpui::AsyncApp,
1656 ) -> Result<T> {
1657 match response {
1658 Ok((data, usage)) => {
1659 if let Some(usage) = usage {
1660 this.update(cx, |this, cx| {
1661 this.user_store.update(cx, |user_store, cx| {
1662 user_store.update_edit_prediction_usage(usage, cx);
1663 });
1664 })
1665 .ok();
1666 }
1667 Ok(data)
1668 }
1669 Err(err) => {
1670 if err.is::<ZedUpdateRequiredError>() {
1671 cx.update(|cx| {
1672 this.update(cx, |this, _cx| {
1673 this.update_required = true;
1674 })
1675 .ok();
1676
1677 let error_message: SharedString = err.to_string().into();
1678 show_app_notification(
1679 NotificationId::unique::<ZedUpdateRequiredError>(),
1680 cx,
1681 move |cx| {
1682 cx.new(|cx| {
1683 ErrorMessagePrompt::new(error_message.clone(), cx)
1684 .with_link_button("Update Zed", "https://zed.dev/releases")
1685 })
1686 },
1687 );
1688 })
1689 .ok();
1690 }
1691 Err(err)
1692 }
1693 }
1694 }
1695
1696 async fn send_api_request<Res>(
1697 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1698 client: Arc<Client>,
1699 llm_token: LlmApiToken,
1700 app_version: Version,
1701 ) -> Result<(Res, Option<EditPredictionUsage>)>
1702 where
1703 Res: DeserializeOwned,
1704 {
1705 let http_client = client.http_client();
1706 let mut token = llm_token.acquire(&client).await?;
1707 let mut did_retry = false;
1708
1709 loop {
1710 let request_builder = http_client::Request::builder().method(Method::POST);
1711
1712 let request = build(
1713 request_builder
1714 .header("Content-Type", "application/json")
1715 .header("Authorization", format!("Bearer {}", token))
1716 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1717 )?;
1718
1719 let mut response = http_client.send(request).await?;
1720
1721 if let Some(minimum_required_version) = response
1722 .headers()
1723 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1724 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1725 {
1726 anyhow::ensure!(
1727 app_version >= minimum_required_version,
1728 ZedUpdateRequiredError {
1729 minimum_version: minimum_required_version
1730 }
1731 );
1732 }
1733
1734 if response.status().is_success() {
1735 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1736
1737 let mut body = Vec::new();
1738 response.body_mut().read_to_end(&mut body).await?;
1739 return Ok((serde_json::from_slice(&body)?, usage));
1740 } else if !did_retry
1741 && response
1742 .headers()
1743 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1744 .is_some()
1745 {
1746 did_retry = true;
1747 token = llm_token.refresh(&client).await?;
1748 } else {
1749 let mut body = String::new();
1750 response.body_mut().read_to_string(&mut body).await?;
1751 anyhow::bail!(
1752 "Request failed with status: {:?}\nBody: {}",
1753 response.status(),
1754 body
1755 );
1756 }
1757 }
1758 }
1759
1760 pub fn refresh_context(
1761 &mut self,
1762 project: &Entity<Project>,
1763 buffer: &Entity<language::Buffer>,
1764 cursor_position: language::Anchor,
1765 cx: &mut Context<Self>,
1766 ) {
1767 if self.use_context {
1768 self.get_or_init_project(project, cx)
1769 .context
1770 .update(cx, |store, cx| {
1771 store.refresh(buffer.clone(), cursor_position, cx);
1772 });
1773 }
1774 }
1775
1776 #[cfg(feature = "cli-support")]
1777 pub fn set_context_for_buffer(
1778 &mut self,
1779 project: &Entity<Project>,
1780 related_files: Vec<RelatedFile>,
1781 cx: &mut Context<Self>,
1782 ) {
1783 self.get_or_init_project(project, cx)
1784 .context
1785 .update(cx, |store, _| {
1786 store.set_related_files(related_files);
1787 });
1788 }
1789
1790 fn is_file_open_source(
1791 &self,
1792 project: &Entity<Project>,
1793 file: &Arc<dyn File>,
1794 cx: &App,
1795 ) -> bool {
1796 if !file.is_local() || file.is_private() {
1797 return false;
1798 }
1799 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1800 return false;
1801 };
1802 project_state
1803 .license_detection_watchers
1804 .get(&file.worktree_id(cx))
1805 .as_ref()
1806 .is_some_and(|watcher| watcher.is_project_open_source())
1807 }
1808
1809 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1810 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1811 }
1812
1813 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1814 if !self.data_collection_choice.is_enabled() {
1815 return false;
1816 }
1817 events.iter().all(|event| {
1818 matches!(
1819 event.as_ref(),
1820 zeta_prompt::Event::BufferChange {
1821 in_open_source_repo: true,
1822 ..
1823 }
1824 )
1825 })
1826 }
1827
1828 fn load_data_collection_choice() -> DataCollectionChoice {
1829 let choice = KEY_VALUE_STORE
1830 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1831 .log_err()
1832 .flatten();
1833
1834 match choice.as_deref() {
1835 Some("true") => DataCollectionChoice::Enabled,
1836 Some("false") => DataCollectionChoice::Disabled,
1837 Some(_) => {
1838 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1839 DataCollectionChoice::NotAnswered
1840 }
1841 None => DataCollectionChoice::NotAnswered,
1842 }
1843 }
1844
1845 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1846 self.data_collection_choice = self.data_collection_choice.toggle();
1847 let new_choice = self.data_collection_choice;
1848 db::write_and_log(cx, move || {
1849 KEY_VALUE_STORE.write_kvp(
1850 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1851 new_choice.is_enabled().to_string(),
1852 )
1853 });
1854 }
1855
1856 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1857 self.shown_predictions.iter()
1858 }
1859
1860 pub fn shown_completions_len(&self) -> usize {
1861 self.shown_predictions.len()
1862 }
1863
1864 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1865 self.rated_predictions.contains(id)
1866 }
1867
1868 pub fn rate_prediction(
1869 &mut self,
1870 prediction: &EditPrediction,
1871 rating: EditPredictionRating,
1872 feedback: String,
1873 cx: &mut Context<Self>,
1874 ) {
1875 self.rated_predictions.insert(prediction.id.clone());
1876 telemetry::event!(
1877 "Edit Prediction Rated",
1878 rating,
1879 inputs = prediction.inputs,
1880 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1881 feedback
1882 );
1883 self.client.telemetry().flush_events().detach();
1884 cx.notify();
1885 }
1886
1887 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1888 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1889 && all_language_settings(None, cx).edit_predictions.use_context;
1890 }
1891}
1892
1893#[derive(Error, Debug)]
1894#[error(
1895 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1896)]
1897pub struct ZedUpdateRequiredError {
1898 minimum_version: Version,
1899}
1900
1901#[cfg(feature = "cli-support")]
1902pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1903
1904#[cfg(feature = "cli-support")]
1905#[derive(Debug, Clone, Copy, PartialEq)]
1906pub enum EvalCacheEntryKind {
1907 Context,
1908 Search,
1909 Prediction,
1910}
1911
1912#[cfg(feature = "cli-support")]
1913impl std::fmt::Display for EvalCacheEntryKind {
1914 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1915 match self {
1916 EvalCacheEntryKind::Search => write!(f, "search"),
1917 EvalCacheEntryKind::Context => write!(f, "context"),
1918 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1919 }
1920 }
1921}
1922
1923#[cfg(feature = "cli-support")]
1924pub trait EvalCache: Send + Sync {
1925 fn read(&self, key: EvalCacheKey) -> Option<String>;
1926 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1927}
1928
1929#[derive(Debug, Clone, Copy)]
1930pub enum DataCollectionChoice {
1931 NotAnswered,
1932 Enabled,
1933 Disabled,
1934}
1935
1936impl DataCollectionChoice {
1937 pub fn is_enabled(self) -> bool {
1938 match self {
1939 Self::Enabled => true,
1940 Self::NotAnswered | Self::Disabled => false,
1941 }
1942 }
1943
1944 pub fn is_answered(self) -> bool {
1945 match self {
1946 Self::Enabled | Self::Disabled => true,
1947 Self::NotAnswered => false,
1948 }
1949 }
1950
1951 #[must_use]
1952 pub fn toggle(&self) -> DataCollectionChoice {
1953 match self {
1954 Self::Enabled => Self::Disabled,
1955 Self::Disabled => Self::Enabled,
1956 Self::NotAnswered => Self::Enabled,
1957 }
1958 }
1959}
1960
1961impl From<bool> for DataCollectionChoice {
1962 fn from(value: bool) -> Self {
1963 match value {
1964 true => DataCollectionChoice::Enabled,
1965 false => DataCollectionChoice::Disabled,
1966 }
1967 }
1968}
1969
1970struct ZedPredictUpsell;
1971
1972impl Dismissable for ZedPredictUpsell {
1973 const KEY: &'static str = "dismissed-edit-predict-upsell";
1974
1975 fn dismissed() -> bool {
1976 // To make this backwards compatible with older versions of Zed, we
1977 // check if the user has seen the previous Edit Prediction Onboarding
1978 // before, by checking the data collection choice which was written to
1979 // the database once the user clicked on "Accept and Enable"
1980 if KEY_VALUE_STORE
1981 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1982 .log_err()
1983 .is_some_and(|s| s.is_some())
1984 {
1985 return true;
1986 }
1987
1988 KEY_VALUE_STORE
1989 .read_kvp(Self::KEY)
1990 .log_err()
1991 .is_some_and(|s| s.is_some())
1992 }
1993}
1994
1995pub fn should_show_upsell_modal() -> bool {
1996 !ZedPredictUpsell::dismissed()
1997}
1998
1999pub fn init(cx: &mut App) {
2000 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2001 workspace.register_action(
2002 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2003 ZedPredictModal::toggle(
2004 workspace,
2005 workspace.user_store().clone(),
2006 workspace.client().clone(),
2007 window,
2008 cx,
2009 )
2010 },
2011 );
2012
2013 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2014 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2015 settings
2016 .project
2017 .all_languages
2018 .features
2019 .get_or_insert_default()
2020 .edit_prediction_provider = Some(EditPredictionProvider::None)
2021 });
2022 });
2023 })
2024 .detach();
2025}