1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
24 http_client::{self, AsyncBody, Method},
25 prelude::*,
26};
27use language::language_settings::all_language_settings;
28use language::{Anchor, Buffer, File, Point, ToPoint};
29use language::{BufferSnapshot, OffsetRangeExt};
30use language_model::{LlmApiToken, RefreshLlmTokenListener};
31use project::{Project, ProjectPath, WorktreeId};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
36use std::collections::{VecDeque, hash_map};
37use workspace::Workspace;
38
39use std::ops::Range;
40use std::path::Path;
41use std::rc::Rc;
42use std::str::FromStr as _;
43use std::sync::{Arc, LazyLock};
44use std::time::{Duration, Instant};
45use std::{env, mem};
46use thiserror::Error;
47use util::{RangeExt as _, ResultExt as _};
48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
49
50mod cursor_excerpt;
51mod license_detection;
52pub mod mercury;
53mod onboarding_modal;
54pub mod open_ai_response;
55mod prediction;
56pub mod sweep_ai;
57
58#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
59pub mod udiff;
60
61mod zed_edit_prediction_delegate;
62pub mod zeta1;
63pub mod zeta2;
64
65#[cfg(test)]
66mod edit_prediction_tests;
67
68use crate::license_detection::LicenseDetectionWatcher;
69use crate::mercury::Mercury;
70use crate::onboarding_modal::ZedPredictModal;
71pub use crate::prediction::EditPrediction;
72pub use crate::prediction::EditPredictionId;
73use crate::prediction::EditPredictionResult;
74pub use crate::sweep_ai::SweepAi;
75pub use language_model::ApiKeyState;
76pub use telemetry_events::EditPredictionRating;
77pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
78
79actions!(
80 edit_prediction,
81 [
82 /// Resets the edit prediction onboarding state.
83 ResetOnboarding,
84 /// Clears the edit prediction history.
85 ClearHistory,
86 ]
87);
88
89/// Maximum number of events to track.
90const EVENT_COUNT_MAX: usize = 6;
91const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
92const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
93const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
94
95pub struct SweepFeatureFlag;
96
97impl FeatureFlag for SweepFeatureFlag {
98 const NAME: &str = "sweep-ai";
99}
100
101pub struct MercuryFeatureFlag;
102
103impl FeatureFlag for MercuryFeatureFlag {
104 const NAME: &str = "mercury";
105}
106
107pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
108 context: EditPredictionExcerptOptions {
109 max_bytes: 512,
110 min_bytes: 128,
111 target_before_cursor_over_total_bytes: 0.5,
112 },
113 prompt_format: PromptFormat::DEFAULT,
114};
115
116static USE_OLLAMA: LazyLock<bool> =
117 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
118
119static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
120 match env::var("ZED_ZETA2_MODEL").as_deref() {
121 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
122 Ok(model) => model,
123 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
124 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
125 }
126 .to_string()
127});
128static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
129 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
130 if *USE_OLLAMA {
131 Some("http://localhost:11434/v1/chat/completions".into())
132 } else {
133 None
134 }
135 })
136});
137
138pub struct Zeta2FeatureFlag;
139
140impl FeatureFlag for Zeta2FeatureFlag {
141 const NAME: &'static str = "zeta2";
142
143 fn enabled_for_staff() -> bool {
144 true
145 }
146}
147
148#[derive(Clone)]
149struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
150
151impl Global for EditPredictionStoreGlobal {}
152
153pub struct EditPredictionStore {
154 client: Arc<Client>,
155 user_store: Entity<UserStore>,
156 llm_token: LlmApiToken,
157 _llm_token_subscription: Subscription,
158 projects: HashMap<EntityId, ProjectState>,
159 use_context: bool,
160 options: ZetaOptions,
161 update_required: bool,
162 #[cfg(feature = "cli-support")]
163 eval_cache: Option<Arc<dyn EvalCache>>,
164 edit_prediction_model: EditPredictionModel,
165 pub sweep_ai: SweepAi,
166 pub mercury: Mercury,
167 data_collection_choice: DataCollectionChoice,
168 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
169 shown_predictions: VecDeque<EditPrediction>,
170 rated_predictions: HashSet<EditPredictionId>,
171}
172
173#[derive(Copy, Clone, Default, PartialEq, Eq)]
174pub enum EditPredictionModel {
175 #[default]
176 Zeta1,
177 Zeta2,
178 Sweep,
179 Mercury,
180}
181
182pub struct EditPredictionModelInput {
183 project: Entity<Project>,
184 buffer: Entity<Buffer>,
185 snapshot: BufferSnapshot,
186 position: Anchor,
187 events: Vec<Arc<zeta_prompt::Event>>,
188 related_files: Arc<[RelatedFile]>,
189 recent_paths: VecDeque<ProjectPath>,
190 trigger: PredictEditsRequestTrigger,
191 diagnostic_search_range: Range<Point>,
192 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
193}
194
195#[derive(Debug, Clone, PartialEq)]
196pub struct ZetaOptions {
197 pub context: EditPredictionExcerptOptions,
198 pub prompt_format: predict_edits_v3::PromptFormat,
199}
200
201#[derive(Debug)]
202pub enum DebugEvent {
203 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
204 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
205 EditPredictionStarted(EditPredictionStartedDebugEvent),
206 EditPredictionFinished(EditPredictionFinishedDebugEvent),
207}
208
209#[derive(Debug)]
210pub struct ContextRetrievalStartedDebugEvent {
211 pub project_entity_id: EntityId,
212 pub timestamp: Instant,
213 pub search_prompt: String,
214}
215
216#[derive(Debug)]
217pub struct ContextRetrievalFinishedDebugEvent {
218 pub project_entity_id: EntityId,
219 pub timestamp: Instant,
220 pub metadata: Vec<(&'static str, SharedString)>,
221}
222
223#[derive(Debug)]
224pub struct EditPredictionStartedDebugEvent {
225 pub buffer: WeakEntity<Buffer>,
226 pub position: Anchor,
227 pub prompt: Option<String>,
228}
229
230#[derive(Debug)]
231pub struct EditPredictionFinishedDebugEvent {
232 pub buffer: WeakEntity<Buffer>,
233 pub position: Anchor,
234 pub model_output: Option<String>,
235}
236
237pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
238
239struct ProjectState {
240 events: VecDeque<Arc<zeta_prompt::Event>>,
241 last_event: Option<LastEvent>,
242 recent_paths: VecDeque<ProjectPath>,
243 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
244 current_prediction: Option<CurrentEditPrediction>,
245 next_pending_prediction_id: usize,
246 pending_predictions: ArrayVec<PendingPrediction, 2>,
247 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
248 last_prediction_refresh: Option<(EntityId, Instant)>,
249 cancelled_predictions: HashSet<usize>,
250 context: Entity<RelatedExcerptStore>,
251 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
252 _subscription: gpui::Subscription,
253}
254
255impl ProjectState {
256 pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
257 self.events
258 .iter()
259 .cloned()
260 .chain(
261 self.last_event
262 .as_ref()
263 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
264 )
265 .collect()
266 }
267
268 fn cancel_pending_prediction(
269 &mut self,
270 pending_prediction: PendingPrediction,
271 cx: &mut Context<EditPredictionStore>,
272 ) {
273 self.cancelled_predictions.insert(pending_prediction.id);
274
275 cx.spawn(async move |this, cx| {
276 let Some(prediction_id) = pending_prediction.task.await else {
277 return;
278 };
279
280 this.update(cx, |this, _cx| {
281 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
282 })
283 .ok();
284 })
285 .detach()
286 }
287
288 fn active_buffer(
289 &self,
290 project: &Entity<Project>,
291 cx: &App,
292 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
293 let project = project.read(cx);
294 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
295 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
296 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
297 Some((active_buffer, registered_buffer.last_position))
298 }
299}
300
301#[derive(Debug, Clone)]
302struct CurrentEditPrediction {
303 pub requested_by: PredictionRequestedBy,
304 pub prediction: EditPrediction,
305 pub was_shown: bool,
306}
307
308impl CurrentEditPrediction {
309 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
310 let Some(new_edits) = self
311 .prediction
312 .interpolate(&self.prediction.buffer.read(cx))
313 else {
314 return false;
315 };
316
317 if self.prediction.buffer != old_prediction.prediction.buffer {
318 return true;
319 }
320
321 let Some(old_edits) = old_prediction
322 .prediction
323 .interpolate(&old_prediction.prediction.buffer.read(cx))
324 else {
325 return true;
326 };
327
328 let requested_by_buffer_id = self.requested_by.buffer_id();
329
330 // This reduces the occurrence of UI thrash from replacing edits
331 //
332 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
333 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
334 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
335 && old_edits.len() == 1
336 && new_edits.len() == 1
337 {
338 let (old_range, old_text) = &old_edits[0];
339 let (new_range, new_text) = &new_edits[0];
340 new_range == old_range && new_text.starts_with(old_text.as_ref())
341 } else {
342 true
343 }
344 }
345}
346
347#[derive(Debug, Clone)]
348enum PredictionRequestedBy {
349 DiagnosticsUpdate,
350 Buffer(EntityId),
351}
352
353impl PredictionRequestedBy {
354 pub fn buffer_id(&self) -> Option<EntityId> {
355 match self {
356 PredictionRequestedBy::DiagnosticsUpdate => None,
357 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
358 }
359 }
360}
361
362#[derive(Debug)]
363struct PendingPrediction {
364 id: usize,
365 task: Task<Option<EditPredictionId>>,
366}
367
368/// A prediction from the perspective of a buffer.
369#[derive(Debug)]
370enum BufferEditPrediction<'a> {
371 Local { prediction: &'a EditPrediction },
372 Jump { prediction: &'a EditPrediction },
373}
374
375#[cfg(test)]
376impl std::ops::Deref for BufferEditPrediction<'_> {
377 type Target = EditPrediction;
378
379 fn deref(&self) -> &Self::Target {
380 match self {
381 BufferEditPrediction::Local { prediction } => prediction,
382 BufferEditPrediction::Jump { prediction } => prediction,
383 }
384 }
385}
386
387struct RegisteredBuffer {
388 snapshot: BufferSnapshot,
389 last_position: Option<Anchor>,
390 _subscriptions: [gpui::Subscription; 2],
391}
392
393struct LastEvent {
394 old_snapshot: BufferSnapshot,
395 new_snapshot: BufferSnapshot,
396 end_edit_anchor: Option<Anchor>,
397}
398
399impl LastEvent {
400 pub fn finalize(
401 &self,
402 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
403 cx: &App,
404 ) -> Option<Arc<zeta_prompt::Event>> {
405 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
406 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
407
408 let file = self.new_snapshot.file();
409 let old_file = self.old_snapshot.file();
410
411 let in_open_source_repo = [file, old_file].iter().all(|file| {
412 file.is_some_and(|file| {
413 license_detection_watchers
414 .get(&file.worktree_id(cx))
415 .is_some_and(|watcher| watcher.is_project_open_source())
416 })
417 });
418
419 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
420
421 if path == old_path && diff.is_empty() {
422 None
423 } else {
424 Some(Arc::new(zeta_prompt::Event::BufferChange {
425 old_path,
426 path,
427 diff,
428 in_open_source_repo,
429 // TODO: Actually detect if this edit was predicted or not
430 predicted: false,
431 }))
432 }
433 }
434}
435
436fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
437 if let Some(file) = snapshot.file() {
438 file.full_path(cx).into()
439 } else {
440 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
441 }
442}
443
444impl EditPredictionStore {
445 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
446 cx.try_global::<EditPredictionStoreGlobal>()
447 .map(|global| global.0.clone())
448 }
449
450 pub fn global(
451 client: &Arc<Client>,
452 user_store: &Entity<UserStore>,
453 cx: &mut App,
454 ) -> Entity<Self> {
455 cx.try_global::<EditPredictionStoreGlobal>()
456 .map(|global| global.0.clone())
457 .unwrap_or_else(|| {
458 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
459 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
460 ep_store
461 })
462 }
463
464 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
465 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
466 let data_collection_choice = Self::load_data_collection_choice();
467
468 let llm_token = LlmApiToken::default();
469
470 let (reject_tx, reject_rx) = mpsc::unbounded();
471 cx.background_spawn({
472 let client = client.clone();
473 let llm_token = llm_token.clone();
474 let app_version = AppVersion::global(cx);
475 let background_executor = cx.background_executor().clone();
476 async move {
477 Self::handle_rejected_predictions(
478 reject_rx,
479 client,
480 llm_token,
481 app_version,
482 background_executor,
483 )
484 .await
485 }
486 })
487 .detach();
488
489 let mut this = Self {
490 projects: HashMap::default(),
491 client,
492 user_store,
493 options: DEFAULT_OPTIONS,
494 use_context: false,
495 llm_token,
496 _llm_token_subscription: cx.subscribe(
497 &refresh_llm_token_listener,
498 |this, _listener, _event, cx| {
499 let client = this.client.clone();
500 let llm_token = this.llm_token.clone();
501 cx.spawn(async move |_this, _cx| {
502 llm_token.refresh(&client).await?;
503 anyhow::Ok(())
504 })
505 .detach_and_log_err(cx);
506 },
507 ),
508 update_required: false,
509 #[cfg(feature = "cli-support")]
510 eval_cache: None,
511 edit_prediction_model: EditPredictionModel::Zeta2,
512 sweep_ai: SweepAi::new(cx),
513 mercury: Mercury::new(cx),
514 data_collection_choice,
515 reject_predictions_tx: reject_tx,
516 rated_predictions: Default::default(),
517 shown_predictions: Default::default(),
518 };
519
520 this.configure_context_retrieval(cx);
521 let weak_this = cx.weak_entity();
522 cx.on_flags_ready(move |_, cx| {
523 weak_this
524 .update(cx, |this, cx| this.configure_context_retrieval(cx))
525 .ok();
526 })
527 .detach();
528 cx.observe_global::<SettingsStore>(|this, cx| {
529 this.configure_context_retrieval(cx);
530 })
531 .detach();
532
533 this
534 }
535
536 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
537 self.edit_prediction_model = model;
538 }
539
540 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
541 self.sweep_ai.api_token.read(cx).has_key()
542 }
543
544 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
545 self.mercury.api_token.read(cx).has_key()
546 }
547
548 #[cfg(feature = "cli-support")]
549 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
550 self.eval_cache = Some(cache);
551 }
552
553 pub fn options(&self) -> &ZetaOptions {
554 &self.options
555 }
556
557 pub fn set_options(&mut self, options: ZetaOptions) {
558 self.options = options;
559 }
560
561 pub fn set_use_context(&mut self, use_context: bool) {
562 self.use_context = use_context;
563 }
564
565 pub fn clear_history(&mut self) {
566 for project_state in self.projects.values_mut() {
567 project_state.events.clear();
568 }
569 }
570
571 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
572 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
573 project_state.events.clear();
574 }
575 }
576
577 pub fn edit_history_for_project(
578 &self,
579 project: &Entity<Project>,
580 cx: &App,
581 ) -> Vec<Arc<zeta_prompt::Event>> {
582 self.projects
583 .get(&project.entity_id())
584 .map(|project_state| project_state.events(cx))
585 .unwrap_or_default()
586 }
587
588 pub fn context_for_project<'a>(
589 &'a self,
590 project: &Entity<Project>,
591 cx: &'a App,
592 ) -> Arc<[RelatedFile]> {
593 self.projects
594 .get(&project.entity_id())
595 .map(|project| project.context.read(cx).related_files())
596 .unwrap_or_else(|| vec![].into())
597 }
598
599 pub fn context_for_project_with_buffers<'a>(
600 &'a self,
601 project: &Entity<Project>,
602 cx: &'a App,
603 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
604 self.projects
605 .get(&project.entity_id())
606 .map(|project| project.context.read(cx).related_files_with_buffers())
607 }
608
609 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
610 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
611 self.user_store.read(cx).edit_prediction_usage()
612 } else {
613 None
614 }
615 }
616
617 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
618 self.get_or_init_project(project, cx);
619 }
620
621 pub fn register_buffer(
622 &mut self,
623 buffer: &Entity<Buffer>,
624 project: &Entity<Project>,
625 cx: &mut Context<Self>,
626 ) {
627 let project_state = self.get_or_init_project(project, cx);
628 Self::register_buffer_impl(project_state, buffer, project, cx);
629 }
630
631 fn get_or_init_project(
632 &mut self,
633 project: &Entity<Project>,
634 cx: &mut Context<Self>,
635 ) -> &mut ProjectState {
636 let entity_id = project.entity_id();
637 self.projects
638 .entry(entity_id)
639 .or_insert_with(|| ProjectState {
640 context: {
641 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
642 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
643 this.handle_excerpt_store_event(entity_id, event);
644 })
645 .detach();
646 related_excerpt_store
647 },
648 events: VecDeque::new(),
649 last_event: None,
650 recent_paths: VecDeque::new(),
651 debug_tx: None,
652 registered_buffers: HashMap::default(),
653 current_prediction: None,
654 cancelled_predictions: HashSet::default(),
655 pending_predictions: ArrayVec::new(),
656 next_pending_prediction_id: 0,
657 last_prediction_refresh: None,
658 license_detection_watchers: HashMap::default(),
659 _subscription: cx.subscribe(&project, Self::handle_project_event),
660 })
661 }
662
663 pub fn remove_project(&mut self, project: &Entity<Project>) {
664 self.projects.remove(&project.entity_id());
665 }
666
667 fn handle_excerpt_store_event(
668 &mut self,
669 project_entity_id: EntityId,
670 event: &RelatedExcerptStoreEvent,
671 ) {
672 if let Some(project_state) = self.projects.get(&project_entity_id) {
673 if let Some(debug_tx) = project_state.debug_tx.clone() {
674 match event {
675 RelatedExcerptStoreEvent::StartedRefresh => {
676 debug_tx
677 .unbounded_send(DebugEvent::ContextRetrievalStarted(
678 ContextRetrievalStartedDebugEvent {
679 project_entity_id: project_entity_id,
680 timestamp: Instant::now(),
681 search_prompt: String::new(),
682 },
683 ))
684 .ok();
685 }
686 RelatedExcerptStoreEvent::FinishedRefresh {
687 cache_hit_count,
688 cache_miss_count,
689 mean_definition_latency,
690 max_definition_latency,
691 } => {
692 debug_tx
693 .unbounded_send(DebugEvent::ContextRetrievalFinished(
694 ContextRetrievalFinishedDebugEvent {
695 project_entity_id: project_entity_id,
696 timestamp: Instant::now(),
697 metadata: vec![
698 (
699 "Cache Hits",
700 format!(
701 "{}/{}",
702 cache_hit_count,
703 cache_hit_count + cache_miss_count
704 )
705 .into(),
706 ),
707 (
708 "Max LSP Time",
709 format!("{} ms", max_definition_latency.as_millis())
710 .into(),
711 ),
712 (
713 "Mean LSP Time",
714 format!("{} ms", mean_definition_latency.as_millis())
715 .into(),
716 ),
717 ],
718 },
719 ))
720 .ok();
721 }
722 }
723 }
724 }
725 }
726
727 pub fn debug_info(
728 &mut self,
729 project: &Entity<Project>,
730 cx: &mut Context<Self>,
731 ) -> mpsc::UnboundedReceiver<DebugEvent> {
732 let project_state = self.get_or_init_project(project, cx);
733 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
734 project_state.debug_tx = Some(debug_watch_tx);
735 debug_watch_rx
736 }
737
738 fn handle_project_event(
739 &mut self,
740 project: Entity<Project>,
741 event: &project::Event,
742 cx: &mut Context<Self>,
743 ) {
744 // TODO [zeta2] init with recent paths
745 match event {
746 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
747 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
748 return;
749 };
750 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
751 if let Some(path) = path {
752 if let Some(ix) = project_state
753 .recent_paths
754 .iter()
755 .position(|probe| probe == &path)
756 {
757 project_state.recent_paths.remove(ix);
758 }
759 project_state.recent_paths.push_front(path);
760 }
761 }
762 project::Event::DiagnosticsUpdated { .. } => {
763 if cx.has_flag::<Zeta2FeatureFlag>() {
764 self.refresh_prediction_from_diagnostics(project, cx);
765 }
766 }
767 _ => (),
768 }
769 }
770
771 fn register_buffer_impl<'a>(
772 project_state: &'a mut ProjectState,
773 buffer: &Entity<Buffer>,
774 project: &Entity<Project>,
775 cx: &mut Context<Self>,
776 ) -> &'a mut RegisteredBuffer {
777 let buffer_id = buffer.entity_id();
778
779 if let Some(file) = buffer.read(cx).file() {
780 let worktree_id = file.worktree_id(cx);
781 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
782 project_state
783 .license_detection_watchers
784 .entry(worktree_id)
785 .or_insert_with(|| {
786 let project_entity_id = project.entity_id();
787 cx.observe_release(&worktree, move |this, _worktree, _cx| {
788 let Some(project_state) = this.projects.get_mut(&project_entity_id)
789 else {
790 return;
791 };
792 project_state
793 .license_detection_watchers
794 .remove(&worktree_id);
795 })
796 .detach();
797 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
798 });
799 }
800 }
801
802 match project_state.registered_buffers.entry(buffer_id) {
803 hash_map::Entry::Occupied(entry) => entry.into_mut(),
804 hash_map::Entry::Vacant(entry) => {
805 let snapshot = buffer.read(cx).snapshot();
806 let project_entity_id = project.entity_id();
807 entry.insert(RegisteredBuffer {
808 snapshot,
809 last_position: None,
810 _subscriptions: [
811 cx.subscribe(buffer, {
812 let project = project.downgrade();
813 move |this, buffer, event, cx| {
814 if let language::BufferEvent::Edited = event
815 && let Some(project) = project.upgrade()
816 {
817 this.report_changes_for_buffer(&buffer, &project, cx);
818 }
819 }
820 }),
821 cx.observe_release(buffer, move |this, _buffer, _cx| {
822 let Some(project_state) = this.projects.get_mut(&project_entity_id)
823 else {
824 return;
825 };
826 project_state.registered_buffers.remove(&buffer_id);
827 }),
828 ],
829 })
830 }
831 }
832 }
833
834 fn report_changes_for_buffer(
835 &mut self,
836 buffer: &Entity<Buffer>,
837 project: &Entity<Project>,
838 cx: &mut Context<Self>,
839 ) {
840 let project_state = self.get_or_init_project(project, cx);
841 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
842
843 let new_snapshot = buffer.read(cx).snapshot();
844 if new_snapshot.version == registered_buffer.snapshot.version {
845 return;
846 }
847
848 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
849 let end_edit_anchor = new_snapshot
850 .anchored_edits_since::<Point>(&old_snapshot.version)
851 .last()
852 .map(|(_, range)| range.end);
853 let events = &mut project_state.events;
854
855 if let Some(LastEvent {
856 new_snapshot: last_new_snapshot,
857 end_edit_anchor: last_end_edit_anchor,
858 ..
859 }) = project_state.last_event.as_mut()
860 {
861 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
862 == last_new_snapshot.remote_id()
863 && old_snapshot.version == last_new_snapshot.version;
864
865 let should_coalesce = is_next_snapshot_of_same_buffer
866 && end_edit_anchor
867 .as_ref()
868 .zip(last_end_edit_anchor.as_ref())
869 .is_some_and(|(a, b)| {
870 let a = a.to_point(&new_snapshot);
871 let b = b.to_point(&new_snapshot);
872 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
873 });
874
875 if should_coalesce {
876 *last_end_edit_anchor = end_edit_anchor;
877 *last_new_snapshot = new_snapshot;
878 return;
879 }
880 }
881
882 if events.len() + 1 >= EVENT_COUNT_MAX {
883 events.pop_front();
884 }
885
886 if let Some(event) = project_state.last_event.take() {
887 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
888 }
889
890 project_state.last_event = Some(LastEvent {
891 old_snapshot,
892 new_snapshot,
893 end_edit_anchor,
894 });
895 }
896
897 fn prediction_at(
898 &mut self,
899 buffer: &Entity<Buffer>,
900 position: Option<language::Anchor>,
901 project: &Entity<Project>,
902 cx: &App,
903 ) -> Option<BufferEditPrediction<'_>> {
904 let project_state = self.projects.get_mut(&project.entity_id())?;
905 if let Some(position) = position
906 && let Some(buffer) = project_state
907 .registered_buffers
908 .get_mut(&buffer.entity_id())
909 {
910 buffer.last_position = Some(position);
911 }
912
913 let CurrentEditPrediction {
914 requested_by,
915 prediction,
916 ..
917 } = project_state.current_prediction.as_ref()?;
918
919 if prediction.targets_buffer(buffer.read(cx)) {
920 Some(BufferEditPrediction::Local { prediction })
921 } else {
922 let show_jump = match requested_by {
923 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
924 requested_by_buffer_id == &buffer.entity_id()
925 }
926 PredictionRequestedBy::DiagnosticsUpdate => true,
927 };
928
929 if show_jump {
930 Some(BufferEditPrediction::Jump { prediction })
931 } else {
932 None
933 }
934 }
935 }
936
937 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
938 match self.edit_prediction_model {
939 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
940 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
941 }
942
943 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
944 return;
945 };
946
947 let Some(prediction) = project_state.current_prediction.take() else {
948 return;
949 };
950 let request_id = prediction.prediction.id.to_string();
951 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
952 project_state.cancel_pending_prediction(pending_prediction, cx);
953 }
954
955 let client = self.client.clone();
956 let llm_token = self.llm_token.clone();
957 let app_version = AppVersion::global(cx);
958 cx.spawn(async move |this, cx| {
959 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
960 http_client::Url::parse(&predict_edits_url)?
961 } else {
962 client
963 .http_client()
964 .build_zed_llm_url("/predict_edits/accept", &[])?
965 };
966
967 let response = cx
968 .background_spawn(Self::send_api_request::<()>(
969 move |builder| {
970 let req = builder.uri(url.as_ref()).body(
971 serde_json::to_string(&AcceptEditPredictionBody {
972 request_id: request_id.clone(),
973 })?
974 .into(),
975 );
976 Ok(req?)
977 },
978 client,
979 llm_token,
980 app_version,
981 ))
982 .await;
983
984 Self::handle_api_response(&this, response, cx)?;
985 anyhow::Ok(())
986 })
987 .detach_and_log_err(cx);
988 }
989
990 async fn handle_rejected_predictions(
991 rx: UnboundedReceiver<EditPredictionRejection>,
992 client: Arc<Client>,
993 llm_token: LlmApiToken,
994 app_version: Version,
995 background_executor: BackgroundExecutor,
996 ) {
997 let mut rx = std::pin::pin!(rx.peekable());
998 let mut batched = Vec::new();
999
1000 while let Some(rejection) = rx.next().await {
1001 batched.push(rejection);
1002
1003 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1004 select_biased! {
1005 next = rx.as_mut().peek().fuse() => {
1006 if next.is_some() {
1007 continue;
1008 }
1009 }
1010 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1011 }
1012 }
1013
1014 let url = client
1015 .http_client()
1016 .build_zed_llm_url("/predict_edits/reject", &[])
1017 .unwrap();
1018
1019 let flush_count = batched
1020 .len()
1021 // in case items have accumulated after failure
1022 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1023 let start = batched.len() - flush_count;
1024
1025 let body = RejectEditPredictionsBodyRef {
1026 rejections: &batched[start..],
1027 };
1028
1029 let result = Self::send_api_request::<()>(
1030 |builder| {
1031 let req = builder
1032 .uri(url.as_ref())
1033 .body(serde_json::to_string(&body)?.into());
1034 anyhow::Ok(req?)
1035 },
1036 client.clone(),
1037 llm_token.clone(),
1038 app_version.clone(),
1039 )
1040 .await;
1041
1042 if result.log_err().is_some() {
1043 batched.drain(start..);
1044 }
1045 }
1046 }
1047
1048 fn reject_current_prediction(
1049 &mut self,
1050 reason: EditPredictionRejectReason,
1051 project: &Entity<Project>,
1052 ) {
1053 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1054 project_state.pending_predictions.clear();
1055 if let Some(prediction) = project_state.current_prediction.take() {
1056 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1057 }
1058 };
1059 }
1060
1061 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1062 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1063 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1064 if !current_prediction.was_shown {
1065 current_prediction.was_shown = true;
1066 self.shown_predictions
1067 .push_front(current_prediction.prediction.clone());
1068 if self.shown_predictions.len() > 50 {
1069 let completion = self.shown_predictions.pop_back().unwrap();
1070 self.rated_predictions.remove(&completion.id);
1071 }
1072 }
1073 }
1074 }
1075 }
1076
1077 fn reject_prediction(
1078 &mut self,
1079 prediction_id: EditPredictionId,
1080 reason: EditPredictionRejectReason,
1081 was_shown: bool,
1082 ) {
1083 match self.edit_prediction_model {
1084 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1085 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1086 }
1087
1088 self.reject_predictions_tx
1089 .unbounded_send(EditPredictionRejection {
1090 request_id: prediction_id.to_string(),
1091 reason,
1092 was_shown,
1093 })
1094 .log_err();
1095 }
1096
1097 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1098 self.projects
1099 .get(&project.entity_id())
1100 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1101 }
1102
1103 pub fn refresh_prediction_from_buffer(
1104 &mut self,
1105 project: Entity<Project>,
1106 buffer: Entity<Buffer>,
1107 position: language::Anchor,
1108 cx: &mut Context<Self>,
1109 ) {
1110 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1111 let Some(request_task) = this
1112 .update(cx, |this, cx| {
1113 this.request_prediction(
1114 &project,
1115 &buffer,
1116 position,
1117 PredictEditsRequestTrigger::Other,
1118 cx,
1119 )
1120 })
1121 .log_err()
1122 else {
1123 return Task::ready(anyhow::Ok(None));
1124 };
1125
1126 cx.spawn(async move |_cx| {
1127 request_task.await.map(|prediction_result| {
1128 prediction_result.map(|prediction_result| {
1129 (
1130 prediction_result,
1131 PredictionRequestedBy::Buffer(buffer.entity_id()),
1132 )
1133 })
1134 })
1135 })
1136 })
1137 }
1138
1139 pub fn refresh_prediction_from_diagnostics(
1140 &mut self,
1141 project: Entity<Project>,
1142 cx: &mut Context<Self>,
1143 ) {
1144 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1145 return;
1146 };
1147
1148 // Prefer predictions from buffer
1149 if project_state.current_prediction.is_some() {
1150 return;
1151 };
1152
1153 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1154 let Some((active_buffer, snapshot, cursor_point)) = this
1155 .read_with(cx, |this, cx| {
1156 let project_state = this.projects.get(&project.entity_id())?;
1157 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1158 let snapshot = buffer.read(cx).snapshot();
1159
1160 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1161 return None;
1162 }
1163
1164 let cursor_point = position
1165 .map(|pos| pos.to_point(&snapshot))
1166 .unwrap_or_default();
1167
1168 Some((buffer, snapshot, cursor_point))
1169 })
1170 .log_err()
1171 .flatten()
1172 else {
1173 return Task::ready(anyhow::Ok(None));
1174 };
1175
1176 cx.spawn(async move |cx| {
1177 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1178 active_buffer,
1179 &snapshot,
1180 Default::default(),
1181 cursor_point,
1182 &project,
1183 cx,
1184 )
1185 .await?
1186 else {
1187 return anyhow::Ok(None);
1188 };
1189
1190 let Some(prediction_result) = this
1191 .update(cx, |this, cx| {
1192 this.request_prediction(
1193 &project,
1194 &jump_buffer,
1195 jump_position,
1196 PredictEditsRequestTrigger::Diagnostics,
1197 cx,
1198 )
1199 })?
1200 .await?
1201 else {
1202 return anyhow::Ok(None);
1203 };
1204
1205 this.update(cx, |this, cx| {
1206 Some((
1207 if this
1208 .get_or_init_project(&project, cx)
1209 .current_prediction
1210 .is_none()
1211 {
1212 prediction_result
1213 } else {
1214 EditPredictionResult {
1215 id: prediction_result.id,
1216 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1217 }
1218 },
1219 PredictionRequestedBy::DiagnosticsUpdate,
1220 ))
1221 })
1222 })
1223 });
1224 }
1225
1226 fn predictions_enabled_at(
1227 snapshot: &BufferSnapshot,
1228 position: Option<language::Anchor>,
1229 cx: &App,
1230 ) -> bool {
1231 let file = snapshot.file();
1232 let all_settings = all_language_settings(file, cx);
1233 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1234 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1235 {
1236 return false;
1237 }
1238
1239 if let Some(last_position) = position {
1240 let settings = snapshot.settings_at(last_position, cx);
1241
1242 if !settings.edit_predictions_disabled_in.is_empty()
1243 && let Some(scope) = snapshot.language_scope_at(last_position)
1244 && let Some(scope_name) = scope.override_name()
1245 && settings
1246 .edit_predictions_disabled_in
1247 .iter()
1248 .any(|s| s == scope_name)
1249 {
1250 return false;
1251 }
1252 }
1253
1254 true
1255 }
1256
1257 #[cfg(not(test))]
1258 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1259 #[cfg(test)]
1260 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1261
1262 fn queue_prediction_refresh(
1263 &mut self,
1264 project: Entity<Project>,
1265 throttle_entity: EntityId,
1266 cx: &mut Context<Self>,
1267 do_refresh: impl FnOnce(
1268 WeakEntity<Self>,
1269 &mut AsyncApp,
1270 )
1271 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1272 + 'static,
1273 ) {
1274 let project_state = self.get_or_init_project(&project, cx);
1275 let pending_prediction_id = project_state.next_pending_prediction_id;
1276 project_state.next_pending_prediction_id += 1;
1277 let last_request = project_state.last_prediction_refresh;
1278
1279 let task = cx.spawn(async move |this, cx| {
1280 if let Some((last_entity, last_timestamp)) = last_request
1281 && throttle_entity == last_entity
1282 && let Some(timeout) =
1283 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1284 {
1285 cx.background_executor().timer(timeout).await;
1286 }
1287
1288 // If this task was cancelled before the throttle timeout expired,
1289 // do not perform a request.
1290 let mut is_cancelled = true;
1291 this.update(cx, |this, cx| {
1292 let project_state = this.get_or_init_project(&project, cx);
1293 if !project_state
1294 .cancelled_predictions
1295 .remove(&pending_prediction_id)
1296 {
1297 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1298 is_cancelled = false;
1299 }
1300 })
1301 .ok();
1302 if is_cancelled {
1303 return None;
1304 }
1305
1306 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1307 let new_prediction_id = new_prediction_result
1308 .as_ref()
1309 .map(|(prediction, _)| prediction.id.clone());
1310
1311 // When a prediction completes, remove it from the pending list, and cancel
1312 // any pending predictions that were enqueued before it.
1313 this.update(cx, |this, cx| {
1314 let project_state = this.get_or_init_project(&project, cx);
1315
1316 let is_cancelled = project_state
1317 .cancelled_predictions
1318 .remove(&pending_prediction_id);
1319
1320 let new_current_prediction = if !is_cancelled
1321 && let Some((prediction_result, requested_by)) = new_prediction_result
1322 {
1323 match prediction_result.prediction {
1324 Ok(prediction) => {
1325 let new_prediction = CurrentEditPrediction {
1326 requested_by,
1327 prediction,
1328 was_shown: false,
1329 };
1330
1331 if let Some(current_prediction) =
1332 project_state.current_prediction.as_ref()
1333 {
1334 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1335 {
1336 this.reject_current_prediction(
1337 EditPredictionRejectReason::Replaced,
1338 &project,
1339 );
1340
1341 Some(new_prediction)
1342 } else {
1343 this.reject_prediction(
1344 new_prediction.prediction.id,
1345 EditPredictionRejectReason::CurrentPreferred,
1346 false,
1347 );
1348 None
1349 }
1350 } else {
1351 Some(new_prediction)
1352 }
1353 }
1354 Err(reject_reason) => {
1355 this.reject_prediction(prediction_result.id, reject_reason, false);
1356 None
1357 }
1358 }
1359 } else {
1360 None
1361 };
1362
1363 let project_state = this.get_or_init_project(&project, cx);
1364
1365 if let Some(new_prediction) = new_current_prediction {
1366 project_state.current_prediction = Some(new_prediction);
1367 }
1368
1369 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1370 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1371 if pending_prediction.id == pending_prediction_id {
1372 pending_predictions.remove(ix);
1373 for pending_prediction in pending_predictions.drain(0..ix) {
1374 project_state.cancel_pending_prediction(pending_prediction, cx)
1375 }
1376 break;
1377 }
1378 }
1379 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1380 cx.notify();
1381 })
1382 .ok();
1383
1384 new_prediction_id
1385 });
1386
1387 if project_state.pending_predictions.len() <= 1 {
1388 project_state.pending_predictions.push(PendingPrediction {
1389 id: pending_prediction_id,
1390 task,
1391 });
1392 } else if project_state.pending_predictions.len() == 2 {
1393 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1394 project_state.pending_predictions.push(PendingPrediction {
1395 id: pending_prediction_id,
1396 task,
1397 });
1398 project_state.cancel_pending_prediction(pending_prediction, cx);
1399 }
1400 }
1401
1402 pub fn request_prediction(
1403 &mut self,
1404 project: &Entity<Project>,
1405 active_buffer: &Entity<Buffer>,
1406 position: language::Anchor,
1407 trigger: PredictEditsRequestTrigger,
1408 cx: &mut Context<Self>,
1409 ) -> Task<Result<Option<EditPredictionResult>>> {
1410 self.request_prediction_internal(
1411 project.clone(),
1412 active_buffer.clone(),
1413 position,
1414 trigger,
1415 cx.has_flag::<Zeta2FeatureFlag>(),
1416 cx,
1417 )
1418 }
1419
1420 fn request_prediction_internal(
1421 &mut self,
1422 project: Entity<Project>,
1423 active_buffer: Entity<Buffer>,
1424 position: language::Anchor,
1425 trigger: PredictEditsRequestTrigger,
1426 allow_jump: bool,
1427 cx: &mut Context<Self>,
1428 ) -> Task<Result<Option<EditPredictionResult>>> {
1429 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1430
1431 self.get_or_init_project(&project, cx);
1432 let project_state = self.projects.get(&project.entity_id()).unwrap();
1433 let events = project_state.events(cx);
1434 let has_events = !events.is_empty();
1435 let debug_tx = project_state.debug_tx.clone();
1436
1437 let snapshot = active_buffer.read(cx).snapshot();
1438 let cursor_point = position.to_point(&snapshot);
1439 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1440 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1441 let diagnostic_search_range =
1442 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1443
1444 let related_files = if self.use_context {
1445 self.context_for_project(&project, cx)
1446 } else {
1447 Vec::new().into()
1448 };
1449
1450 let inputs = EditPredictionModelInput {
1451 project: project.clone(),
1452 buffer: active_buffer.clone(),
1453 snapshot: snapshot.clone(),
1454 position,
1455 events,
1456 related_files,
1457 recent_paths: project_state.recent_paths.clone(),
1458 trigger,
1459 diagnostic_search_range: diagnostic_search_range.clone(),
1460 debug_tx,
1461 };
1462
1463 let task = match self.edit_prediction_model {
1464 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1465 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1466 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1467 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1468 };
1469
1470 cx.spawn(async move |this, cx| {
1471 let prediction = task.await?;
1472
1473 if prediction.is_none() && allow_jump {
1474 let cursor_point = position.to_point(&snapshot);
1475 if has_events
1476 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1477 active_buffer.clone(),
1478 &snapshot,
1479 diagnostic_search_range,
1480 cursor_point,
1481 &project,
1482 cx,
1483 )
1484 .await?
1485 {
1486 return this
1487 .update(cx, |this, cx| {
1488 this.request_prediction_internal(
1489 project,
1490 jump_buffer,
1491 jump_position,
1492 trigger,
1493 false,
1494 cx,
1495 )
1496 })?
1497 .await;
1498 }
1499
1500 return anyhow::Ok(None);
1501 }
1502
1503 Ok(prediction)
1504 })
1505 }
1506
1507 async fn next_diagnostic_location(
1508 active_buffer: Entity<Buffer>,
1509 active_buffer_snapshot: &BufferSnapshot,
1510 active_buffer_diagnostic_search_range: Range<Point>,
1511 active_buffer_cursor_point: Point,
1512 project: &Entity<Project>,
1513 cx: &mut AsyncApp,
1514 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1515 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1516 let mut jump_location = active_buffer_snapshot
1517 .diagnostic_groups(None)
1518 .into_iter()
1519 .filter_map(|(_, group)| {
1520 let range = &group.entries[group.primary_ix]
1521 .range
1522 .to_point(&active_buffer_snapshot);
1523 if range.overlaps(&active_buffer_diagnostic_search_range) {
1524 None
1525 } else {
1526 Some(range.start)
1527 }
1528 })
1529 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1530 .map(|position| {
1531 (
1532 active_buffer.clone(),
1533 active_buffer_snapshot.anchor_before(position),
1534 )
1535 });
1536
1537 if jump_location.is_none() {
1538 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1539 let file = buffer.file()?;
1540
1541 Some(ProjectPath {
1542 worktree_id: file.worktree_id(cx),
1543 path: file.path().clone(),
1544 })
1545 })?;
1546
1547 let buffer_task = project.update(cx, |project, cx| {
1548 let (path, _, _) = project
1549 .diagnostic_summaries(false, cx)
1550 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1551 .max_by_key(|(path, _, _)| {
1552 // find the buffer with errors that shares most parent directories
1553 path.path
1554 .components()
1555 .zip(
1556 active_buffer_path
1557 .as_ref()
1558 .map(|p| p.path.components())
1559 .unwrap_or_default(),
1560 )
1561 .take_while(|(a, b)| a == b)
1562 .count()
1563 })?;
1564
1565 Some(project.open_buffer(path, cx))
1566 })?;
1567
1568 if let Some(buffer_task) = buffer_task {
1569 let closest_buffer = buffer_task.await?;
1570
1571 jump_location = closest_buffer
1572 .read_with(cx, |buffer, _cx| {
1573 buffer
1574 .buffer_diagnostics(None)
1575 .into_iter()
1576 .min_by_key(|entry| entry.diagnostic.severity)
1577 .map(|entry| entry.range.start)
1578 })?
1579 .map(|position| (closest_buffer, position));
1580 }
1581 }
1582
1583 anyhow::Ok(jump_location)
1584 }
1585
1586 async fn send_raw_llm_request(
1587 request: open_ai::Request,
1588 client: Arc<Client>,
1589 llm_token: LlmApiToken,
1590 app_version: Version,
1591 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1592 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1593 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1594 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1595 http_client::Url::parse(&predict_edits_url)?
1596 } else {
1597 client
1598 .http_client()
1599 .build_zed_llm_url("/predict_edits/raw", &[])?
1600 };
1601
1602 #[cfg(feature = "cli-support")]
1603 let cache_key = if let Some(cache) = eval_cache {
1604 use collections::FxHasher;
1605 use std::hash::{Hash, Hasher};
1606
1607 let mut hasher = FxHasher::default();
1608 url.hash(&mut hasher);
1609 let request_str = serde_json::to_string_pretty(&request)?;
1610 request_str.hash(&mut hasher);
1611 let hash = hasher.finish();
1612
1613 let key = (eval_cache_kind, hash);
1614 if let Some(response_str) = cache.read(key) {
1615 return Ok((serde_json::from_str(&response_str)?, None));
1616 }
1617
1618 Some((cache, request_str, key))
1619 } else {
1620 None
1621 };
1622
1623 let (response, usage) = Self::send_api_request(
1624 |builder| {
1625 let req = builder
1626 .uri(url.as_ref())
1627 .body(serde_json::to_string(&request)?.into());
1628 Ok(req?)
1629 },
1630 client,
1631 llm_token,
1632 app_version,
1633 )
1634 .await?;
1635
1636 #[cfg(feature = "cli-support")]
1637 if let Some((cache, request, key)) = cache_key {
1638 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1639 }
1640
1641 Ok((response, usage))
1642 }
1643
1644 fn handle_api_response<T>(
1645 this: &WeakEntity<Self>,
1646 response: Result<(T, Option<EditPredictionUsage>)>,
1647 cx: &mut gpui::AsyncApp,
1648 ) -> Result<T> {
1649 match response {
1650 Ok((data, usage)) => {
1651 if let Some(usage) = usage {
1652 this.update(cx, |this, cx| {
1653 this.user_store.update(cx, |user_store, cx| {
1654 user_store.update_edit_prediction_usage(usage, cx);
1655 });
1656 })
1657 .ok();
1658 }
1659 Ok(data)
1660 }
1661 Err(err) => {
1662 if err.is::<ZedUpdateRequiredError>() {
1663 cx.update(|cx| {
1664 this.update(cx, |this, _cx| {
1665 this.update_required = true;
1666 })
1667 .ok();
1668
1669 let error_message: SharedString = err.to_string().into();
1670 show_app_notification(
1671 NotificationId::unique::<ZedUpdateRequiredError>(),
1672 cx,
1673 move |cx| {
1674 cx.new(|cx| {
1675 ErrorMessagePrompt::new(error_message.clone(), cx)
1676 .with_link_button("Update Zed", "https://zed.dev/releases")
1677 })
1678 },
1679 );
1680 })
1681 .ok();
1682 }
1683 Err(err)
1684 }
1685 }
1686 }
1687
1688 async fn send_api_request<Res>(
1689 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1690 client: Arc<Client>,
1691 llm_token: LlmApiToken,
1692 app_version: Version,
1693 ) -> Result<(Res, Option<EditPredictionUsage>)>
1694 where
1695 Res: DeserializeOwned,
1696 {
1697 let http_client = client.http_client();
1698 let mut token = llm_token.acquire(&client).await?;
1699 let mut did_retry = false;
1700
1701 loop {
1702 let request_builder = http_client::Request::builder().method(Method::POST);
1703
1704 let request = build(
1705 request_builder
1706 .header("Content-Type", "application/json")
1707 .header("Authorization", format!("Bearer {}", token))
1708 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1709 )?;
1710
1711 let mut response = http_client.send(request).await?;
1712
1713 if let Some(minimum_required_version) = response
1714 .headers()
1715 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1716 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1717 {
1718 anyhow::ensure!(
1719 app_version >= minimum_required_version,
1720 ZedUpdateRequiredError {
1721 minimum_version: minimum_required_version
1722 }
1723 );
1724 }
1725
1726 if response.status().is_success() {
1727 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1728
1729 let mut body = Vec::new();
1730 response.body_mut().read_to_end(&mut body).await?;
1731 return Ok((serde_json::from_slice(&body)?, usage));
1732 } else if !did_retry
1733 && response
1734 .headers()
1735 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1736 .is_some()
1737 {
1738 did_retry = true;
1739 token = llm_token.refresh(&client).await?;
1740 } else {
1741 let mut body = String::new();
1742 response.body_mut().read_to_string(&mut body).await?;
1743 anyhow::bail!(
1744 "Request failed with status: {:?}\nBody: {}",
1745 response.status(),
1746 body
1747 );
1748 }
1749 }
1750 }
1751
1752 pub fn refresh_context(
1753 &mut self,
1754 project: &Entity<Project>,
1755 buffer: &Entity<language::Buffer>,
1756 cursor_position: language::Anchor,
1757 cx: &mut Context<Self>,
1758 ) {
1759 if self.use_context {
1760 self.get_or_init_project(project, cx)
1761 .context
1762 .update(cx, |store, cx| {
1763 store.refresh(buffer.clone(), cursor_position, cx);
1764 });
1765 }
1766 }
1767
1768 #[cfg(feature = "cli-support")]
1769 pub fn set_context_for_buffer(
1770 &mut self,
1771 project: &Entity<Project>,
1772 related_files: Vec<RelatedFile>,
1773 cx: &mut Context<Self>,
1774 ) {
1775 self.get_or_init_project(project, cx)
1776 .context
1777 .update(cx, |store, _| {
1778 store.set_related_files(related_files);
1779 });
1780 }
1781
1782 fn is_file_open_source(
1783 &self,
1784 project: &Entity<Project>,
1785 file: &Arc<dyn File>,
1786 cx: &App,
1787 ) -> bool {
1788 if !file.is_local() || file.is_private() {
1789 return false;
1790 }
1791 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1792 return false;
1793 };
1794 project_state
1795 .license_detection_watchers
1796 .get(&file.worktree_id(cx))
1797 .as_ref()
1798 .is_some_and(|watcher| watcher.is_project_open_source())
1799 }
1800
1801 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1802 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1803 }
1804
1805 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1806 if !self.data_collection_choice.is_enabled() {
1807 return false;
1808 }
1809 events.iter().all(|event| {
1810 matches!(
1811 event.as_ref(),
1812 zeta_prompt::Event::BufferChange {
1813 in_open_source_repo: true,
1814 ..
1815 }
1816 )
1817 })
1818 }
1819
1820 fn load_data_collection_choice() -> DataCollectionChoice {
1821 let choice = KEY_VALUE_STORE
1822 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1823 .log_err()
1824 .flatten();
1825
1826 match choice.as_deref() {
1827 Some("true") => DataCollectionChoice::Enabled,
1828 Some("false") => DataCollectionChoice::Disabled,
1829 Some(_) => {
1830 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1831 DataCollectionChoice::NotAnswered
1832 }
1833 None => DataCollectionChoice::NotAnswered,
1834 }
1835 }
1836
1837 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1838 self.data_collection_choice = self.data_collection_choice.toggle();
1839 let new_choice = self.data_collection_choice;
1840 db::write_and_log(cx, move || {
1841 KEY_VALUE_STORE.write_kvp(
1842 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1843 new_choice.is_enabled().to_string(),
1844 )
1845 });
1846 }
1847
1848 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1849 self.shown_predictions.iter()
1850 }
1851
1852 pub fn shown_completions_len(&self) -> usize {
1853 self.shown_predictions.len()
1854 }
1855
1856 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1857 self.rated_predictions.contains(id)
1858 }
1859
1860 pub fn rate_prediction(
1861 &mut self,
1862 prediction: &EditPrediction,
1863 rating: EditPredictionRating,
1864 feedback: String,
1865 cx: &mut Context<Self>,
1866 ) {
1867 self.rated_predictions.insert(prediction.id.clone());
1868 telemetry::event!(
1869 "Edit Prediction Rated",
1870 rating,
1871 inputs = prediction.inputs,
1872 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1873 feedback
1874 );
1875 self.client.telemetry().flush_events().detach();
1876 cx.notify();
1877 }
1878
1879 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1880 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1881 && all_language_settings(None, cx).edit_predictions.use_context;
1882 }
1883}
1884
1885#[derive(Error, Debug)]
1886#[error(
1887 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1888)]
1889pub struct ZedUpdateRequiredError {
1890 minimum_version: Version,
1891}
1892
1893#[cfg(feature = "cli-support")]
1894pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1895
1896#[cfg(feature = "cli-support")]
1897#[derive(Debug, Clone, Copy, PartialEq)]
1898pub enum EvalCacheEntryKind {
1899 Context,
1900 Search,
1901 Prediction,
1902}
1903
1904#[cfg(feature = "cli-support")]
1905impl std::fmt::Display for EvalCacheEntryKind {
1906 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1907 match self {
1908 EvalCacheEntryKind::Search => write!(f, "search"),
1909 EvalCacheEntryKind::Context => write!(f, "context"),
1910 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1911 }
1912 }
1913}
1914
1915#[cfg(feature = "cli-support")]
1916pub trait EvalCache: Send + Sync {
1917 fn read(&self, key: EvalCacheKey) -> Option<String>;
1918 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1919}
1920
1921#[derive(Debug, Clone, Copy)]
1922pub enum DataCollectionChoice {
1923 NotAnswered,
1924 Enabled,
1925 Disabled,
1926}
1927
1928impl DataCollectionChoice {
1929 pub fn is_enabled(self) -> bool {
1930 match self {
1931 Self::Enabled => true,
1932 Self::NotAnswered | Self::Disabled => false,
1933 }
1934 }
1935
1936 pub fn is_answered(self) -> bool {
1937 match self {
1938 Self::Enabled | Self::Disabled => true,
1939 Self::NotAnswered => false,
1940 }
1941 }
1942
1943 #[must_use]
1944 pub fn toggle(&self) -> DataCollectionChoice {
1945 match self {
1946 Self::Enabled => Self::Disabled,
1947 Self::Disabled => Self::Enabled,
1948 Self::NotAnswered => Self::Enabled,
1949 }
1950 }
1951}
1952
1953impl From<bool> for DataCollectionChoice {
1954 fn from(value: bool) -> Self {
1955 match value {
1956 true => DataCollectionChoice::Enabled,
1957 false => DataCollectionChoice::Disabled,
1958 }
1959 }
1960}
1961
1962struct ZedPredictUpsell;
1963
1964impl Dismissable for ZedPredictUpsell {
1965 const KEY: &'static str = "dismissed-edit-predict-upsell";
1966
1967 fn dismissed() -> bool {
1968 // To make this backwards compatible with older versions of Zed, we
1969 // check if the user has seen the previous Edit Prediction Onboarding
1970 // before, by checking the data collection choice which was written to
1971 // the database once the user clicked on "Accept and Enable"
1972 if KEY_VALUE_STORE
1973 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1974 .log_err()
1975 .is_some_and(|s| s.is_some())
1976 {
1977 return true;
1978 }
1979
1980 KEY_VALUE_STORE
1981 .read_kvp(Self::KEY)
1982 .log_err()
1983 .is_some_and(|s| s.is_some())
1984 }
1985}
1986
1987pub fn should_show_upsell_modal() -> bool {
1988 !ZedPredictUpsell::dismissed()
1989}
1990
1991pub fn init(cx: &mut App) {
1992 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
1993 workspace.register_action(
1994 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
1995 ZedPredictModal::toggle(
1996 workspace,
1997 workspace.user_store().clone(),
1998 workspace.client().clone(),
1999 window,
2000 cx,
2001 )
2002 },
2003 );
2004
2005 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2006 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2007 settings
2008 .project
2009 .all_languages
2010 .features
2011 .get_or_insert_default()
2012 .edit_prediction_provider = Some(EditPredictionProvider::None)
2013 });
2014 });
2015 })
2016 .detach();
2017}