1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
24 http_client::{self, AsyncBody, Method},
25 prelude::*,
26};
27use language::language_settings::all_language_settings;
28use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
29use language::{BufferSnapshot, OffsetRangeExt};
30use language_model::{LlmApiToken, RefreshLlmTokenListener};
31use project::{Project, ProjectPath, WorktreeId};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
36use std::collections::{VecDeque, hash_map};
37use workspace::Workspace;
38
39use std::ops::Range;
40use std::path::Path;
41use std::rc::Rc;
42use std::str::FromStr as _;
43use std::sync::{Arc, LazyLock};
44use std::time::{Duration, Instant};
45use std::{env, mem};
46use thiserror::Error;
47use util::{RangeExt as _, ResultExt as _};
48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
49
50pub mod cursor_excerpt;
51pub mod example_spec;
52mod license_detection;
53pub mod mercury;
54mod onboarding_modal;
55pub mod open_ai_response;
56mod prediction;
57pub mod sweep_ai;
58
59#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
60pub mod udiff;
61
62mod zed_edit_prediction_delegate;
63pub mod zeta1;
64pub mod zeta2;
65
66#[cfg(test)]
67mod edit_prediction_tests;
68
69use crate::license_detection::LicenseDetectionWatcher;
70use crate::mercury::Mercury;
71use crate::onboarding_modal::ZedPredictModal;
72pub use crate::prediction::EditPrediction;
73pub use crate::prediction::EditPredictionId;
74use crate::prediction::EditPredictionResult;
75pub use crate::sweep_ai::SweepAi;
76pub use language_model::ApiKeyState;
77pub use telemetry_events::EditPredictionRating;
78pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
79
80actions!(
81 edit_prediction,
82 [
83 /// Resets the edit prediction onboarding state.
84 ResetOnboarding,
85 /// Clears the edit prediction history.
86 ClearHistory,
87 ]
88);
89
90/// Maximum number of events to track.
91const EVENT_COUNT_MAX: usize = 6;
92const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
93const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
94const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
95const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
96
97pub struct SweepFeatureFlag;
98
99impl FeatureFlag for SweepFeatureFlag {
100 const NAME: &str = "sweep-ai";
101}
102
103pub struct MercuryFeatureFlag;
104
105impl FeatureFlag for MercuryFeatureFlag {
106 const NAME: &str = "mercury";
107}
108
109pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
110 context: EditPredictionExcerptOptions {
111 max_bytes: 512,
112 min_bytes: 128,
113 target_before_cursor_over_total_bytes: 0.5,
114 },
115 prompt_format: PromptFormat::DEFAULT,
116};
117
118static USE_OLLAMA: LazyLock<bool> =
119 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
120
121static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
122 match env::var("ZED_ZETA2_MODEL").as_deref() {
123 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
124 Ok(model) => model,
125 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
126 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
127 }
128 .to_string()
129});
130static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
131 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
132 if *USE_OLLAMA {
133 Some("http://localhost:11434/v1/chat/completions".into())
134 } else {
135 None
136 }
137 })
138});
139
140pub struct Zeta2FeatureFlag;
141
142impl FeatureFlag for Zeta2FeatureFlag {
143 const NAME: &'static str = "zeta2";
144
145 fn enabled_for_staff() -> bool {
146 true
147 }
148}
149
150#[derive(Clone)]
151struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
152
153impl Global for EditPredictionStoreGlobal {}
154
155pub struct EditPredictionStore {
156 client: Arc<Client>,
157 user_store: Entity<UserStore>,
158 llm_token: LlmApiToken,
159 _llm_token_subscription: Subscription,
160 projects: HashMap<EntityId, ProjectState>,
161 use_context: bool,
162 options: ZetaOptions,
163 update_required: bool,
164 #[cfg(feature = "cli-support")]
165 eval_cache: Option<Arc<dyn EvalCache>>,
166 edit_prediction_model: EditPredictionModel,
167 pub sweep_ai: SweepAi,
168 pub mercury: Mercury,
169 data_collection_choice: DataCollectionChoice,
170 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
171 shown_predictions: VecDeque<EditPrediction>,
172 rated_predictions: HashSet<EditPredictionId>,
173}
174
175#[derive(Copy, Clone, Default, PartialEq, Eq)]
176pub enum EditPredictionModel {
177 #[default]
178 Zeta1,
179 Zeta2,
180 Sweep,
181 Mercury,
182}
183
184pub struct EditPredictionModelInput {
185 project: Entity<Project>,
186 buffer: Entity<Buffer>,
187 snapshot: BufferSnapshot,
188 position: Anchor,
189 events: Vec<Arc<zeta_prompt::Event>>,
190 related_files: Arc<[RelatedFile]>,
191 recent_paths: VecDeque<ProjectPath>,
192 trigger: PredictEditsRequestTrigger,
193 diagnostic_search_range: Range<Point>,
194 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
195}
196
197#[derive(Debug, Clone, PartialEq)]
198pub struct ZetaOptions {
199 pub context: EditPredictionExcerptOptions,
200 pub prompt_format: predict_edits_v3::PromptFormat,
201}
202
203#[derive(Debug)]
204pub enum DebugEvent {
205 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
206 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
207 EditPredictionStarted(EditPredictionStartedDebugEvent),
208 EditPredictionFinished(EditPredictionFinishedDebugEvent),
209}
210
211#[derive(Debug)]
212pub struct ContextRetrievalStartedDebugEvent {
213 pub project_entity_id: EntityId,
214 pub timestamp: Instant,
215 pub search_prompt: String,
216}
217
218#[derive(Debug)]
219pub struct ContextRetrievalFinishedDebugEvent {
220 pub project_entity_id: EntityId,
221 pub timestamp: Instant,
222 pub metadata: Vec<(&'static str, SharedString)>,
223}
224
225#[derive(Debug)]
226pub struct EditPredictionStartedDebugEvent {
227 pub buffer: WeakEntity<Buffer>,
228 pub position: Anchor,
229 pub prompt: Option<String>,
230}
231
232#[derive(Debug)]
233pub struct EditPredictionFinishedDebugEvent {
234 pub buffer: WeakEntity<Buffer>,
235 pub position: Anchor,
236 pub model_output: Option<String>,
237}
238
239pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
240
241struct ProjectState {
242 events: VecDeque<Arc<zeta_prompt::Event>>,
243 last_event: Option<LastEvent>,
244 recent_paths: VecDeque<ProjectPath>,
245 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
246 current_prediction: Option<CurrentEditPrediction>,
247 next_pending_prediction_id: usize,
248 pending_predictions: ArrayVec<PendingPrediction, 2>,
249 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
250 last_prediction_refresh: Option<(EntityId, Instant)>,
251 cancelled_predictions: HashSet<usize>,
252 context: Entity<RelatedExcerptStore>,
253 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
254 _subscription: gpui::Subscription,
255}
256
257impl ProjectState {
258 pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
259 self.events
260 .iter()
261 .cloned()
262 .chain(
263 self.last_event
264 .as_ref()
265 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
266 )
267 .collect()
268 }
269
270 pub fn events_split_by_pause(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
271 self.events
272 .iter()
273 .cloned()
274 .chain(self.last_event.as_ref().iter().flat_map(|event| {
275 let (one, two) = event.split_by_pause();
276 let one = one.finalize(&self.license_detection_watchers, cx);
277 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
278 one.into_iter().chain(two)
279 }))
280 .collect()
281 }
282
283 fn cancel_pending_prediction(
284 &mut self,
285 pending_prediction: PendingPrediction,
286 cx: &mut Context<EditPredictionStore>,
287 ) {
288 self.cancelled_predictions.insert(pending_prediction.id);
289
290 cx.spawn(async move |this, cx| {
291 let Some(prediction_id) = pending_prediction.task.await else {
292 return;
293 };
294
295 this.update(cx, |this, _cx| {
296 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
297 })
298 .ok();
299 })
300 .detach()
301 }
302
303 fn active_buffer(
304 &self,
305 project: &Entity<Project>,
306 cx: &App,
307 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
308 let project = project.read(cx);
309 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
310 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
311 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
312 Some((active_buffer, registered_buffer.last_position))
313 }
314}
315
316#[derive(Debug, Clone)]
317struct CurrentEditPrediction {
318 pub requested_by: PredictionRequestedBy,
319 pub prediction: EditPrediction,
320 pub was_shown: bool,
321}
322
323impl CurrentEditPrediction {
324 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
325 let Some(new_edits) = self
326 .prediction
327 .interpolate(&self.prediction.buffer.read(cx))
328 else {
329 return false;
330 };
331
332 if self.prediction.buffer != old_prediction.prediction.buffer {
333 return true;
334 }
335
336 let Some(old_edits) = old_prediction
337 .prediction
338 .interpolate(&old_prediction.prediction.buffer.read(cx))
339 else {
340 return true;
341 };
342
343 let requested_by_buffer_id = self.requested_by.buffer_id();
344
345 // This reduces the occurrence of UI thrash from replacing edits
346 //
347 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
348 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
349 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
350 && old_edits.len() == 1
351 && new_edits.len() == 1
352 {
353 let (old_range, old_text) = &old_edits[0];
354 let (new_range, new_text) = &new_edits[0];
355 new_range == old_range && new_text.starts_with(old_text.as_ref())
356 } else {
357 true
358 }
359 }
360}
361
362#[derive(Debug, Clone)]
363enum PredictionRequestedBy {
364 DiagnosticsUpdate,
365 Buffer(EntityId),
366}
367
368impl PredictionRequestedBy {
369 pub fn buffer_id(&self) -> Option<EntityId> {
370 match self {
371 PredictionRequestedBy::DiagnosticsUpdate => None,
372 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
373 }
374 }
375}
376
377#[derive(Debug)]
378struct PendingPrediction {
379 id: usize,
380 task: Task<Option<EditPredictionId>>,
381}
382
383/// A prediction from the perspective of a buffer.
384#[derive(Debug)]
385enum BufferEditPrediction<'a> {
386 Local { prediction: &'a EditPrediction },
387 Jump { prediction: &'a EditPrediction },
388}
389
390#[cfg(test)]
391impl std::ops::Deref for BufferEditPrediction<'_> {
392 type Target = EditPrediction;
393
394 fn deref(&self) -> &Self::Target {
395 match self {
396 BufferEditPrediction::Local { prediction } => prediction,
397 BufferEditPrediction::Jump { prediction } => prediction,
398 }
399 }
400}
401
402struct RegisteredBuffer {
403 file: Option<Arc<dyn File>>,
404 snapshot: TextBufferSnapshot,
405 last_position: Option<Anchor>,
406 _subscriptions: [gpui::Subscription; 2],
407}
408
409#[derive(Clone)]
410struct LastEvent {
411 old_snapshot: TextBufferSnapshot,
412 new_snapshot: TextBufferSnapshot,
413 old_file: Option<Arc<dyn File>>,
414 new_file: Option<Arc<dyn File>>,
415 end_edit_anchor: Option<Anchor>,
416 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
417 last_edit_time: Option<Instant>,
418}
419
420impl LastEvent {
421 pub fn finalize(
422 &self,
423 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
424 cx: &App,
425 ) -> Option<Arc<zeta_prompt::Event>> {
426 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
427 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
428
429 let in_open_source_repo =
430 [self.new_file.as_ref(), self.old_file.as_ref()]
431 .iter()
432 .all(|file| {
433 file.is_some_and(|file| {
434 license_detection_watchers
435 .get(&file.worktree_id(cx))
436 .is_some_and(|watcher| watcher.is_project_open_source())
437 })
438 });
439
440 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
441
442 if path == old_path && diff.is_empty() {
443 None
444 } else {
445 Some(Arc::new(zeta_prompt::Event::BufferChange {
446 old_path,
447 path,
448 diff,
449 in_open_source_repo,
450 // TODO: Actually detect if this edit was predicted or not
451 predicted: false,
452 }))
453 }
454 }
455
456 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
457 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
458 return (self.clone(), None);
459 };
460
461 let before = LastEvent {
462 old_snapshot: self.old_snapshot.clone(),
463 new_snapshot: boundary_snapshot.clone(),
464 old_file: self.old_file.clone(),
465 new_file: self.new_file.clone(),
466 end_edit_anchor: self.end_edit_anchor,
467 snapshot_after_last_editing_pause: None,
468 last_edit_time: self.last_edit_time,
469 };
470
471 let after = LastEvent {
472 old_snapshot: boundary_snapshot.clone(),
473 new_snapshot: self.new_snapshot.clone(),
474 old_file: self.old_file.clone(),
475 new_file: self.new_file.clone(),
476 end_edit_anchor: self.end_edit_anchor,
477 snapshot_after_last_editing_pause: None,
478 last_edit_time: self.last_edit_time,
479 };
480
481 (before, Some(after))
482 }
483}
484
485fn buffer_path_with_id_fallback(
486 file: Option<&Arc<dyn File>>,
487 snapshot: &TextBufferSnapshot,
488 cx: &App,
489) -> Arc<Path> {
490 if let Some(file) = file {
491 file.full_path(cx).into()
492 } else {
493 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
494 }
495}
496
497impl EditPredictionStore {
498 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
499 cx.try_global::<EditPredictionStoreGlobal>()
500 .map(|global| global.0.clone())
501 }
502
503 pub fn global(
504 client: &Arc<Client>,
505 user_store: &Entity<UserStore>,
506 cx: &mut App,
507 ) -> Entity<Self> {
508 cx.try_global::<EditPredictionStoreGlobal>()
509 .map(|global| global.0.clone())
510 .unwrap_or_else(|| {
511 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
512 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
513 ep_store
514 })
515 }
516
517 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
518 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
519 let data_collection_choice = Self::load_data_collection_choice();
520
521 let llm_token = LlmApiToken::default();
522
523 let (reject_tx, reject_rx) = mpsc::unbounded();
524 cx.background_spawn({
525 let client = client.clone();
526 let llm_token = llm_token.clone();
527 let app_version = AppVersion::global(cx);
528 let background_executor = cx.background_executor().clone();
529 async move {
530 Self::handle_rejected_predictions(
531 reject_rx,
532 client,
533 llm_token,
534 app_version,
535 background_executor,
536 )
537 .await
538 }
539 })
540 .detach();
541
542 let mut this = Self {
543 projects: HashMap::default(),
544 client,
545 user_store,
546 options: DEFAULT_OPTIONS,
547 use_context: false,
548 llm_token,
549 _llm_token_subscription: cx.subscribe(
550 &refresh_llm_token_listener,
551 |this, _listener, _event, cx| {
552 let client = this.client.clone();
553 let llm_token = this.llm_token.clone();
554 cx.spawn(async move |_this, _cx| {
555 llm_token.refresh(&client).await?;
556 anyhow::Ok(())
557 })
558 .detach_and_log_err(cx);
559 },
560 ),
561 update_required: false,
562 #[cfg(feature = "cli-support")]
563 eval_cache: None,
564 edit_prediction_model: EditPredictionModel::Zeta2,
565 sweep_ai: SweepAi::new(cx),
566 mercury: Mercury::new(cx),
567 data_collection_choice,
568 reject_predictions_tx: reject_tx,
569 rated_predictions: Default::default(),
570 shown_predictions: Default::default(),
571 };
572
573 this.configure_context_retrieval(cx);
574 let weak_this = cx.weak_entity();
575 cx.on_flags_ready(move |_, cx| {
576 weak_this
577 .update(cx, |this, cx| this.configure_context_retrieval(cx))
578 .ok();
579 })
580 .detach();
581 cx.observe_global::<SettingsStore>(|this, cx| {
582 this.configure_context_retrieval(cx);
583 })
584 .detach();
585
586 this
587 }
588
589 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
590 self.edit_prediction_model = model;
591 }
592
593 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
594 self.sweep_ai.api_token.read(cx).has_key()
595 }
596
597 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
598 self.mercury.api_token.read(cx).has_key()
599 }
600
601 #[cfg(feature = "cli-support")]
602 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
603 self.eval_cache = Some(cache);
604 }
605
606 pub fn options(&self) -> &ZetaOptions {
607 &self.options
608 }
609
610 pub fn set_options(&mut self, options: ZetaOptions) {
611 self.options = options;
612 }
613
614 pub fn set_use_context(&mut self, use_context: bool) {
615 self.use_context = use_context;
616 }
617
618 pub fn clear_history(&mut self) {
619 for project_state in self.projects.values_mut() {
620 project_state.events.clear();
621 }
622 }
623
624 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
625 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
626 project_state.events.clear();
627 }
628 }
629
630 pub fn edit_history_for_project(
631 &self,
632 project: &Entity<Project>,
633 cx: &App,
634 ) -> Vec<Arc<zeta_prompt::Event>> {
635 self.projects
636 .get(&project.entity_id())
637 .map(|project_state| project_state.events(cx))
638 .unwrap_or_default()
639 }
640
641 pub fn edit_history_for_project_with_pause_split_last_event(
642 &self,
643 project: &Entity<Project>,
644 cx: &App,
645 ) -> Vec<Arc<zeta_prompt::Event>> {
646 self.projects
647 .get(&project.entity_id())
648 .map(|project_state| project_state.events_split_by_pause(cx))
649 .unwrap_or_default()
650 }
651
652 pub fn context_for_project<'a>(
653 &'a self,
654 project: &Entity<Project>,
655 cx: &'a App,
656 ) -> Arc<[RelatedFile]> {
657 self.projects
658 .get(&project.entity_id())
659 .map(|project| project.context.read(cx).related_files())
660 .unwrap_or_else(|| vec![].into())
661 }
662
663 pub fn context_for_project_with_buffers<'a>(
664 &'a self,
665 project: &Entity<Project>,
666 cx: &'a App,
667 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
668 self.projects
669 .get(&project.entity_id())
670 .map(|project| project.context.read(cx).related_files_with_buffers())
671 }
672
673 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
674 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
675 self.user_store.read(cx).edit_prediction_usage()
676 } else {
677 None
678 }
679 }
680
681 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
682 self.get_or_init_project(project, cx);
683 }
684
685 pub fn register_buffer(
686 &mut self,
687 buffer: &Entity<Buffer>,
688 project: &Entity<Project>,
689 cx: &mut Context<Self>,
690 ) {
691 let project_state = self.get_or_init_project(project, cx);
692 Self::register_buffer_impl(project_state, buffer, project, cx);
693 }
694
695 fn get_or_init_project(
696 &mut self,
697 project: &Entity<Project>,
698 cx: &mut Context<Self>,
699 ) -> &mut ProjectState {
700 let entity_id = project.entity_id();
701 self.projects
702 .entry(entity_id)
703 .or_insert_with(|| ProjectState {
704 context: {
705 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
706 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
707 this.handle_excerpt_store_event(entity_id, event);
708 })
709 .detach();
710 related_excerpt_store
711 },
712 events: VecDeque::new(),
713 last_event: None,
714 recent_paths: VecDeque::new(),
715 debug_tx: None,
716 registered_buffers: HashMap::default(),
717 current_prediction: None,
718 cancelled_predictions: HashSet::default(),
719 pending_predictions: ArrayVec::new(),
720 next_pending_prediction_id: 0,
721 last_prediction_refresh: None,
722 license_detection_watchers: HashMap::default(),
723 _subscription: cx.subscribe(&project, Self::handle_project_event),
724 })
725 }
726
727 pub fn remove_project(&mut self, project: &Entity<Project>) {
728 self.projects.remove(&project.entity_id());
729 }
730
731 fn handle_excerpt_store_event(
732 &mut self,
733 project_entity_id: EntityId,
734 event: &RelatedExcerptStoreEvent,
735 ) {
736 if let Some(project_state) = self.projects.get(&project_entity_id) {
737 if let Some(debug_tx) = project_state.debug_tx.clone() {
738 match event {
739 RelatedExcerptStoreEvent::StartedRefresh => {
740 debug_tx
741 .unbounded_send(DebugEvent::ContextRetrievalStarted(
742 ContextRetrievalStartedDebugEvent {
743 project_entity_id: project_entity_id,
744 timestamp: Instant::now(),
745 search_prompt: String::new(),
746 },
747 ))
748 .ok();
749 }
750 RelatedExcerptStoreEvent::FinishedRefresh {
751 cache_hit_count,
752 cache_miss_count,
753 mean_definition_latency,
754 max_definition_latency,
755 } => {
756 debug_tx
757 .unbounded_send(DebugEvent::ContextRetrievalFinished(
758 ContextRetrievalFinishedDebugEvent {
759 project_entity_id: project_entity_id,
760 timestamp: Instant::now(),
761 metadata: vec![
762 (
763 "Cache Hits",
764 format!(
765 "{}/{}",
766 cache_hit_count,
767 cache_hit_count + cache_miss_count
768 )
769 .into(),
770 ),
771 (
772 "Max LSP Time",
773 format!("{} ms", max_definition_latency.as_millis())
774 .into(),
775 ),
776 (
777 "Mean LSP Time",
778 format!("{} ms", mean_definition_latency.as_millis())
779 .into(),
780 ),
781 ],
782 },
783 ))
784 .ok();
785 }
786 }
787 }
788 }
789 }
790
791 pub fn debug_info(
792 &mut self,
793 project: &Entity<Project>,
794 cx: &mut Context<Self>,
795 ) -> mpsc::UnboundedReceiver<DebugEvent> {
796 let project_state = self.get_or_init_project(project, cx);
797 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
798 project_state.debug_tx = Some(debug_watch_tx);
799 debug_watch_rx
800 }
801
802 fn handle_project_event(
803 &mut self,
804 project: Entity<Project>,
805 event: &project::Event,
806 cx: &mut Context<Self>,
807 ) {
808 // TODO [zeta2] init with recent paths
809 match event {
810 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
811 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
812 return;
813 };
814 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
815 if let Some(path) = path {
816 if let Some(ix) = project_state
817 .recent_paths
818 .iter()
819 .position(|probe| probe == &path)
820 {
821 project_state.recent_paths.remove(ix);
822 }
823 project_state.recent_paths.push_front(path);
824 }
825 }
826 project::Event::DiagnosticsUpdated { .. } => {
827 if cx.has_flag::<Zeta2FeatureFlag>() {
828 self.refresh_prediction_from_diagnostics(project, cx);
829 }
830 }
831 _ => (),
832 }
833 }
834
835 fn register_buffer_impl<'a>(
836 project_state: &'a mut ProjectState,
837 buffer: &Entity<Buffer>,
838 project: &Entity<Project>,
839 cx: &mut Context<Self>,
840 ) -> &'a mut RegisteredBuffer {
841 let buffer_id = buffer.entity_id();
842
843 if let Some(file) = buffer.read(cx).file() {
844 let worktree_id = file.worktree_id(cx);
845 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
846 project_state
847 .license_detection_watchers
848 .entry(worktree_id)
849 .or_insert_with(|| {
850 let project_entity_id = project.entity_id();
851 cx.observe_release(&worktree, move |this, _worktree, _cx| {
852 let Some(project_state) = this.projects.get_mut(&project_entity_id)
853 else {
854 return;
855 };
856 project_state
857 .license_detection_watchers
858 .remove(&worktree_id);
859 })
860 .detach();
861 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
862 });
863 }
864 }
865
866 match project_state.registered_buffers.entry(buffer_id) {
867 hash_map::Entry::Occupied(entry) => entry.into_mut(),
868 hash_map::Entry::Vacant(entry) => {
869 let buf = buffer.read(cx);
870 let snapshot = buf.text_snapshot();
871 let file = buf.file().cloned();
872 let project_entity_id = project.entity_id();
873 entry.insert(RegisteredBuffer {
874 snapshot,
875 file,
876 last_position: None,
877 _subscriptions: [
878 cx.subscribe(buffer, {
879 let project = project.downgrade();
880 move |this, buffer, event, cx| {
881 if let language::BufferEvent::Edited = event
882 && let Some(project) = project.upgrade()
883 {
884 this.report_changes_for_buffer(&buffer, &project, cx);
885 }
886 }
887 }),
888 cx.observe_release(buffer, move |this, _buffer, _cx| {
889 let Some(project_state) = this.projects.get_mut(&project_entity_id)
890 else {
891 return;
892 };
893 project_state.registered_buffers.remove(&buffer_id);
894 }),
895 ],
896 })
897 }
898 }
899 }
900
901 fn report_changes_for_buffer(
902 &mut self,
903 buffer: &Entity<Buffer>,
904 project: &Entity<Project>,
905 cx: &mut Context<Self>,
906 ) {
907 let project_state = self.get_or_init_project(project, cx);
908 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
909
910 let buf = buffer.read(cx);
911 let new_file = buf.file().cloned();
912 let new_snapshot = buf.text_snapshot();
913 if new_snapshot.version == registered_buffer.snapshot.version {
914 return;
915 }
916
917 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
918 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
919 let end_edit_anchor = new_snapshot
920 .anchored_edits_since::<Point>(&old_snapshot.version)
921 .last()
922 .map(|(_, range)| range.end);
923 let events = &mut project_state.events;
924
925 let now = cx.background_executor().now();
926 if let Some(last_event) = project_state.last_event.as_mut() {
927 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
928 == last_event.new_snapshot.remote_id()
929 && old_snapshot.version == last_event.new_snapshot.version;
930
931 let should_coalesce = is_next_snapshot_of_same_buffer
932 && end_edit_anchor
933 .as_ref()
934 .zip(last_event.end_edit_anchor.as_ref())
935 .is_some_and(|(a, b)| {
936 let a = a.to_point(&new_snapshot);
937 let b = b.to_point(&new_snapshot);
938 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
939 });
940
941 if should_coalesce {
942 let pause_elapsed = last_event
943 .last_edit_time
944 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
945 .unwrap_or(false);
946 if pause_elapsed {
947 last_event.snapshot_after_last_editing_pause =
948 Some(last_event.new_snapshot.clone());
949 }
950
951 last_event.end_edit_anchor = end_edit_anchor;
952 last_event.new_snapshot = new_snapshot;
953 last_event.last_edit_time = Some(now);
954 return;
955 }
956 }
957
958 if events.len() + 1 >= EVENT_COUNT_MAX {
959 events.pop_front();
960 }
961
962 if let Some(event) = project_state.last_event.take() {
963 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
964 }
965
966 project_state.last_event = Some(LastEvent {
967 old_file,
968 new_file,
969 old_snapshot,
970 new_snapshot,
971 end_edit_anchor,
972 snapshot_after_last_editing_pause: None,
973 last_edit_time: Some(now),
974 });
975 }
976
977 fn prediction_at(
978 &mut self,
979 buffer: &Entity<Buffer>,
980 position: Option<language::Anchor>,
981 project: &Entity<Project>,
982 cx: &App,
983 ) -> Option<BufferEditPrediction<'_>> {
984 let project_state = self.projects.get_mut(&project.entity_id())?;
985 if let Some(position) = position
986 && let Some(buffer) = project_state
987 .registered_buffers
988 .get_mut(&buffer.entity_id())
989 {
990 buffer.last_position = Some(position);
991 }
992
993 let CurrentEditPrediction {
994 requested_by,
995 prediction,
996 ..
997 } = project_state.current_prediction.as_ref()?;
998
999 if prediction.targets_buffer(buffer.read(cx)) {
1000 Some(BufferEditPrediction::Local { prediction })
1001 } else {
1002 let show_jump = match requested_by {
1003 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1004 requested_by_buffer_id == &buffer.entity_id()
1005 }
1006 PredictionRequestedBy::DiagnosticsUpdate => true,
1007 };
1008
1009 if show_jump {
1010 Some(BufferEditPrediction::Jump { prediction })
1011 } else {
1012 None
1013 }
1014 }
1015 }
1016
1017 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1018 match self.edit_prediction_model {
1019 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1020 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1021 }
1022
1023 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1024 return;
1025 };
1026
1027 let Some(prediction) = project_state.current_prediction.take() else {
1028 return;
1029 };
1030 let request_id = prediction.prediction.id.to_string();
1031 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1032 project_state.cancel_pending_prediction(pending_prediction, cx);
1033 }
1034
1035 let client = self.client.clone();
1036 let llm_token = self.llm_token.clone();
1037 let app_version = AppVersion::global(cx);
1038 cx.spawn(async move |this, cx| {
1039 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
1040 http_client::Url::parse(&predict_edits_url)?
1041 } else {
1042 client
1043 .http_client()
1044 .build_zed_llm_url("/predict_edits/accept", &[])?
1045 };
1046
1047 let response = cx
1048 .background_spawn(Self::send_api_request::<()>(
1049 move |builder| {
1050 let req = builder.uri(url.as_ref()).body(
1051 serde_json::to_string(&AcceptEditPredictionBody {
1052 request_id: request_id.clone(),
1053 })?
1054 .into(),
1055 );
1056 Ok(req?)
1057 },
1058 client,
1059 llm_token,
1060 app_version,
1061 ))
1062 .await;
1063
1064 Self::handle_api_response(&this, response, cx)?;
1065 anyhow::Ok(())
1066 })
1067 .detach_and_log_err(cx);
1068 }
1069
1070 async fn handle_rejected_predictions(
1071 rx: UnboundedReceiver<EditPredictionRejection>,
1072 client: Arc<Client>,
1073 llm_token: LlmApiToken,
1074 app_version: Version,
1075 background_executor: BackgroundExecutor,
1076 ) {
1077 let mut rx = std::pin::pin!(rx.peekable());
1078 let mut batched = Vec::new();
1079
1080 while let Some(rejection) = rx.next().await {
1081 batched.push(rejection);
1082
1083 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1084 select_biased! {
1085 next = rx.as_mut().peek().fuse() => {
1086 if next.is_some() {
1087 continue;
1088 }
1089 }
1090 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1091 }
1092 }
1093
1094 let url = client
1095 .http_client()
1096 .build_zed_llm_url("/predict_edits/reject", &[])
1097 .unwrap();
1098
1099 let flush_count = batched
1100 .len()
1101 // in case items have accumulated after failure
1102 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1103 let start = batched.len() - flush_count;
1104
1105 let body = RejectEditPredictionsBodyRef {
1106 rejections: &batched[start..],
1107 };
1108
1109 let result = Self::send_api_request::<()>(
1110 |builder| {
1111 let req = builder
1112 .uri(url.as_ref())
1113 .body(serde_json::to_string(&body)?.into());
1114 anyhow::Ok(req?)
1115 },
1116 client.clone(),
1117 llm_token.clone(),
1118 app_version.clone(),
1119 )
1120 .await;
1121
1122 if result.log_err().is_some() {
1123 batched.drain(start..);
1124 }
1125 }
1126 }
1127
1128 fn reject_current_prediction(
1129 &mut self,
1130 reason: EditPredictionRejectReason,
1131 project: &Entity<Project>,
1132 ) {
1133 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1134 project_state.pending_predictions.clear();
1135 if let Some(prediction) = project_state.current_prediction.take() {
1136 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1137 }
1138 };
1139 }
1140
1141 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1142 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1143 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1144 if !current_prediction.was_shown {
1145 current_prediction.was_shown = true;
1146 self.shown_predictions
1147 .push_front(current_prediction.prediction.clone());
1148 if self.shown_predictions.len() > 50 {
1149 let completion = self.shown_predictions.pop_back().unwrap();
1150 self.rated_predictions.remove(&completion.id);
1151 }
1152 }
1153 }
1154 }
1155 }
1156
1157 fn reject_prediction(
1158 &mut self,
1159 prediction_id: EditPredictionId,
1160 reason: EditPredictionRejectReason,
1161 was_shown: bool,
1162 ) {
1163 match self.edit_prediction_model {
1164 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1165 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1166 }
1167
1168 self.reject_predictions_tx
1169 .unbounded_send(EditPredictionRejection {
1170 request_id: prediction_id.to_string(),
1171 reason,
1172 was_shown,
1173 })
1174 .log_err();
1175 }
1176
1177 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1178 self.projects
1179 .get(&project.entity_id())
1180 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1181 }
1182
1183 pub fn refresh_prediction_from_buffer(
1184 &mut self,
1185 project: Entity<Project>,
1186 buffer: Entity<Buffer>,
1187 position: language::Anchor,
1188 cx: &mut Context<Self>,
1189 ) {
1190 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1191 let Some(request_task) = this
1192 .update(cx, |this, cx| {
1193 this.request_prediction(
1194 &project,
1195 &buffer,
1196 position,
1197 PredictEditsRequestTrigger::Other,
1198 cx,
1199 )
1200 })
1201 .log_err()
1202 else {
1203 return Task::ready(anyhow::Ok(None));
1204 };
1205
1206 cx.spawn(async move |_cx| {
1207 request_task.await.map(|prediction_result| {
1208 prediction_result.map(|prediction_result| {
1209 (
1210 prediction_result,
1211 PredictionRequestedBy::Buffer(buffer.entity_id()),
1212 )
1213 })
1214 })
1215 })
1216 })
1217 }
1218
1219 pub fn refresh_prediction_from_diagnostics(
1220 &mut self,
1221 project: Entity<Project>,
1222 cx: &mut Context<Self>,
1223 ) {
1224 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1225 return;
1226 };
1227
1228 // Prefer predictions from buffer
1229 if project_state.current_prediction.is_some() {
1230 return;
1231 };
1232
1233 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1234 let Some((active_buffer, snapshot, cursor_point)) = this
1235 .read_with(cx, |this, cx| {
1236 let project_state = this.projects.get(&project.entity_id())?;
1237 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1238 let snapshot = buffer.read(cx).snapshot();
1239
1240 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1241 return None;
1242 }
1243
1244 let cursor_point = position
1245 .map(|pos| pos.to_point(&snapshot))
1246 .unwrap_or_default();
1247
1248 Some((buffer, snapshot, cursor_point))
1249 })
1250 .log_err()
1251 .flatten()
1252 else {
1253 return Task::ready(anyhow::Ok(None));
1254 };
1255
1256 cx.spawn(async move |cx| {
1257 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1258 active_buffer,
1259 &snapshot,
1260 Default::default(),
1261 cursor_point,
1262 &project,
1263 cx,
1264 )
1265 .await?
1266 else {
1267 return anyhow::Ok(None);
1268 };
1269
1270 let Some(prediction_result) = this
1271 .update(cx, |this, cx| {
1272 this.request_prediction(
1273 &project,
1274 &jump_buffer,
1275 jump_position,
1276 PredictEditsRequestTrigger::Diagnostics,
1277 cx,
1278 )
1279 })?
1280 .await?
1281 else {
1282 return anyhow::Ok(None);
1283 };
1284
1285 this.update(cx, |this, cx| {
1286 Some((
1287 if this
1288 .get_or_init_project(&project, cx)
1289 .current_prediction
1290 .is_none()
1291 {
1292 prediction_result
1293 } else {
1294 EditPredictionResult {
1295 id: prediction_result.id,
1296 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1297 }
1298 },
1299 PredictionRequestedBy::DiagnosticsUpdate,
1300 ))
1301 })
1302 })
1303 });
1304 }
1305
1306 fn predictions_enabled_at(
1307 snapshot: &BufferSnapshot,
1308 position: Option<language::Anchor>,
1309 cx: &App,
1310 ) -> bool {
1311 let file = snapshot.file();
1312 let all_settings = all_language_settings(file, cx);
1313 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1314 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1315 {
1316 return false;
1317 }
1318
1319 if let Some(last_position) = position {
1320 let settings = snapshot.settings_at(last_position, cx);
1321
1322 if !settings.edit_predictions_disabled_in.is_empty()
1323 && let Some(scope) = snapshot.language_scope_at(last_position)
1324 && let Some(scope_name) = scope.override_name()
1325 && settings
1326 .edit_predictions_disabled_in
1327 .iter()
1328 .any(|s| s == scope_name)
1329 {
1330 return false;
1331 }
1332 }
1333
1334 true
1335 }
1336
1337 #[cfg(not(test))]
1338 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1339 #[cfg(test)]
1340 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1341
1342 fn queue_prediction_refresh(
1343 &mut self,
1344 project: Entity<Project>,
1345 throttle_entity: EntityId,
1346 cx: &mut Context<Self>,
1347 do_refresh: impl FnOnce(
1348 WeakEntity<Self>,
1349 &mut AsyncApp,
1350 )
1351 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1352 + 'static,
1353 ) {
1354 let project_state = self.get_or_init_project(&project, cx);
1355 let pending_prediction_id = project_state.next_pending_prediction_id;
1356 project_state.next_pending_prediction_id += 1;
1357 let last_request = project_state.last_prediction_refresh;
1358
1359 let task = cx.spawn(async move |this, cx| {
1360 if let Some((last_entity, last_timestamp)) = last_request
1361 && throttle_entity == last_entity
1362 && let Some(timeout) =
1363 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1364 {
1365 cx.background_executor().timer(timeout).await;
1366 }
1367
1368 // If this task was cancelled before the throttle timeout expired,
1369 // do not perform a request.
1370 let mut is_cancelled = true;
1371 this.update(cx, |this, cx| {
1372 let project_state = this.get_or_init_project(&project, cx);
1373 if !project_state
1374 .cancelled_predictions
1375 .remove(&pending_prediction_id)
1376 {
1377 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1378 is_cancelled = false;
1379 }
1380 })
1381 .ok();
1382 if is_cancelled {
1383 return None;
1384 }
1385
1386 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1387 let new_prediction_id = new_prediction_result
1388 .as_ref()
1389 .map(|(prediction, _)| prediction.id.clone());
1390
1391 // When a prediction completes, remove it from the pending list, and cancel
1392 // any pending predictions that were enqueued before it.
1393 this.update(cx, |this, cx| {
1394 let project_state = this.get_or_init_project(&project, cx);
1395
1396 let is_cancelled = project_state
1397 .cancelled_predictions
1398 .remove(&pending_prediction_id);
1399
1400 let new_current_prediction = if !is_cancelled
1401 && let Some((prediction_result, requested_by)) = new_prediction_result
1402 {
1403 match prediction_result.prediction {
1404 Ok(prediction) => {
1405 let new_prediction = CurrentEditPrediction {
1406 requested_by,
1407 prediction,
1408 was_shown: false,
1409 };
1410
1411 if let Some(current_prediction) =
1412 project_state.current_prediction.as_ref()
1413 {
1414 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1415 {
1416 this.reject_current_prediction(
1417 EditPredictionRejectReason::Replaced,
1418 &project,
1419 );
1420
1421 Some(new_prediction)
1422 } else {
1423 this.reject_prediction(
1424 new_prediction.prediction.id,
1425 EditPredictionRejectReason::CurrentPreferred,
1426 false,
1427 );
1428 None
1429 }
1430 } else {
1431 Some(new_prediction)
1432 }
1433 }
1434 Err(reject_reason) => {
1435 this.reject_prediction(prediction_result.id, reject_reason, false);
1436 None
1437 }
1438 }
1439 } else {
1440 None
1441 };
1442
1443 let project_state = this.get_or_init_project(&project, cx);
1444
1445 if let Some(new_prediction) = new_current_prediction {
1446 project_state.current_prediction = Some(new_prediction);
1447 }
1448
1449 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1450 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1451 if pending_prediction.id == pending_prediction_id {
1452 pending_predictions.remove(ix);
1453 for pending_prediction in pending_predictions.drain(0..ix) {
1454 project_state.cancel_pending_prediction(pending_prediction, cx)
1455 }
1456 break;
1457 }
1458 }
1459 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1460 cx.notify();
1461 })
1462 .ok();
1463
1464 new_prediction_id
1465 });
1466
1467 if project_state.pending_predictions.len() <= 1 {
1468 project_state.pending_predictions.push(PendingPrediction {
1469 id: pending_prediction_id,
1470 task,
1471 });
1472 } else if project_state.pending_predictions.len() == 2 {
1473 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1474 project_state.pending_predictions.push(PendingPrediction {
1475 id: pending_prediction_id,
1476 task,
1477 });
1478 project_state.cancel_pending_prediction(pending_prediction, cx);
1479 }
1480 }
1481
1482 pub fn request_prediction(
1483 &mut self,
1484 project: &Entity<Project>,
1485 active_buffer: &Entity<Buffer>,
1486 position: language::Anchor,
1487 trigger: PredictEditsRequestTrigger,
1488 cx: &mut Context<Self>,
1489 ) -> Task<Result<Option<EditPredictionResult>>> {
1490 self.request_prediction_internal(
1491 project.clone(),
1492 active_buffer.clone(),
1493 position,
1494 trigger,
1495 cx.has_flag::<Zeta2FeatureFlag>(),
1496 cx,
1497 )
1498 }
1499
1500 fn request_prediction_internal(
1501 &mut self,
1502 project: Entity<Project>,
1503 active_buffer: Entity<Buffer>,
1504 position: language::Anchor,
1505 trigger: PredictEditsRequestTrigger,
1506 allow_jump: bool,
1507 cx: &mut Context<Self>,
1508 ) -> Task<Result<Option<EditPredictionResult>>> {
1509 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1510
1511 self.get_or_init_project(&project, cx);
1512 let project_state = self.projects.get(&project.entity_id()).unwrap();
1513 let events = project_state.events(cx);
1514 let has_events = !events.is_empty();
1515 let debug_tx = project_state.debug_tx.clone();
1516
1517 let snapshot = active_buffer.read(cx).snapshot();
1518 let cursor_point = position.to_point(&snapshot);
1519 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1520 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1521 let diagnostic_search_range =
1522 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1523
1524 let related_files = if self.use_context {
1525 self.context_for_project(&project, cx)
1526 } else {
1527 Vec::new().into()
1528 };
1529
1530 let inputs = EditPredictionModelInput {
1531 project: project.clone(),
1532 buffer: active_buffer.clone(),
1533 snapshot: snapshot.clone(),
1534 position,
1535 events,
1536 related_files,
1537 recent_paths: project_state.recent_paths.clone(),
1538 trigger,
1539 diagnostic_search_range: diagnostic_search_range.clone(),
1540 debug_tx,
1541 };
1542
1543 let task = match self.edit_prediction_model {
1544 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1545 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1546 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1547 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1548 };
1549
1550 cx.spawn(async move |this, cx| {
1551 let prediction = task.await?;
1552
1553 if prediction.is_none() && allow_jump {
1554 let cursor_point = position.to_point(&snapshot);
1555 if has_events
1556 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1557 active_buffer.clone(),
1558 &snapshot,
1559 diagnostic_search_range,
1560 cursor_point,
1561 &project,
1562 cx,
1563 )
1564 .await?
1565 {
1566 return this
1567 .update(cx, |this, cx| {
1568 this.request_prediction_internal(
1569 project,
1570 jump_buffer,
1571 jump_position,
1572 trigger,
1573 false,
1574 cx,
1575 )
1576 })?
1577 .await;
1578 }
1579
1580 return anyhow::Ok(None);
1581 }
1582
1583 Ok(prediction)
1584 })
1585 }
1586
1587 async fn next_diagnostic_location(
1588 active_buffer: Entity<Buffer>,
1589 active_buffer_snapshot: &BufferSnapshot,
1590 active_buffer_diagnostic_search_range: Range<Point>,
1591 active_buffer_cursor_point: Point,
1592 project: &Entity<Project>,
1593 cx: &mut AsyncApp,
1594 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1595 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1596 let mut jump_location = active_buffer_snapshot
1597 .diagnostic_groups(None)
1598 .into_iter()
1599 .filter_map(|(_, group)| {
1600 let range = &group.entries[group.primary_ix]
1601 .range
1602 .to_point(&active_buffer_snapshot);
1603 if range.overlaps(&active_buffer_diagnostic_search_range) {
1604 None
1605 } else {
1606 Some(range.start)
1607 }
1608 })
1609 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1610 .map(|position| {
1611 (
1612 active_buffer.clone(),
1613 active_buffer_snapshot.anchor_before(position),
1614 )
1615 });
1616
1617 if jump_location.is_none() {
1618 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1619 let file = buffer.file()?;
1620
1621 Some(ProjectPath {
1622 worktree_id: file.worktree_id(cx),
1623 path: file.path().clone(),
1624 })
1625 })?;
1626
1627 let buffer_task = project.update(cx, |project, cx| {
1628 let (path, _, _) = project
1629 .diagnostic_summaries(false, cx)
1630 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1631 .max_by_key(|(path, _, _)| {
1632 // find the buffer with errors that shares most parent directories
1633 path.path
1634 .components()
1635 .zip(
1636 active_buffer_path
1637 .as_ref()
1638 .map(|p| p.path.components())
1639 .unwrap_or_default(),
1640 )
1641 .take_while(|(a, b)| a == b)
1642 .count()
1643 })?;
1644
1645 Some(project.open_buffer(path, cx))
1646 })?;
1647
1648 if let Some(buffer_task) = buffer_task {
1649 let closest_buffer = buffer_task.await?;
1650
1651 jump_location = closest_buffer
1652 .read_with(cx, |buffer, _cx| {
1653 buffer
1654 .buffer_diagnostics(None)
1655 .into_iter()
1656 .min_by_key(|entry| entry.diagnostic.severity)
1657 .map(|entry| entry.range.start)
1658 })?
1659 .map(|position| (closest_buffer, position));
1660 }
1661 }
1662
1663 anyhow::Ok(jump_location)
1664 }
1665
1666 async fn send_raw_llm_request(
1667 request: open_ai::Request,
1668 client: Arc<Client>,
1669 llm_token: LlmApiToken,
1670 app_version: Version,
1671 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1672 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1673 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1674 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1675 http_client::Url::parse(&predict_edits_url)?
1676 } else {
1677 client
1678 .http_client()
1679 .build_zed_llm_url("/predict_edits/raw", &[])?
1680 };
1681
1682 #[cfg(feature = "cli-support")]
1683 let cache_key = if let Some(cache) = eval_cache {
1684 use collections::FxHasher;
1685 use std::hash::{Hash, Hasher};
1686
1687 let mut hasher = FxHasher::default();
1688 url.hash(&mut hasher);
1689 let request_str = serde_json::to_string_pretty(&request)?;
1690 request_str.hash(&mut hasher);
1691 let hash = hasher.finish();
1692
1693 let key = (eval_cache_kind, hash);
1694 if let Some(response_str) = cache.read(key) {
1695 return Ok((serde_json::from_str(&response_str)?, None));
1696 }
1697
1698 Some((cache, request_str, key))
1699 } else {
1700 None
1701 };
1702
1703 let (response, usage) = Self::send_api_request(
1704 |builder| {
1705 let req = builder
1706 .uri(url.as_ref())
1707 .body(serde_json::to_string(&request)?.into());
1708 Ok(req?)
1709 },
1710 client,
1711 llm_token,
1712 app_version,
1713 )
1714 .await?;
1715
1716 #[cfg(feature = "cli-support")]
1717 if let Some((cache, request, key)) = cache_key {
1718 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1719 }
1720
1721 Ok((response, usage))
1722 }
1723
1724 fn handle_api_response<T>(
1725 this: &WeakEntity<Self>,
1726 response: Result<(T, Option<EditPredictionUsage>)>,
1727 cx: &mut gpui::AsyncApp,
1728 ) -> Result<T> {
1729 match response {
1730 Ok((data, usage)) => {
1731 if let Some(usage) = usage {
1732 this.update(cx, |this, cx| {
1733 this.user_store.update(cx, |user_store, cx| {
1734 user_store.update_edit_prediction_usage(usage, cx);
1735 });
1736 })
1737 .ok();
1738 }
1739 Ok(data)
1740 }
1741 Err(err) => {
1742 if err.is::<ZedUpdateRequiredError>() {
1743 cx.update(|cx| {
1744 this.update(cx, |this, _cx| {
1745 this.update_required = true;
1746 })
1747 .ok();
1748
1749 let error_message: SharedString = err.to_string().into();
1750 show_app_notification(
1751 NotificationId::unique::<ZedUpdateRequiredError>(),
1752 cx,
1753 move |cx| {
1754 cx.new(|cx| {
1755 ErrorMessagePrompt::new(error_message.clone(), cx)
1756 .with_link_button("Update Zed", "https://zed.dev/releases")
1757 })
1758 },
1759 );
1760 })
1761 .ok();
1762 }
1763 Err(err)
1764 }
1765 }
1766 }
1767
1768 async fn send_api_request<Res>(
1769 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1770 client: Arc<Client>,
1771 llm_token: LlmApiToken,
1772 app_version: Version,
1773 ) -> Result<(Res, Option<EditPredictionUsage>)>
1774 where
1775 Res: DeserializeOwned,
1776 {
1777 let http_client = client.http_client();
1778 let mut token = llm_token.acquire(&client).await?;
1779 let mut did_retry = false;
1780
1781 loop {
1782 let request_builder = http_client::Request::builder().method(Method::POST);
1783
1784 let request = build(
1785 request_builder
1786 .header("Content-Type", "application/json")
1787 .header("Authorization", format!("Bearer {}", token))
1788 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1789 )?;
1790
1791 let mut response = http_client.send(request).await?;
1792
1793 if let Some(minimum_required_version) = response
1794 .headers()
1795 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1796 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1797 {
1798 anyhow::ensure!(
1799 app_version >= minimum_required_version,
1800 ZedUpdateRequiredError {
1801 minimum_version: minimum_required_version
1802 }
1803 );
1804 }
1805
1806 if response.status().is_success() {
1807 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1808
1809 let mut body = Vec::new();
1810 response.body_mut().read_to_end(&mut body).await?;
1811 return Ok((serde_json::from_slice(&body)?, usage));
1812 } else if !did_retry
1813 && response
1814 .headers()
1815 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1816 .is_some()
1817 {
1818 did_retry = true;
1819 token = llm_token.refresh(&client).await?;
1820 } else {
1821 let mut body = String::new();
1822 response.body_mut().read_to_string(&mut body).await?;
1823 anyhow::bail!(
1824 "Request failed with status: {:?}\nBody: {}",
1825 response.status(),
1826 body
1827 );
1828 }
1829 }
1830 }
1831
1832 pub fn refresh_context(
1833 &mut self,
1834 project: &Entity<Project>,
1835 buffer: &Entity<language::Buffer>,
1836 cursor_position: language::Anchor,
1837 cx: &mut Context<Self>,
1838 ) {
1839 if self.use_context {
1840 self.get_or_init_project(project, cx)
1841 .context
1842 .update(cx, |store, cx| {
1843 store.refresh(buffer.clone(), cursor_position, cx);
1844 });
1845 }
1846 }
1847
1848 #[cfg(feature = "cli-support")]
1849 pub fn set_context_for_buffer(
1850 &mut self,
1851 project: &Entity<Project>,
1852 related_files: Vec<RelatedFile>,
1853 cx: &mut Context<Self>,
1854 ) {
1855 self.get_or_init_project(project, cx)
1856 .context
1857 .update(cx, |store, _| {
1858 store.set_related_files(related_files);
1859 });
1860 }
1861
1862 fn is_file_open_source(
1863 &self,
1864 project: &Entity<Project>,
1865 file: &Arc<dyn File>,
1866 cx: &App,
1867 ) -> bool {
1868 if !file.is_local() || file.is_private() {
1869 return false;
1870 }
1871 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1872 return false;
1873 };
1874 project_state
1875 .license_detection_watchers
1876 .get(&file.worktree_id(cx))
1877 .as_ref()
1878 .is_some_and(|watcher| watcher.is_project_open_source())
1879 }
1880
1881 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1882 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1883 }
1884
1885 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1886 if !self.data_collection_choice.is_enabled() {
1887 return false;
1888 }
1889 events.iter().all(|event| {
1890 matches!(
1891 event.as_ref(),
1892 zeta_prompt::Event::BufferChange {
1893 in_open_source_repo: true,
1894 ..
1895 }
1896 )
1897 })
1898 }
1899
1900 fn load_data_collection_choice() -> DataCollectionChoice {
1901 let choice = KEY_VALUE_STORE
1902 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1903 .log_err()
1904 .flatten();
1905
1906 match choice.as_deref() {
1907 Some("true") => DataCollectionChoice::Enabled,
1908 Some("false") => DataCollectionChoice::Disabled,
1909 Some(_) => {
1910 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1911 DataCollectionChoice::NotAnswered
1912 }
1913 None => DataCollectionChoice::NotAnswered,
1914 }
1915 }
1916
1917 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1918 self.data_collection_choice = self.data_collection_choice.toggle();
1919 let new_choice = self.data_collection_choice;
1920 db::write_and_log(cx, move || {
1921 KEY_VALUE_STORE.write_kvp(
1922 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1923 new_choice.is_enabled().to_string(),
1924 )
1925 });
1926 }
1927
1928 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1929 self.shown_predictions.iter()
1930 }
1931
1932 pub fn shown_completions_len(&self) -> usize {
1933 self.shown_predictions.len()
1934 }
1935
1936 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1937 self.rated_predictions.contains(id)
1938 }
1939
1940 pub fn rate_prediction(
1941 &mut self,
1942 prediction: &EditPrediction,
1943 rating: EditPredictionRating,
1944 feedback: String,
1945 cx: &mut Context<Self>,
1946 ) {
1947 self.rated_predictions.insert(prediction.id.clone());
1948 telemetry::event!(
1949 "Edit Prediction Rated",
1950 rating,
1951 inputs = prediction.inputs,
1952 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1953 feedback
1954 );
1955 self.client.telemetry().flush_events().detach();
1956 cx.notify();
1957 }
1958
1959 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1960 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1961 && all_language_settings(None, cx).edit_predictions.use_context;
1962 }
1963}
1964
1965#[derive(Error, Debug)]
1966#[error(
1967 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1968)]
1969pub struct ZedUpdateRequiredError {
1970 minimum_version: Version,
1971}
1972
1973#[cfg(feature = "cli-support")]
1974pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1975
1976#[cfg(feature = "cli-support")]
1977#[derive(Debug, Clone, Copy, PartialEq)]
1978pub enum EvalCacheEntryKind {
1979 Context,
1980 Search,
1981 Prediction,
1982}
1983
1984#[cfg(feature = "cli-support")]
1985impl std::fmt::Display for EvalCacheEntryKind {
1986 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1987 match self {
1988 EvalCacheEntryKind::Search => write!(f, "search"),
1989 EvalCacheEntryKind::Context => write!(f, "context"),
1990 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1991 }
1992 }
1993}
1994
1995#[cfg(feature = "cli-support")]
1996pub trait EvalCache: Send + Sync {
1997 fn read(&self, key: EvalCacheKey) -> Option<String>;
1998 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1999}
2000
2001#[derive(Debug, Clone, Copy)]
2002pub enum DataCollectionChoice {
2003 NotAnswered,
2004 Enabled,
2005 Disabled,
2006}
2007
2008impl DataCollectionChoice {
2009 pub fn is_enabled(self) -> bool {
2010 match self {
2011 Self::Enabled => true,
2012 Self::NotAnswered | Self::Disabled => false,
2013 }
2014 }
2015
2016 pub fn is_answered(self) -> bool {
2017 match self {
2018 Self::Enabled | Self::Disabled => true,
2019 Self::NotAnswered => false,
2020 }
2021 }
2022
2023 #[must_use]
2024 pub fn toggle(&self) -> DataCollectionChoice {
2025 match self {
2026 Self::Enabled => Self::Disabled,
2027 Self::Disabled => Self::Enabled,
2028 Self::NotAnswered => Self::Enabled,
2029 }
2030 }
2031}
2032
2033impl From<bool> for DataCollectionChoice {
2034 fn from(value: bool) -> Self {
2035 match value {
2036 true => DataCollectionChoice::Enabled,
2037 false => DataCollectionChoice::Disabled,
2038 }
2039 }
2040}
2041
2042struct ZedPredictUpsell;
2043
2044impl Dismissable for ZedPredictUpsell {
2045 const KEY: &'static str = "dismissed-edit-predict-upsell";
2046
2047 fn dismissed() -> bool {
2048 // To make this backwards compatible with older versions of Zed, we
2049 // check if the user has seen the previous Edit Prediction Onboarding
2050 // before, by checking the data collection choice which was written to
2051 // the database once the user clicked on "Accept and Enable"
2052 if KEY_VALUE_STORE
2053 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2054 .log_err()
2055 .is_some_and(|s| s.is_some())
2056 {
2057 return true;
2058 }
2059
2060 KEY_VALUE_STORE
2061 .read_kvp(Self::KEY)
2062 .log_err()
2063 .is_some_and(|s| s.is_some())
2064 }
2065}
2066
2067pub fn should_show_upsell_modal() -> bool {
2068 !ZedPredictUpsell::dismissed()
2069}
2070
2071pub fn init(cx: &mut App) {
2072 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2073 workspace.register_action(
2074 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2075 ZedPredictModal::toggle(
2076 workspace,
2077 workspace.user_store().clone(),
2078 workspace.client().clone(),
2079 window,
2080 cx,
2081 )
2082 },
2083 );
2084
2085 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2086 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2087 settings
2088 .project
2089 .all_languages
2090 .features
2091 .get_or_insert_default()
2092 .edit_prediction_provider = Some(EditPredictionProvider::None)
2093 });
2094 });
2095 })
2096 .detach();
2097}