1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::http_client::Url;
23use gpui::{
24 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
25 http_client::{self, AsyncBody, Method},
26 prelude::*,
27};
28use language::language_settings::all_language_settings;
29use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
30use language::{BufferSnapshot, OffsetRangeExt};
31use language_model::{LlmApiToken, RefreshLlmTokenListener};
32use project::{Project, ProjectPath, WorktreeId};
33use release_channel::AppVersion;
34use semver::Version;
35use serde::de::DeserializeOwned;
36use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
37use std::collections::{VecDeque, hash_map};
38use workspace::Workspace;
39
40use std::ops::Range;
41use std::path::Path;
42use std::rc::Rc;
43use std::str::FromStr as _;
44use std::sync::{Arc, LazyLock};
45use std::time::{Duration, Instant};
46use std::{env, mem};
47use thiserror::Error;
48use util::{RangeExt as _, ResultExt as _};
49use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
50
51pub mod cursor_excerpt;
52pub mod example_spec;
53mod license_detection;
54pub mod mercury;
55mod onboarding_modal;
56pub mod open_ai_response;
57mod prediction;
58pub mod sweep_ai;
59
60#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
61pub mod udiff;
62
63mod zed_edit_prediction_delegate;
64pub mod zeta1;
65pub mod zeta2;
66
67#[cfg(test)]
68mod edit_prediction_tests;
69
70use crate::license_detection::LicenseDetectionWatcher;
71use crate::mercury::Mercury;
72use crate::onboarding_modal::ZedPredictModal;
73pub use crate::prediction::EditPrediction;
74pub use crate::prediction::EditPredictionId;
75use crate::prediction::EditPredictionResult;
76pub use crate::sweep_ai::SweepAi;
77pub use language_model::ApiKeyState;
78pub use telemetry_events::EditPredictionRating;
79pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
80
81actions!(
82 edit_prediction,
83 [
84 /// Resets the edit prediction onboarding state.
85 ResetOnboarding,
86 /// Clears the edit prediction history.
87 ClearHistory,
88 ]
89);
90
91/// Maximum number of events to track.
92const EVENT_COUNT_MAX: usize = 6;
93const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
94const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
95const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
96const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
97
98pub struct SweepFeatureFlag;
99
100impl FeatureFlag for SweepFeatureFlag {
101 const NAME: &str = "sweep-ai";
102}
103
104pub struct MercuryFeatureFlag;
105
106impl FeatureFlag for MercuryFeatureFlag {
107 const NAME: &str = "mercury";
108}
109
110pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
111 context: EditPredictionExcerptOptions {
112 max_bytes: 512,
113 min_bytes: 128,
114 target_before_cursor_over_total_bytes: 0.5,
115 },
116 prompt_format: PromptFormat::DEFAULT,
117};
118
119static USE_OLLAMA: LazyLock<bool> =
120 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
121
122static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
123 match env::var("ZED_ZETA2_MODEL").as_deref() {
124 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
125 Ok(model) => model,
126 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
127 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
128 }
129 .to_string()
130});
131
132pub struct Zeta2FeatureFlag;
133
134impl FeatureFlag for Zeta2FeatureFlag {
135 const NAME: &'static str = "zeta2";
136
137 fn enabled_for_staff() -> bool {
138 true
139 }
140}
141
142#[derive(Clone)]
143struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
144
145impl Global for EditPredictionStoreGlobal {}
146
147pub struct EditPredictionStore {
148 client: Arc<Client>,
149 user_store: Entity<UserStore>,
150 llm_token: LlmApiToken,
151 _llm_token_subscription: Subscription,
152 projects: HashMap<EntityId, ProjectState>,
153 use_context: bool,
154 options: ZetaOptions,
155 update_required: bool,
156 #[cfg(feature = "cli-support")]
157 eval_cache: Option<Arc<dyn EvalCache>>,
158 edit_prediction_model: EditPredictionModel,
159 pub sweep_ai: SweepAi,
160 pub mercury: Mercury,
161 data_collection_choice: DataCollectionChoice,
162 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
163 shown_predictions: VecDeque<EditPrediction>,
164 rated_predictions: HashSet<EditPredictionId>,
165 custom_predict_edits_url: Option<Arc<Url>>,
166}
167
168#[derive(Copy, Clone, Default, PartialEq, Eq)]
169pub enum EditPredictionModel {
170 #[default]
171 Zeta1,
172 Zeta2,
173 Sweep,
174 Mercury,
175}
176
177pub struct EditPredictionModelInput {
178 project: Entity<Project>,
179 buffer: Entity<Buffer>,
180 snapshot: BufferSnapshot,
181 position: Anchor,
182 events: Vec<Arc<zeta_prompt::Event>>,
183 related_files: Arc<[RelatedFile]>,
184 recent_paths: VecDeque<ProjectPath>,
185 trigger: PredictEditsRequestTrigger,
186 diagnostic_search_range: Range<Point>,
187 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
188}
189
190#[derive(Debug, Clone, PartialEq)]
191pub struct ZetaOptions {
192 pub context: EditPredictionExcerptOptions,
193 pub prompt_format: predict_edits_v3::PromptFormat,
194}
195
196#[derive(Debug)]
197pub enum DebugEvent {
198 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
199 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
200 EditPredictionStarted(EditPredictionStartedDebugEvent),
201 EditPredictionFinished(EditPredictionFinishedDebugEvent),
202}
203
204#[derive(Debug)]
205pub struct ContextRetrievalStartedDebugEvent {
206 pub project_entity_id: EntityId,
207 pub timestamp: Instant,
208 pub search_prompt: String,
209}
210
211#[derive(Debug)]
212pub struct ContextRetrievalFinishedDebugEvent {
213 pub project_entity_id: EntityId,
214 pub timestamp: Instant,
215 pub metadata: Vec<(&'static str, SharedString)>,
216}
217
218#[derive(Debug)]
219pub struct EditPredictionStartedDebugEvent {
220 pub buffer: WeakEntity<Buffer>,
221 pub position: Anchor,
222 pub prompt: Option<String>,
223}
224
225#[derive(Debug)]
226pub struct EditPredictionFinishedDebugEvent {
227 pub buffer: WeakEntity<Buffer>,
228 pub position: Anchor,
229 pub model_output: Option<String>,
230}
231
232pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
233
234struct ProjectState {
235 events: VecDeque<Arc<zeta_prompt::Event>>,
236 last_event: Option<LastEvent>,
237 recent_paths: VecDeque<ProjectPath>,
238 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
239 current_prediction: Option<CurrentEditPrediction>,
240 next_pending_prediction_id: usize,
241 pending_predictions: ArrayVec<PendingPrediction, 2>,
242 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
243 last_prediction_refresh: Option<(EntityId, Instant)>,
244 cancelled_predictions: HashSet<usize>,
245 context: Entity<RelatedExcerptStore>,
246 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
247 _subscription: gpui::Subscription,
248}
249
250impl ProjectState {
251 pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
252 self.events
253 .iter()
254 .cloned()
255 .chain(
256 self.last_event
257 .as_ref()
258 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
259 )
260 .collect()
261 }
262
263 pub fn events_split_by_pause(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
264 self.events
265 .iter()
266 .cloned()
267 .chain(self.last_event.as_ref().iter().flat_map(|event| {
268 let (one, two) = event.split_by_pause();
269 let one = one.finalize(&self.license_detection_watchers, cx);
270 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
271 one.into_iter().chain(two)
272 }))
273 .collect()
274 }
275
276 fn cancel_pending_prediction(
277 &mut self,
278 pending_prediction: PendingPrediction,
279 cx: &mut Context<EditPredictionStore>,
280 ) {
281 self.cancelled_predictions.insert(pending_prediction.id);
282
283 cx.spawn(async move |this, cx| {
284 let Some(prediction_id) = pending_prediction.task.await else {
285 return;
286 };
287
288 this.update(cx, |this, _cx| {
289 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
290 })
291 .ok();
292 })
293 .detach()
294 }
295
296 fn active_buffer(
297 &self,
298 project: &Entity<Project>,
299 cx: &App,
300 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
301 let project = project.read(cx);
302 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
303 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
304 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
305 Some((active_buffer, registered_buffer.last_position))
306 }
307}
308
309#[derive(Debug, Clone)]
310struct CurrentEditPrediction {
311 pub requested_by: PredictionRequestedBy,
312 pub prediction: EditPrediction,
313 pub was_shown: bool,
314}
315
316impl CurrentEditPrediction {
317 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
318 let Some(new_edits) = self
319 .prediction
320 .interpolate(&self.prediction.buffer.read(cx))
321 else {
322 return false;
323 };
324
325 if self.prediction.buffer != old_prediction.prediction.buffer {
326 return true;
327 }
328
329 let Some(old_edits) = old_prediction
330 .prediction
331 .interpolate(&old_prediction.prediction.buffer.read(cx))
332 else {
333 return true;
334 };
335
336 let requested_by_buffer_id = self.requested_by.buffer_id();
337
338 // This reduces the occurrence of UI thrash from replacing edits
339 //
340 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
341 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
342 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
343 && old_edits.len() == 1
344 && new_edits.len() == 1
345 {
346 let (old_range, old_text) = &old_edits[0];
347 let (new_range, new_text) = &new_edits[0];
348 new_range == old_range && new_text.starts_with(old_text.as_ref())
349 } else {
350 true
351 }
352 }
353}
354
355#[derive(Debug, Clone)]
356enum PredictionRequestedBy {
357 DiagnosticsUpdate,
358 Buffer(EntityId),
359}
360
361impl PredictionRequestedBy {
362 pub fn buffer_id(&self) -> Option<EntityId> {
363 match self {
364 PredictionRequestedBy::DiagnosticsUpdate => None,
365 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
366 }
367 }
368}
369
370#[derive(Debug)]
371struct PendingPrediction {
372 id: usize,
373 task: Task<Option<EditPredictionId>>,
374}
375
376/// A prediction from the perspective of a buffer.
377#[derive(Debug)]
378enum BufferEditPrediction<'a> {
379 Local { prediction: &'a EditPrediction },
380 Jump { prediction: &'a EditPrediction },
381}
382
383#[cfg(test)]
384impl std::ops::Deref for BufferEditPrediction<'_> {
385 type Target = EditPrediction;
386
387 fn deref(&self) -> &Self::Target {
388 match self {
389 BufferEditPrediction::Local { prediction } => prediction,
390 BufferEditPrediction::Jump { prediction } => prediction,
391 }
392 }
393}
394
395struct RegisteredBuffer {
396 file: Option<Arc<dyn File>>,
397 snapshot: TextBufferSnapshot,
398 last_position: Option<Anchor>,
399 _subscriptions: [gpui::Subscription; 2],
400}
401
402#[derive(Clone)]
403struct LastEvent {
404 old_snapshot: TextBufferSnapshot,
405 new_snapshot: TextBufferSnapshot,
406 old_file: Option<Arc<dyn File>>,
407 new_file: Option<Arc<dyn File>>,
408 end_edit_anchor: Option<Anchor>,
409 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
410 last_edit_time: Option<Instant>,
411}
412
413impl LastEvent {
414 pub fn finalize(
415 &self,
416 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
417 cx: &App,
418 ) -> Option<Arc<zeta_prompt::Event>> {
419 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
420 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
421
422 let in_open_source_repo =
423 [self.new_file.as_ref(), self.old_file.as_ref()]
424 .iter()
425 .all(|file| {
426 file.is_some_and(|file| {
427 license_detection_watchers
428 .get(&file.worktree_id(cx))
429 .is_some_and(|watcher| watcher.is_project_open_source())
430 })
431 });
432
433 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
434
435 if path == old_path && diff.is_empty() {
436 None
437 } else {
438 Some(Arc::new(zeta_prompt::Event::BufferChange {
439 old_path,
440 path,
441 diff,
442 in_open_source_repo,
443 // TODO: Actually detect if this edit was predicted or not
444 predicted: false,
445 }))
446 }
447 }
448
449 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
450 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
451 return (self.clone(), None);
452 };
453
454 let before = LastEvent {
455 old_snapshot: self.old_snapshot.clone(),
456 new_snapshot: boundary_snapshot.clone(),
457 old_file: self.old_file.clone(),
458 new_file: self.new_file.clone(),
459 end_edit_anchor: self.end_edit_anchor,
460 snapshot_after_last_editing_pause: None,
461 last_edit_time: self.last_edit_time,
462 };
463
464 let after = LastEvent {
465 old_snapshot: boundary_snapshot.clone(),
466 new_snapshot: self.new_snapshot.clone(),
467 old_file: self.old_file.clone(),
468 new_file: self.new_file.clone(),
469 end_edit_anchor: self.end_edit_anchor,
470 snapshot_after_last_editing_pause: None,
471 last_edit_time: self.last_edit_time,
472 };
473
474 (before, Some(after))
475 }
476}
477
478fn buffer_path_with_id_fallback(
479 file: Option<&Arc<dyn File>>,
480 snapshot: &TextBufferSnapshot,
481 cx: &App,
482) -> Arc<Path> {
483 if let Some(file) = file {
484 file.full_path(cx).into()
485 } else {
486 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
487 }
488}
489
490impl EditPredictionStore {
491 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
492 cx.try_global::<EditPredictionStoreGlobal>()
493 .map(|global| global.0.clone())
494 }
495
496 pub fn global(
497 client: &Arc<Client>,
498 user_store: &Entity<UserStore>,
499 cx: &mut App,
500 ) -> Entity<Self> {
501 cx.try_global::<EditPredictionStoreGlobal>()
502 .map(|global| global.0.clone())
503 .unwrap_or_else(|| {
504 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
505 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
506 ep_store
507 })
508 }
509
510 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
511 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
512 let data_collection_choice = Self::load_data_collection_choice();
513
514 let llm_token = LlmApiToken::default();
515
516 let (reject_tx, reject_rx) = mpsc::unbounded();
517 cx.background_spawn({
518 let client = client.clone();
519 let llm_token = llm_token.clone();
520 let app_version = AppVersion::global(cx);
521 let background_executor = cx.background_executor().clone();
522 async move {
523 Self::handle_rejected_predictions(
524 reject_rx,
525 client,
526 llm_token,
527 app_version,
528 background_executor,
529 )
530 .await
531 }
532 })
533 .detach();
534
535 let mut this = Self {
536 projects: HashMap::default(),
537 client,
538 user_store,
539 options: DEFAULT_OPTIONS,
540 use_context: false,
541 llm_token,
542 _llm_token_subscription: cx.subscribe(
543 &refresh_llm_token_listener,
544 |this, _listener, _event, cx| {
545 let client = this.client.clone();
546 let llm_token = this.llm_token.clone();
547 cx.spawn(async move |_this, _cx| {
548 llm_token.refresh(&client).await?;
549 anyhow::Ok(())
550 })
551 .detach_and_log_err(cx);
552 },
553 ),
554 update_required: false,
555 #[cfg(feature = "cli-support")]
556 eval_cache: None,
557 edit_prediction_model: EditPredictionModel::Zeta2,
558 sweep_ai: SweepAi::new(cx),
559 mercury: Mercury::new(cx),
560 data_collection_choice,
561 reject_predictions_tx: reject_tx,
562 rated_predictions: Default::default(),
563 shown_predictions: Default::default(),
564 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
565 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
566 Err(_) => {
567 if *USE_OLLAMA {
568 Some(
569 Url::parse("http://localhost:11434/v1/chat/completions")
570 .unwrap()
571 .into(),
572 )
573 } else {
574 None
575 }
576 }
577 },
578 };
579
580 this.configure_context_retrieval(cx);
581 let weak_this = cx.weak_entity();
582 cx.on_flags_ready(move |_, cx| {
583 weak_this
584 .update(cx, |this, cx| this.configure_context_retrieval(cx))
585 .ok();
586 })
587 .detach();
588 cx.observe_global::<SettingsStore>(|this, cx| {
589 this.configure_context_retrieval(cx);
590 })
591 .detach();
592
593 this
594 }
595
596 #[cfg(test)]
597 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
598 self.custom_predict_edits_url = Some(url.into());
599 }
600
601 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
602 self.edit_prediction_model = model;
603 }
604
605 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
606 self.sweep_ai.api_token.read(cx).has_key()
607 }
608
609 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
610 self.mercury.api_token.read(cx).has_key()
611 }
612
613 #[cfg(feature = "cli-support")]
614 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
615 self.eval_cache = Some(cache);
616 }
617
618 pub fn options(&self) -> &ZetaOptions {
619 &self.options
620 }
621
622 pub fn set_options(&mut self, options: ZetaOptions) {
623 self.options = options;
624 }
625
626 pub fn set_use_context(&mut self, use_context: bool) {
627 self.use_context = use_context;
628 }
629
630 pub fn clear_history(&mut self) {
631 for project_state in self.projects.values_mut() {
632 project_state.events.clear();
633 }
634 }
635
636 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
637 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
638 project_state.events.clear();
639 }
640 }
641
642 pub fn edit_history_for_project(
643 &self,
644 project: &Entity<Project>,
645 cx: &App,
646 ) -> Vec<Arc<zeta_prompt::Event>> {
647 self.projects
648 .get(&project.entity_id())
649 .map(|project_state| project_state.events(cx))
650 .unwrap_or_default()
651 }
652
653 pub fn edit_history_for_project_with_pause_split_last_event(
654 &self,
655 project: &Entity<Project>,
656 cx: &App,
657 ) -> Vec<Arc<zeta_prompt::Event>> {
658 self.projects
659 .get(&project.entity_id())
660 .map(|project_state| project_state.events_split_by_pause(cx))
661 .unwrap_or_default()
662 }
663
664 pub fn context_for_project<'a>(
665 &'a self,
666 project: &Entity<Project>,
667 cx: &'a App,
668 ) -> Arc<[RelatedFile]> {
669 self.projects
670 .get(&project.entity_id())
671 .map(|project| project.context.read(cx).related_files())
672 .unwrap_or_else(|| vec![].into())
673 }
674
675 pub fn context_for_project_with_buffers<'a>(
676 &'a self,
677 project: &Entity<Project>,
678 cx: &'a App,
679 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
680 self.projects
681 .get(&project.entity_id())
682 .map(|project| project.context.read(cx).related_files_with_buffers())
683 }
684
685 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
686 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
687 self.user_store.read(cx).edit_prediction_usage()
688 } else {
689 None
690 }
691 }
692
693 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
694 self.get_or_init_project(project, cx);
695 }
696
697 pub fn register_buffer(
698 &mut self,
699 buffer: &Entity<Buffer>,
700 project: &Entity<Project>,
701 cx: &mut Context<Self>,
702 ) {
703 let project_state = self.get_or_init_project(project, cx);
704 Self::register_buffer_impl(project_state, buffer, project, cx);
705 }
706
707 fn get_or_init_project(
708 &mut self,
709 project: &Entity<Project>,
710 cx: &mut Context<Self>,
711 ) -> &mut ProjectState {
712 let entity_id = project.entity_id();
713 self.projects
714 .entry(entity_id)
715 .or_insert_with(|| ProjectState {
716 context: {
717 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
718 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
719 this.handle_excerpt_store_event(entity_id, event);
720 })
721 .detach();
722 related_excerpt_store
723 },
724 events: VecDeque::new(),
725 last_event: None,
726 recent_paths: VecDeque::new(),
727 debug_tx: None,
728 registered_buffers: HashMap::default(),
729 current_prediction: None,
730 cancelled_predictions: HashSet::default(),
731 pending_predictions: ArrayVec::new(),
732 next_pending_prediction_id: 0,
733 last_prediction_refresh: None,
734 license_detection_watchers: HashMap::default(),
735 _subscription: cx.subscribe(&project, Self::handle_project_event),
736 })
737 }
738
739 pub fn remove_project(&mut self, project: &Entity<Project>) {
740 self.projects.remove(&project.entity_id());
741 }
742
743 fn handle_excerpt_store_event(
744 &mut self,
745 project_entity_id: EntityId,
746 event: &RelatedExcerptStoreEvent,
747 ) {
748 if let Some(project_state) = self.projects.get(&project_entity_id) {
749 if let Some(debug_tx) = project_state.debug_tx.clone() {
750 match event {
751 RelatedExcerptStoreEvent::StartedRefresh => {
752 debug_tx
753 .unbounded_send(DebugEvent::ContextRetrievalStarted(
754 ContextRetrievalStartedDebugEvent {
755 project_entity_id: project_entity_id,
756 timestamp: Instant::now(),
757 search_prompt: String::new(),
758 },
759 ))
760 .ok();
761 }
762 RelatedExcerptStoreEvent::FinishedRefresh {
763 cache_hit_count,
764 cache_miss_count,
765 mean_definition_latency,
766 max_definition_latency,
767 } => {
768 debug_tx
769 .unbounded_send(DebugEvent::ContextRetrievalFinished(
770 ContextRetrievalFinishedDebugEvent {
771 project_entity_id: project_entity_id,
772 timestamp: Instant::now(),
773 metadata: vec![
774 (
775 "Cache Hits",
776 format!(
777 "{}/{}",
778 cache_hit_count,
779 cache_hit_count + cache_miss_count
780 )
781 .into(),
782 ),
783 (
784 "Max LSP Time",
785 format!("{} ms", max_definition_latency.as_millis())
786 .into(),
787 ),
788 (
789 "Mean LSP Time",
790 format!("{} ms", mean_definition_latency.as_millis())
791 .into(),
792 ),
793 ],
794 },
795 ))
796 .ok();
797 }
798 }
799 }
800 }
801 }
802
803 pub fn debug_info(
804 &mut self,
805 project: &Entity<Project>,
806 cx: &mut Context<Self>,
807 ) -> mpsc::UnboundedReceiver<DebugEvent> {
808 let project_state = self.get_or_init_project(project, cx);
809 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
810 project_state.debug_tx = Some(debug_watch_tx);
811 debug_watch_rx
812 }
813
814 fn handle_project_event(
815 &mut self,
816 project: Entity<Project>,
817 event: &project::Event,
818 cx: &mut Context<Self>,
819 ) {
820 // TODO [zeta2] init with recent paths
821 match event {
822 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
823 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
824 return;
825 };
826 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
827 if let Some(path) = path {
828 if let Some(ix) = project_state
829 .recent_paths
830 .iter()
831 .position(|probe| probe == &path)
832 {
833 project_state.recent_paths.remove(ix);
834 }
835 project_state.recent_paths.push_front(path);
836 }
837 }
838 project::Event::DiagnosticsUpdated { .. } => {
839 if cx.has_flag::<Zeta2FeatureFlag>() {
840 self.refresh_prediction_from_diagnostics(project, cx);
841 }
842 }
843 _ => (),
844 }
845 }
846
847 fn register_buffer_impl<'a>(
848 project_state: &'a mut ProjectState,
849 buffer: &Entity<Buffer>,
850 project: &Entity<Project>,
851 cx: &mut Context<Self>,
852 ) -> &'a mut RegisteredBuffer {
853 let buffer_id = buffer.entity_id();
854
855 if let Some(file) = buffer.read(cx).file() {
856 let worktree_id = file.worktree_id(cx);
857 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
858 project_state
859 .license_detection_watchers
860 .entry(worktree_id)
861 .or_insert_with(|| {
862 let project_entity_id = project.entity_id();
863 cx.observe_release(&worktree, move |this, _worktree, _cx| {
864 let Some(project_state) = this.projects.get_mut(&project_entity_id)
865 else {
866 return;
867 };
868 project_state
869 .license_detection_watchers
870 .remove(&worktree_id);
871 })
872 .detach();
873 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
874 });
875 }
876 }
877
878 match project_state.registered_buffers.entry(buffer_id) {
879 hash_map::Entry::Occupied(entry) => entry.into_mut(),
880 hash_map::Entry::Vacant(entry) => {
881 let buf = buffer.read(cx);
882 let snapshot = buf.text_snapshot();
883 let file = buf.file().cloned();
884 let project_entity_id = project.entity_id();
885 entry.insert(RegisteredBuffer {
886 snapshot,
887 file,
888 last_position: None,
889 _subscriptions: [
890 cx.subscribe(buffer, {
891 let project = project.downgrade();
892 move |this, buffer, event, cx| {
893 if let language::BufferEvent::Edited = event
894 && let Some(project) = project.upgrade()
895 {
896 this.report_changes_for_buffer(&buffer, &project, cx);
897 }
898 }
899 }),
900 cx.observe_release(buffer, move |this, _buffer, _cx| {
901 let Some(project_state) = this.projects.get_mut(&project_entity_id)
902 else {
903 return;
904 };
905 project_state.registered_buffers.remove(&buffer_id);
906 }),
907 ],
908 })
909 }
910 }
911 }
912
913 fn report_changes_for_buffer(
914 &mut self,
915 buffer: &Entity<Buffer>,
916 project: &Entity<Project>,
917 cx: &mut Context<Self>,
918 ) {
919 let project_state = self.get_or_init_project(project, cx);
920 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
921
922 let buf = buffer.read(cx);
923 let new_file = buf.file().cloned();
924 let new_snapshot = buf.text_snapshot();
925 if new_snapshot.version == registered_buffer.snapshot.version {
926 return;
927 }
928
929 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
930 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
931 let end_edit_anchor = new_snapshot
932 .anchored_edits_since::<Point>(&old_snapshot.version)
933 .last()
934 .map(|(_, range)| range.end);
935 let events = &mut project_state.events;
936
937 let now = cx.background_executor().now();
938 if let Some(last_event) = project_state.last_event.as_mut() {
939 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
940 == last_event.new_snapshot.remote_id()
941 && old_snapshot.version == last_event.new_snapshot.version;
942
943 let should_coalesce = is_next_snapshot_of_same_buffer
944 && end_edit_anchor
945 .as_ref()
946 .zip(last_event.end_edit_anchor.as_ref())
947 .is_some_and(|(a, b)| {
948 let a = a.to_point(&new_snapshot);
949 let b = b.to_point(&new_snapshot);
950 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
951 });
952
953 if should_coalesce {
954 let pause_elapsed = last_event
955 .last_edit_time
956 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
957 .unwrap_or(false);
958 if pause_elapsed {
959 last_event.snapshot_after_last_editing_pause =
960 Some(last_event.new_snapshot.clone());
961 }
962
963 last_event.end_edit_anchor = end_edit_anchor;
964 last_event.new_snapshot = new_snapshot;
965 last_event.last_edit_time = Some(now);
966 return;
967 }
968 }
969
970 if events.len() + 1 >= EVENT_COUNT_MAX {
971 events.pop_front();
972 }
973
974 if let Some(event) = project_state.last_event.take() {
975 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
976 }
977
978 project_state.last_event = Some(LastEvent {
979 old_file,
980 new_file,
981 old_snapshot,
982 new_snapshot,
983 end_edit_anchor,
984 snapshot_after_last_editing_pause: None,
985 last_edit_time: Some(now),
986 });
987 }
988
989 fn prediction_at(
990 &mut self,
991 buffer: &Entity<Buffer>,
992 position: Option<language::Anchor>,
993 project: &Entity<Project>,
994 cx: &App,
995 ) -> Option<BufferEditPrediction<'_>> {
996 let project_state = self.projects.get_mut(&project.entity_id())?;
997 if let Some(position) = position
998 && let Some(buffer) = project_state
999 .registered_buffers
1000 .get_mut(&buffer.entity_id())
1001 {
1002 buffer.last_position = Some(position);
1003 }
1004
1005 let CurrentEditPrediction {
1006 requested_by,
1007 prediction,
1008 ..
1009 } = project_state.current_prediction.as_ref()?;
1010
1011 if prediction.targets_buffer(buffer.read(cx)) {
1012 Some(BufferEditPrediction::Local { prediction })
1013 } else {
1014 let show_jump = match requested_by {
1015 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1016 requested_by_buffer_id == &buffer.entity_id()
1017 }
1018 PredictionRequestedBy::DiagnosticsUpdate => true,
1019 };
1020
1021 if show_jump {
1022 Some(BufferEditPrediction::Jump { prediction })
1023 } else {
1024 None
1025 }
1026 }
1027 }
1028
1029 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1030 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
1031 match self.edit_prediction_model {
1032 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1033 if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
1034 return;
1035 }
1036 }
1037 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1038 }
1039
1040 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1041 return;
1042 };
1043
1044 let Some(prediction) = project_state.current_prediction.take() else {
1045 return;
1046 };
1047 let request_id = prediction.prediction.id.to_string();
1048 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1049 project_state.cancel_pending_prediction(pending_prediction, cx);
1050 }
1051
1052 let client = self.client.clone();
1053 let llm_token = self.llm_token.clone();
1054 let app_version = AppVersion::global(cx);
1055 cx.spawn(async move |this, cx| {
1056 let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
1057 (http_client::Url::parse(&accept_edits_url)?, false)
1058 } else {
1059 (
1060 client
1061 .http_client()
1062 .build_zed_llm_url("/predict_edits/accept", &[])?,
1063 true,
1064 )
1065 };
1066
1067 let response = cx
1068 .background_spawn(Self::send_api_request::<()>(
1069 move |builder| {
1070 let req = builder.uri(url.as_ref()).body(
1071 serde_json::to_string(&AcceptEditPredictionBody {
1072 request_id: request_id.clone(),
1073 })?
1074 .into(),
1075 );
1076 Ok(req?)
1077 },
1078 client,
1079 llm_token,
1080 app_version,
1081 require_auth,
1082 ))
1083 .await;
1084
1085 Self::handle_api_response(&this, response, cx)?;
1086 anyhow::Ok(())
1087 })
1088 .detach_and_log_err(cx);
1089 }
1090
1091 async fn handle_rejected_predictions(
1092 rx: UnboundedReceiver<EditPredictionRejection>,
1093 client: Arc<Client>,
1094 llm_token: LlmApiToken,
1095 app_version: Version,
1096 background_executor: BackgroundExecutor,
1097 ) {
1098 let mut rx = std::pin::pin!(rx.peekable());
1099 let mut batched = Vec::new();
1100
1101 while let Some(rejection) = rx.next().await {
1102 batched.push(rejection);
1103
1104 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1105 select_biased! {
1106 next = rx.as_mut().peek().fuse() => {
1107 if next.is_some() {
1108 continue;
1109 }
1110 }
1111 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1112 }
1113 }
1114
1115 let url = client
1116 .http_client()
1117 .build_zed_llm_url("/predict_edits/reject", &[])
1118 .unwrap();
1119
1120 let flush_count = batched
1121 .len()
1122 // in case items have accumulated after failure
1123 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1124 let start = batched.len() - flush_count;
1125
1126 let body = RejectEditPredictionsBodyRef {
1127 rejections: &batched[start..],
1128 };
1129
1130 let result = Self::send_api_request::<()>(
1131 |builder| {
1132 let req = builder
1133 .uri(url.as_ref())
1134 .body(serde_json::to_string(&body)?.into());
1135 anyhow::Ok(req?)
1136 },
1137 client.clone(),
1138 llm_token.clone(),
1139 app_version.clone(),
1140 true,
1141 )
1142 .await;
1143
1144 if result.log_err().is_some() {
1145 batched.drain(start..);
1146 }
1147 }
1148 }
1149
1150 fn reject_current_prediction(
1151 &mut self,
1152 reason: EditPredictionRejectReason,
1153 project: &Entity<Project>,
1154 ) {
1155 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1156 project_state.pending_predictions.clear();
1157 if let Some(prediction) = project_state.current_prediction.take() {
1158 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1159 }
1160 };
1161 }
1162
1163 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1164 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1165 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1166 if !current_prediction.was_shown {
1167 current_prediction.was_shown = true;
1168 self.shown_predictions
1169 .push_front(current_prediction.prediction.clone());
1170 if self.shown_predictions.len() > 50 {
1171 let completion = self.shown_predictions.pop_back().unwrap();
1172 self.rated_predictions.remove(&completion.id);
1173 }
1174 }
1175 }
1176 }
1177 }
1178
1179 fn reject_prediction(
1180 &mut self,
1181 prediction_id: EditPredictionId,
1182 reason: EditPredictionRejectReason,
1183 was_shown: bool,
1184 ) {
1185 match self.edit_prediction_model {
1186 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1187 if self.custom_predict_edits_url.is_some() {
1188 return;
1189 }
1190 }
1191 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1192 }
1193
1194 self.reject_predictions_tx
1195 .unbounded_send(EditPredictionRejection {
1196 request_id: prediction_id.to_string(),
1197 reason,
1198 was_shown,
1199 })
1200 .log_err();
1201 }
1202
1203 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1204 self.projects
1205 .get(&project.entity_id())
1206 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1207 }
1208
1209 pub fn refresh_prediction_from_buffer(
1210 &mut self,
1211 project: Entity<Project>,
1212 buffer: Entity<Buffer>,
1213 position: language::Anchor,
1214 cx: &mut Context<Self>,
1215 ) {
1216 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1217 let Some(request_task) = this
1218 .update(cx, |this, cx| {
1219 this.request_prediction(
1220 &project,
1221 &buffer,
1222 position,
1223 PredictEditsRequestTrigger::Other,
1224 cx,
1225 )
1226 })
1227 .log_err()
1228 else {
1229 return Task::ready(anyhow::Ok(None));
1230 };
1231
1232 cx.spawn(async move |_cx| {
1233 request_task.await.map(|prediction_result| {
1234 prediction_result.map(|prediction_result| {
1235 (
1236 prediction_result,
1237 PredictionRequestedBy::Buffer(buffer.entity_id()),
1238 )
1239 })
1240 })
1241 })
1242 })
1243 }
1244
1245 pub fn refresh_prediction_from_diagnostics(
1246 &mut self,
1247 project: Entity<Project>,
1248 cx: &mut Context<Self>,
1249 ) {
1250 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1251 return;
1252 };
1253
1254 // Prefer predictions from buffer
1255 if project_state.current_prediction.is_some() {
1256 return;
1257 };
1258
1259 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1260 let Some((active_buffer, snapshot, cursor_point)) = this
1261 .read_with(cx, |this, cx| {
1262 let project_state = this.projects.get(&project.entity_id())?;
1263 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1264 let snapshot = buffer.read(cx).snapshot();
1265
1266 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1267 return None;
1268 }
1269
1270 let cursor_point = position
1271 .map(|pos| pos.to_point(&snapshot))
1272 .unwrap_or_default();
1273
1274 Some((buffer, snapshot, cursor_point))
1275 })
1276 .log_err()
1277 .flatten()
1278 else {
1279 return Task::ready(anyhow::Ok(None));
1280 };
1281
1282 cx.spawn(async move |cx| {
1283 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1284 active_buffer,
1285 &snapshot,
1286 Default::default(),
1287 cursor_point,
1288 &project,
1289 cx,
1290 )
1291 .await?
1292 else {
1293 return anyhow::Ok(None);
1294 };
1295
1296 let Some(prediction_result) = this
1297 .update(cx, |this, cx| {
1298 this.request_prediction(
1299 &project,
1300 &jump_buffer,
1301 jump_position,
1302 PredictEditsRequestTrigger::Diagnostics,
1303 cx,
1304 )
1305 })?
1306 .await?
1307 else {
1308 return anyhow::Ok(None);
1309 };
1310
1311 this.update(cx, |this, cx| {
1312 Some((
1313 if this
1314 .get_or_init_project(&project, cx)
1315 .current_prediction
1316 .is_none()
1317 {
1318 prediction_result
1319 } else {
1320 EditPredictionResult {
1321 id: prediction_result.id,
1322 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1323 }
1324 },
1325 PredictionRequestedBy::DiagnosticsUpdate,
1326 ))
1327 })
1328 })
1329 });
1330 }
1331
1332 fn predictions_enabled_at(
1333 snapshot: &BufferSnapshot,
1334 position: Option<language::Anchor>,
1335 cx: &App,
1336 ) -> bool {
1337 let file = snapshot.file();
1338 let all_settings = all_language_settings(file, cx);
1339 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1340 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1341 {
1342 return false;
1343 }
1344
1345 if let Some(last_position) = position {
1346 let settings = snapshot.settings_at(last_position, cx);
1347
1348 if !settings.edit_predictions_disabled_in.is_empty()
1349 && let Some(scope) = snapshot.language_scope_at(last_position)
1350 && let Some(scope_name) = scope.override_name()
1351 && settings
1352 .edit_predictions_disabled_in
1353 .iter()
1354 .any(|s| s == scope_name)
1355 {
1356 return false;
1357 }
1358 }
1359
1360 true
1361 }
1362
1363 #[cfg(not(test))]
1364 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1365 #[cfg(test)]
1366 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1367
1368 fn queue_prediction_refresh(
1369 &mut self,
1370 project: Entity<Project>,
1371 throttle_entity: EntityId,
1372 cx: &mut Context<Self>,
1373 do_refresh: impl FnOnce(
1374 WeakEntity<Self>,
1375 &mut AsyncApp,
1376 )
1377 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1378 + 'static,
1379 ) {
1380 let project_state = self.get_or_init_project(&project, cx);
1381 let pending_prediction_id = project_state.next_pending_prediction_id;
1382 project_state.next_pending_prediction_id += 1;
1383 let last_request = project_state.last_prediction_refresh;
1384
1385 let task = cx.spawn(async move |this, cx| {
1386 if let Some((last_entity, last_timestamp)) = last_request
1387 && throttle_entity == last_entity
1388 && let Some(timeout) =
1389 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1390 {
1391 cx.background_executor().timer(timeout).await;
1392 }
1393
1394 // If this task was cancelled before the throttle timeout expired,
1395 // do not perform a request.
1396 let mut is_cancelled = true;
1397 this.update(cx, |this, cx| {
1398 let project_state = this.get_or_init_project(&project, cx);
1399 if !project_state
1400 .cancelled_predictions
1401 .remove(&pending_prediction_id)
1402 {
1403 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1404 is_cancelled = false;
1405 }
1406 })
1407 .ok();
1408 if is_cancelled {
1409 return None;
1410 }
1411
1412 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1413 let new_prediction_id = new_prediction_result
1414 .as_ref()
1415 .map(|(prediction, _)| prediction.id.clone());
1416
1417 // When a prediction completes, remove it from the pending list, and cancel
1418 // any pending predictions that were enqueued before it.
1419 this.update(cx, |this, cx| {
1420 let project_state = this.get_or_init_project(&project, cx);
1421
1422 let is_cancelled = project_state
1423 .cancelled_predictions
1424 .remove(&pending_prediction_id);
1425
1426 let new_current_prediction = if !is_cancelled
1427 && let Some((prediction_result, requested_by)) = new_prediction_result
1428 {
1429 match prediction_result.prediction {
1430 Ok(prediction) => {
1431 let new_prediction = CurrentEditPrediction {
1432 requested_by,
1433 prediction,
1434 was_shown: false,
1435 };
1436
1437 if let Some(current_prediction) =
1438 project_state.current_prediction.as_ref()
1439 {
1440 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1441 {
1442 this.reject_current_prediction(
1443 EditPredictionRejectReason::Replaced,
1444 &project,
1445 );
1446
1447 Some(new_prediction)
1448 } else {
1449 this.reject_prediction(
1450 new_prediction.prediction.id,
1451 EditPredictionRejectReason::CurrentPreferred,
1452 false,
1453 );
1454 None
1455 }
1456 } else {
1457 Some(new_prediction)
1458 }
1459 }
1460 Err(reject_reason) => {
1461 this.reject_prediction(prediction_result.id, reject_reason, false);
1462 None
1463 }
1464 }
1465 } else {
1466 None
1467 };
1468
1469 let project_state = this.get_or_init_project(&project, cx);
1470
1471 if let Some(new_prediction) = new_current_prediction {
1472 project_state.current_prediction = Some(new_prediction);
1473 }
1474
1475 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1476 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1477 if pending_prediction.id == pending_prediction_id {
1478 pending_predictions.remove(ix);
1479 for pending_prediction in pending_predictions.drain(0..ix) {
1480 project_state.cancel_pending_prediction(pending_prediction, cx)
1481 }
1482 break;
1483 }
1484 }
1485 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1486 cx.notify();
1487 })
1488 .ok();
1489
1490 new_prediction_id
1491 });
1492
1493 if project_state.pending_predictions.len() <= 1 {
1494 project_state.pending_predictions.push(PendingPrediction {
1495 id: pending_prediction_id,
1496 task,
1497 });
1498 } else if project_state.pending_predictions.len() == 2 {
1499 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1500 project_state.pending_predictions.push(PendingPrediction {
1501 id: pending_prediction_id,
1502 task,
1503 });
1504 project_state.cancel_pending_prediction(pending_prediction, cx);
1505 }
1506 }
1507
1508 pub fn request_prediction(
1509 &mut self,
1510 project: &Entity<Project>,
1511 active_buffer: &Entity<Buffer>,
1512 position: language::Anchor,
1513 trigger: PredictEditsRequestTrigger,
1514 cx: &mut Context<Self>,
1515 ) -> Task<Result<Option<EditPredictionResult>>> {
1516 self.request_prediction_internal(
1517 project.clone(),
1518 active_buffer.clone(),
1519 position,
1520 trigger,
1521 cx.has_flag::<Zeta2FeatureFlag>(),
1522 cx,
1523 )
1524 }
1525
1526 fn request_prediction_internal(
1527 &mut self,
1528 project: Entity<Project>,
1529 active_buffer: Entity<Buffer>,
1530 position: language::Anchor,
1531 trigger: PredictEditsRequestTrigger,
1532 allow_jump: bool,
1533 cx: &mut Context<Self>,
1534 ) -> Task<Result<Option<EditPredictionResult>>> {
1535 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1536
1537 self.get_or_init_project(&project, cx);
1538 let project_state = self.projects.get(&project.entity_id()).unwrap();
1539 let events = project_state.events(cx);
1540 let has_events = !events.is_empty();
1541 let debug_tx = project_state.debug_tx.clone();
1542
1543 let snapshot = active_buffer.read(cx).snapshot();
1544 let cursor_point = position.to_point(&snapshot);
1545 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1546 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1547 let diagnostic_search_range =
1548 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1549
1550 let related_files = if self.use_context {
1551 self.context_for_project(&project, cx)
1552 } else {
1553 Vec::new().into()
1554 };
1555
1556 let inputs = EditPredictionModelInput {
1557 project: project.clone(),
1558 buffer: active_buffer.clone(),
1559 snapshot: snapshot.clone(),
1560 position,
1561 events,
1562 related_files,
1563 recent_paths: project_state.recent_paths.clone(),
1564 trigger,
1565 diagnostic_search_range: diagnostic_search_range.clone(),
1566 debug_tx,
1567 };
1568
1569 let task = match self.edit_prediction_model {
1570 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1571 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1572 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1573 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1574 };
1575
1576 cx.spawn(async move |this, cx| {
1577 let prediction = task.await?;
1578
1579 if prediction.is_none() && allow_jump {
1580 let cursor_point = position.to_point(&snapshot);
1581 if has_events
1582 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1583 active_buffer.clone(),
1584 &snapshot,
1585 diagnostic_search_range,
1586 cursor_point,
1587 &project,
1588 cx,
1589 )
1590 .await?
1591 {
1592 return this
1593 .update(cx, |this, cx| {
1594 this.request_prediction_internal(
1595 project,
1596 jump_buffer,
1597 jump_position,
1598 trigger,
1599 false,
1600 cx,
1601 )
1602 })?
1603 .await;
1604 }
1605
1606 return anyhow::Ok(None);
1607 }
1608
1609 Ok(prediction)
1610 })
1611 }
1612
1613 async fn next_diagnostic_location(
1614 active_buffer: Entity<Buffer>,
1615 active_buffer_snapshot: &BufferSnapshot,
1616 active_buffer_diagnostic_search_range: Range<Point>,
1617 active_buffer_cursor_point: Point,
1618 project: &Entity<Project>,
1619 cx: &mut AsyncApp,
1620 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1621 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1622 let mut jump_location = active_buffer_snapshot
1623 .diagnostic_groups(None)
1624 .into_iter()
1625 .filter_map(|(_, group)| {
1626 let range = &group.entries[group.primary_ix]
1627 .range
1628 .to_point(&active_buffer_snapshot);
1629 if range.overlaps(&active_buffer_diagnostic_search_range) {
1630 None
1631 } else {
1632 Some(range.start)
1633 }
1634 })
1635 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1636 .map(|position| {
1637 (
1638 active_buffer.clone(),
1639 active_buffer_snapshot.anchor_before(position),
1640 )
1641 });
1642
1643 if jump_location.is_none() {
1644 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1645 let file = buffer.file()?;
1646
1647 Some(ProjectPath {
1648 worktree_id: file.worktree_id(cx),
1649 path: file.path().clone(),
1650 })
1651 })?;
1652
1653 let buffer_task = project.update(cx, |project, cx| {
1654 let (path, _, _) = project
1655 .diagnostic_summaries(false, cx)
1656 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1657 .max_by_key(|(path, _, _)| {
1658 // find the buffer with errors that shares most parent directories
1659 path.path
1660 .components()
1661 .zip(
1662 active_buffer_path
1663 .as_ref()
1664 .map(|p| p.path.components())
1665 .unwrap_or_default(),
1666 )
1667 .take_while(|(a, b)| a == b)
1668 .count()
1669 })?;
1670
1671 Some(project.open_buffer(path, cx))
1672 })?;
1673
1674 if let Some(buffer_task) = buffer_task {
1675 let closest_buffer = buffer_task.await?;
1676
1677 jump_location = closest_buffer
1678 .read_with(cx, |buffer, _cx| {
1679 buffer
1680 .buffer_diagnostics(None)
1681 .into_iter()
1682 .min_by_key(|entry| entry.diagnostic.severity)
1683 .map(|entry| entry.range.start)
1684 })?
1685 .map(|position| (closest_buffer, position));
1686 }
1687 }
1688
1689 anyhow::Ok(jump_location)
1690 }
1691
1692 async fn send_raw_llm_request(
1693 request: open_ai::Request,
1694 client: Arc<Client>,
1695 llm_token: LlmApiToken,
1696 app_version: Version,
1697 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1698 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1699 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1700 let url = client
1701 .http_client()
1702 .build_zed_llm_url("/predict_edits/raw", &[])?;
1703
1704 #[cfg(feature = "cli-support")]
1705 let cache_key = if let Some(cache) = eval_cache {
1706 use collections::FxHasher;
1707 use std::hash::{Hash, Hasher};
1708
1709 let mut hasher = FxHasher::default();
1710 url.hash(&mut hasher);
1711 let request_str = serde_json::to_string_pretty(&request)?;
1712 request_str.hash(&mut hasher);
1713 let hash = hasher.finish();
1714
1715 let key = (eval_cache_kind, hash);
1716 if let Some(response_str) = cache.read(key) {
1717 return Ok((serde_json::from_str(&response_str)?, None));
1718 }
1719
1720 Some((cache, request_str, key))
1721 } else {
1722 None
1723 };
1724
1725 let (response, usage) = Self::send_api_request(
1726 |builder| {
1727 let req = builder
1728 .uri(url.as_ref())
1729 .body(serde_json::to_string(&request)?.into());
1730 Ok(req?)
1731 },
1732 client,
1733 llm_token,
1734 app_version,
1735 true,
1736 )
1737 .await?;
1738
1739 #[cfg(feature = "cli-support")]
1740 if let Some((cache, request, key)) = cache_key {
1741 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1742 }
1743
1744 Ok((response, usage))
1745 }
1746
1747 fn handle_api_response<T>(
1748 this: &WeakEntity<Self>,
1749 response: Result<(T, Option<EditPredictionUsage>)>,
1750 cx: &mut gpui::AsyncApp,
1751 ) -> Result<T> {
1752 match response {
1753 Ok((data, usage)) => {
1754 if let Some(usage) = usage {
1755 this.update(cx, |this, cx| {
1756 this.user_store.update(cx, |user_store, cx| {
1757 user_store.update_edit_prediction_usage(usage, cx);
1758 });
1759 })
1760 .ok();
1761 }
1762 Ok(data)
1763 }
1764 Err(err) => {
1765 if err.is::<ZedUpdateRequiredError>() {
1766 cx.update(|cx| {
1767 this.update(cx, |this, _cx| {
1768 this.update_required = true;
1769 })
1770 .ok();
1771
1772 let error_message: SharedString = err.to_string().into();
1773 show_app_notification(
1774 NotificationId::unique::<ZedUpdateRequiredError>(),
1775 cx,
1776 move |cx| {
1777 cx.new(|cx| {
1778 ErrorMessagePrompt::new(error_message.clone(), cx)
1779 .with_link_button("Update Zed", "https://zed.dev/releases")
1780 })
1781 },
1782 );
1783 })
1784 .ok();
1785 }
1786 Err(err)
1787 }
1788 }
1789 }
1790
1791 async fn send_api_request<Res>(
1792 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1793 client: Arc<Client>,
1794 llm_token: LlmApiToken,
1795 app_version: Version,
1796 require_auth: bool,
1797 ) -> Result<(Res, Option<EditPredictionUsage>)>
1798 where
1799 Res: DeserializeOwned,
1800 {
1801 let http_client = client.http_client();
1802
1803 let mut token = if require_auth {
1804 Some(llm_token.acquire(&client).await?)
1805 } else {
1806 llm_token.acquire(&client).await.ok()
1807 };
1808 let mut did_retry = false;
1809
1810 loop {
1811 let request_builder = http_client::Request::builder().method(Method::POST);
1812
1813 let mut request_builder = request_builder
1814 .header("Content-Type", "application/json")
1815 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
1816
1817 // Only add Authorization header if we have a token
1818 if let Some(ref token_value) = token {
1819 request_builder =
1820 request_builder.header("Authorization", format!("Bearer {}", token_value));
1821 }
1822
1823 let request = build(request_builder)?;
1824
1825 let mut response = http_client.send(request).await?;
1826
1827 if let Some(minimum_required_version) = response
1828 .headers()
1829 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1830 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1831 {
1832 anyhow::ensure!(
1833 app_version >= minimum_required_version,
1834 ZedUpdateRequiredError {
1835 minimum_version: minimum_required_version
1836 }
1837 );
1838 }
1839
1840 if response.status().is_success() {
1841 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1842
1843 let mut body = Vec::new();
1844 response.body_mut().read_to_end(&mut body).await?;
1845 return Ok((serde_json::from_slice(&body)?, usage));
1846 } else if !did_retry
1847 && token.is_some()
1848 && response
1849 .headers()
1850 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1851 .is_some()
1852 {
1853 did_retry = true;
1854 token = Some(llm_token.refresh(&client).await?);
1855 } else {
1856 let mut body = String::new();
1857 response.body_mut().read_to_string(&mut body).await?;
1858 anyhow::bail!(
1859 "Request failed with status: {:?}\nBody: {}",
1860 response.status(),
1861 body
1862 );
1863 }
1864 }
1865 }
1866
1867 pub fn refresh_context(
1868 &mut self,
1869 project: &Entity<Project>,
1870 buffer: &Entity<language::Buffer>,
1871 cursor_position: language::Anchor,
1872 cx: &mut Context<Self>,
1873 ) {
1874 if self.use_context {
1875 self.get_or_init_project(project, cx)
1876 .context
1877 .update(cx, |store, cx| {
1878 store.refresh(buffer.clone(), cursor_position, cx);
1879 });
1880 }
1881 }
1882
1883 #[cfg(feature = "cli-support")]
1884 pub fn set_context_for_buffer(
1885 &mut self,
1886 project: &Entity<Project>,
1887 related_files: Vec<RelatedFile>,
1888 cx: &mut Context<Self>,
1889 ) {
1890 self.get_or_init_project(project, cx)
1891 .context
1892 .update(cx, |store, _| {
1893 store.set_related_files(related_files);
1894 });
1895 }
1896
1897 fn is_file_open_source(
1898 &self,
1899 project: &Entity<Project>,
1900 file: &Arc<dyn File>,
1901 cx: &App,
1902 ) -> bool {
1903 if !file.is_local() || file.is_private() {
1904 return false;
1905 }
1906 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1907 return false;
1908 };
1909 project_state
1910 .license_detection_watchers
1911 .get(&file.worktree_id(cx))
1912 .as_ref()
1913 .is_some_and(|watcher| watcher.is_project_open_source())
1914 }
1915
1916 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1917 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1918 }
1919
1920 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1921 if !self.data_collection_choice.is_enabled() {
1922 return false;
1923 }
1924 events.iter().all(|event| {
1925 matches!(
1926 event.as_ref(),
1927 zeta_prompt::Event::BufferChange {
1928 in_open_source_repo: true,
1929 ..
1930 }
1931 )
1932 })
1933 }
1934
1935 fn load_data_collection_choice() -> DataCollectionChoice {
1936 let choice = KEY_VALUE_STORE
1937 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1938 .log_err()
1939 .flatten();
1940
1941 match choice.as_deref() {
1942 Some("true") => DataCollectionChoice::Enabled,
1943 Some("false") => DataCollectionChoice::Disabled,
1944 Some(_) => {
1945 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1946 DataCollectionChoice::NotAnswered
1947 }
1948 None => DataCollectionChoice::NotAnswered,
1949 }
1950 }
1951
1952 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1953 self.data_collection_choice = self.data_collection_choice.toggle();
1954 let new_choice = self.data_collection_choice;
1955 db::write_and_log(cx, move || {
1956 KEY_VALUE_STORE.write_kvp(
1957 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1958 new_choice.is_enabled().to_string(),
1959 )
1960 });
1961 }
1962
1963 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1964 self.shown_predictions.iter()
1965 }
1966
1967 pub fn shown_completions_len(&self) -> usize {
1968 self.shown_predictions.len()
1969 }
1970
1971 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1972 self.rated_predictions.contains(id)
1973 }
1974
1975 pub fn rate_prediction(
1976 &mut self,
1977 prediction: &EditPrediction,
1978 rating: EditPredictionRating,
1979 feedback: String,
1980 cx: &mut Context<Self>,
1981 ) {
1982 self.rated_predictions.insert(prediction.id.clone());
1983 telemetry::event!(
1984 "Edit Prediction Rated",
1985 rating,
1986 inputs = prediction.inputs,
1987 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1988 feedback
1989 );
1990 self.client.telemetry().flush_events().detach();
1991 cx.notify();
1992 }
1993
1994 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1995 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1996 && all_language_settings(None, cx).edit_predictions.use_context;
1997 }
1998}
1999
2000#[derive(Error, Debug)]
2001#[error(
2002 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2003)]
2004pub struct ZedUpdateRequiredError {
2005 minimum_version: Version,
2006}
2007
2008#[cfg(feature = "cli-support")]
2009pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2010
2011#[cfg(feature = "cli-support")]
2012#[derive(Debug, Clone, Copy, PartialEq)]
2013pub enum EvalCacheEntryKind {
2014 Context,
2015 Search,
2016 Prediction,
2017}
2018
2019#[cfg(feature = "cli-support")]
2020impl std::fmt::Display for EvalCacheEntryKind {
2021 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2022 match self {
2023 EvalCacheEntryKind::Search => write!(f, "search"),
2024 EvalCacheEntryKind::Context => write!(f, "context"),
2025 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2026 }
2027 }
2028}
2029
2030#[cfg(feature = "cli-support")]
2031pub trait EvalCache: Send + Sync {
2032 fn read(&self, key: EvalCacheKey) -> Option<String>;
2033 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2034}
2035
2036#[derive(Debug, Clone, Copy)]
2037pub enum DataCollectionChoice {
2038 NotAnswered,
2039 Enabled,
2040 Disabled,
2041}
2042
2043impl DataCollectionChoice {
2044 pub fn is_enabled(self) -> bool {
2045 match self {
2046 Self::Enabled => true,
2047 Self::NotAnswered | Self::Disabled => false,
2048 }
2049 }
2050
2051 pub fn is_answered(self) -> bool {
2052 match self {
2053 Self::Enabled | Self::Disabled => true,
2054 Self::NotAnswered => false,
2055 }
2056 }
2057
2058 #[must_use]
2059 pub fn toggle(&self) -> DataCollectionChoice {
2060 match self {
2061 Self::Enabled => Self::Disabled,
2062 Self::Disabled => Self::Enabled,
2063 Self::NotAnswered => Self::Enabled,
2064 }
2065 }
2066}
2067
2068impl From<bool> for DataCollectionChoice {
2069 fn from(value: bool) -> Self {
2070 match value {
2071 true => DataCollectionChoice::Enabled,
2072 false => DataCollectionChoice::Disabled,
2073 }
2074 }
2075}
2076
2077struct ZedPredictUpsell;
2078
2079impl Dismissable for ZedPredictUpsell {
2080 const KEY: &'static str = "dismissed-edit-predict-upsell";
2081
2082 fn dismissed() -> bool {
2083 // To make this backwards compatible with older versions of Zed, we
2084 // check if the user has seen the previous Edit Prediction Onboarding
2085 // before, by checking the data collection choice which was written to
2086 // the database once the user clicked on "Accept and Enable"
2087 if KEY_VALUE_STORE
2088 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2089 .log_err()
2090 .is_some_and(|s| s.is_some())
2091 {
2092 return true;
2093 }
2094
2095 KEY_VALUE_STORE
2096 .read_kvp(Self::KEY)
2097 .log_err()
2098 .is_some_and(|s| s.is_some())
2099 }
2100}
2101
2102pub fn should_show_upsell_modal() -> bool {
2103 !ZedPredictUpsell::dismissed()
2104}
2105
2106pub fn init(cx: &mut App) {
2107 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2108 workspace.register_action(
2109 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2110 ZedPredictModal::toggle(
2111 workspace,
2112 workspace.user_store().clone(),
2113 workspace.client().clone(),
2114 window,
2115 cx,
2116 )
2117 },
2118 );
2119
2120 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2121 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2122 settings
2123 .project
2124 .all_languages
2125 .features
2126 .get_or_insert_default()
2127 .edit_prediction_provider = Some(EditPredictionProvider::None)
2128 });
2129 });
2130 })
2131 .detach();
2132}