1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
24 http_client::{self, AsyncBody, Method},
25 prelude::*,
26};
27use language::language_settings::all_language_settings;
28use language::{Anchor, Buffer, File, Point, ToPoint};
29use language::{BufferSnapshot, OffsetRangeExt};
30use language_model::{LlmApiToken, RefreshLlmTokenListener};
31use project::{Project, ProjectPath, WorktreeId};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
36use std::collections::{VecDeque, hash_map};
37use workspace::Workspace;
38
39use std::ops::Range;
40use std::path::Path;
41use std::rc::Rc;
42use std::str::FromStr as _;
43use std::sync::{Arc, LazyLock};
44use std::time::{Duration, Instant};
45use std::{env, mem};
46use thiserror::Error;
47use util::{RangeExt as _, ResultExt as _};
48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
49
50mod cursor_excerpt;
51mod license_detection;
52pub mod mercury;
53mod onboarding_modal;
54pub mod open_ai_response;
55mod prediction;
56pub mod sweep_ai;
57
58#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
59pub mod udiff;
60
61mod zed_edit_prediction_delegate;
62pub mod zeta1;
63pub mod zeta2;
64
65#[cfg(test)]
66mod edit_prediction_tests;
67
68use crate::license_detection::LicenseDetectionWatcher;
69use crate::mercury::Mercury;
70use crate::onboarding_modal::ZedPredictModal;
71pub use crate::prediction::EditPrediction;
72pub use crate::prediction::EditPredictionId;
73use crate::prediction::EditPredictionResult;
74pub use crate::sweep_ai::SweepAi;
75pub use telemetry_events::EditPredictionRating;
76pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
77
78actions!(
79 edit_prediction,
80 [
81 /// Resets the edit prediction onboarding state.
82 ResetOnboarding,
83 /// Clears the edit prediction history.
84 ClearHistory,
85 ]
86);
87
88/// Maximum number of events to track.
89const EVENT_COUNT_MAX: usize = 6;
90const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
91const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
92const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
93
94pub struct SweepFeatureFlag;
95
96impl FeatureFlag for SweepFeatureFlag {
97 const NAME: &str = "sweep-ai";
98}
99
100pub struct MercuryFeatureFlag;
101
102impl FeatureFlag for MercuryFeatureFlag {
103 const NAME: &str = "mercury";
104}
105
106pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
107 context: EditPredictionExcerptOptions {
108 max_bytes: 512,
109 min_bytes: 128,
110 target_before_cursor_over_total_bytes: 0.5,
111 },
112 prompt_format: PromptFormat::DEFAULT,
113};
114
115static USE_OLLAMA: LazyLock<bool> =
116 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
117
118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
119 match env::var("ZED_ZETA2_MODEL").as_deref() {
120 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
121 Ok(model) => model,
122 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
123 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
124 }
125 .to_string()
126});
127static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
128 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
129 if *USE_OLLAMA {
130 Some("http://localhost:11434/v1/chat/completions".into())
131 } else {
132 None
133 }
134 })
135});
136
137pub struct Zeta2FeatureFlag;
138
139impl FeatureFlag for Zeta2FeatureFlag {
140 const NAME: &'static str = "zeta2";
141
142 fn enabled_for_staff() -> bool {
143 true
144 }
145}
146
147#[derive(Clone)]
148struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
149
150impl Global for EditPredictionStoreGlobal {}
151
152pub struct EditPredictionStore {
153 client: Arc<Client>,
154 user_store: Entity<UserStore>,
155 llm_token: LlmApiToken,
156 _llm_token_subscription: Subscription,
157 projects: HashMap<EntityId, ProjectState>,
158 use_context: bool,
159 options: ZetaOptions,
160 update_required: bool,
161 #[cfg(feature = "cli-support")]
162 eval_cache: Option<Arc<dyn EvalCache>>,
163 edit_prediction_model: EditPredictionModel,
164 pub sweep_ai: SweepAi,
165 pub mercury: Mercury,
166 data_collection_choice: DataCollectionChoice,
167 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
168 shown_predictions: VecDeque<EditPrediction>,
169 rated_predictions: HashSet<EditPredictionId>,
170}
171
172#[derive(Copy, Clone, Default, PartialEq, Eq)]
173pub enum EditPredictionModel {
174 #[default]
175 Zeta1,
176 Zeta2,
177 Sweep,
178 Mercury,
179}
180
181pub struct EditPredictionModelInput {
182 project: Entity<Project>,
183 buffer: Entity<Buffer>,
184 snapshot: BufferSnapshot,
185 position: Anchor,
186 events: Vec<Arc<zeta_prompt::Event>>,
187 related_files: Arc<[RelatedFile]>,
188 recent_paths: VecDeque<ProjectPath>,
189 trigger: PredictEditsRequestTrigger,
190 diagnostic_search_range: Range<Point>,
191 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
192}
193
194#[derive(Debug, Clone, PartialEq)]
195pub struct ZetaOptions {
196 pub context: EditPredictionExcerptOptions,
197 pub prompt_format: predict_edits_v3::PromptFormat,
198}
199
200#[derive(Debug)]
201pub enum DebugEvent {
202 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
203 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
204 EditPredictionStarted(EditPredictionStartedDebugEvent),
205 EditPredictionFinished(EditPredictionFinishedDebugEvent),
206}
207
208#[derive(Debug)]
209pub struct ContextRetrievalStartedDebugEvent {
210 pub project_entity_id: EntityId,
211 pub timestamp: Instant,
212 pub search_prompt: String,
213}
214
215#[derive(Debug)]
216pub struct ContextRetrievalFinishedDebugEvent {
217 pub project_entity_id: EntityId,
218 pub timestamp: Instant,
219 pub metadata: Vec<(&'static str, SharedString)>,
220}
221
222#[derive(Debug)]
223pub struct EditPredictionStartedDebugEvent {
224 pub buffer: WeakEntity<Buffer>,
225 pub position: Anchor,
226 pub prompt: Option<String>,
227}
228
229#[derive(Debug)]
230pub struct EditPredictionFinishedDebugEvent {
231 pub buffer: WeakEntity<Buffer>,
232 pub position: Anchor,
233 pub model_output: Option<String>,
234}
235
236pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
237
238struct ProjectState {
239 events: VecDeque<Arc<zeta_prompt::Event>>,
240 last_event: Option<LastEvent>,
241 recent_paths: VecDeque<ProjectPath>,
242 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
243 current_prediction: Option<CurrentEditPrediction>,
244 next_pending_prediction_id: usize,
245 pending_predictions: ArrayVec<PendingPrediction, 2>,
246 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
247 last_prediction_refresh: Option<(EntityId, Instant)>,
248 cancelled_predictions: HashSet<usize>,
249 context: Entity<RelatedExcerptStore>,
250 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
251 _subscription: gpui::Subscription,
252}
253
254impl ProjectState {
255 pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
256 self.events
257 .iter()
258 .cloned()
259 .chain(
260 self.last_event
261 .as_ref()
262 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
263 )
264 .collect()
265 }
266
267 fn cancel_pending_prediction(
268 &mut self,
269 pending_prediction: PendingPrediction,
270 cx: &mut Context<EditPredictionStore>,
271 ) {
272 self.cancelled_predictions.insert(pending_prediction.id);
273
274 cx.spawn(async move |this, cx| {
275 let Some(prediction_id) = pending_prediction.task.await else {
276 return;
277 };
278
279 this.update(cx, |this, _cx| {
280 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
281 })
282 .ok();
283 })
284 .detach()
285 }
286
287 fn active_buffer(
288 &self,
289 project: &Entity<Project>,
290 cx: &App,
291 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
292 let project = project.read(cx);
293 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
294 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
295 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
296 Some((active_buffer, registered_buffer.last_position))
297 }
298}
299
300#[derive(Debug, Clone)]
301struct CurrentEditPrediction {
302 pub requested_by: PredictionRequestedBy,
303 pub prediction: EditPrediction,
304 pub was_shown: bool,
305}
306
307impl CurrentEditPrediction {
308 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
309 let Some(new_edits) = self
310 .prediction
311 .interpolate(&self.prediction.buffer.read(cx))
312 else {
313 return false;
314 };
315
316 if self.prediction.buffer != old_prediction.prediction.buffer {
317 return true;
318 }
319
320 let Some(old_edits) = old_prediction
321 .prediction
322 .interpolate(&old_prediction.prediction.buffer.read(cx))
323 else {
324 return true;
325 };
326
327 let requested_by_buffer_id = self.requested_by.buffer_id();
328
329 // This reduces the occurrence of UI thrash from replacing edits
330 //
331 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
332 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
333 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
334 && old_edits.len() == 1
335 && new_edits.len() == 1
336 {
337 let (old_range, old_text) = &old_edits[0];
338 let (new_range, new_text) = &new_edits[0];
339 new_range == old_range && new_text.starts_with(old_text.as_ref())
340 } else {
341 true
342 }
343 }
344}
345
346#[derive(Debug, Clone)]
347enum PredictionRequestedBy {
348 DiagnosticsUpdate,
349 Buffer(EntityId),
350}
351
352impl PredictionRequestedBy {
353 pub fn buffer_id(&self) -> Option<EntityId> {
354 match self {
355 PredictionRequestedBy::DiagnosticsUpdate => None,
356 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
357 }
358 }
359}
360
361#[derive(Debug)]
362struct PendingPrediction {
363 id: usize,
364 task: Task<Option<EditPredictionId>>,
365}
366
367/// A prediction from the perspective of a buffer.
368#[derive(Debug)]
369enum BufferEditPrediction<'a> {
370 Local { prediction: &'a EditPrediction },
371 Jump { prediction: &'a EditPrediction },
372}
373
374#[cfg(test)]
375impl std::ops::Deref for BufferEditPrediction<'_> {
376 type Target = EditPrediction;
377
378 fn deref(&self) -> &Self::Target {
379 match self {
380 BufferEditPrediction::Local { prediction } => prediction,
381 BufferEditPrediction::Jump { prediction } => prediction,
382 }
383 }
384}
385
386struct RegisteredBuffer {
387 snapshot: BufferSnapshot,
388 last_position: Option<Anchor>,
389 _subscriptions: [gpui::Subscription; 2],
390}
391
392struct LastEvent {
393 old_snapshot: BufferSnapshot,
394 new_snapshot: BufferSnapshot,
395 end_edit_anchor: Option<Anchor>,
396}
397
398impl LastEvent {
399 pub fn finalize(
400 &self,
401 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
402 cx: &App,
403 ) -> Option<Arc<zeta_prompt::Event>> {
404 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
405 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
406
407 let file = self.new_snapshot.file();
408 let old_file = self.old_snapshot.file();
409
410 let in_open_source_repo = [file, old_file].iter().all(|file| {
411 file.is_some_and(|file| {
412 license_detection_watchers
413 .get(&file.worktree_id(cx))
414 .is_some_and(|watcher| watcher.is_project_open_source())
415 })
416 });
417
418 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
419
420 if path == old_path && diff.is_empty() {
421 None
422 } else {
423 Some(Arc::new(zeta_prompt::Event::BufferChange {
424 old_path,
425 path,
426 diff,
427 in_open_source_repo,
428 // TODO: Actually detect if this edit was predicted or not
429 predicted: false,
430 }))
431 }
432 }
433}
434
435fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
436 if let Some(file) = snapshot.file() {
437 file.full_path(cx).into()
438 } else {
439 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
440 }
441}
442
443impl EditPredictionStore {
444 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
445 cx.try_global::<EditPredictionStoreGlobal>()
446 .map(|global| global.0.clone())
447 }
448
449 pub fn global(
450 client: &Arc<Client>,
451 user_store: &Entity<UserStore>,
452 cx: &mut App,
453 ) -> Entity<Self> {
454 cx.try_global::<EditPredictionStoreGlobal>()
455 .map(|global| global.0.clone())
456 .unwrap_or_else(|| {
457 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
458 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
459 ep_store
460 })
461 }
462
463 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
464 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
465 let data_collection_choice = Self::load_data_collection_choice();
466
467 let llm_token = LlmApiToken::default();
468
469 let (reject_tx, reject_rx) = mpsc::unbounded();
470 cx.background_spawn({
471 let client = client.clone();
472 let llm_token = llm_token.clone();
473 let app_version = AppVersion::global(cx);
474 let background_executor = cx.background_executor().clone();
475 async move {
476 Self::handle_rejected_predictions(
477 reject_rx,
478 client,
479 llm_token,
480 app_version,
481 background_executor,
482 )
483 .await
484 }
485 })
486 .detach();
487
488 let mut this = Self {
489 projects: HashMap::default(),
490 client,
491 user_store,
492 options: DEFAULT_OPTIONS,
493 use_context: false,
494 llm_token,
495 _llm_token_subscription: cx.subscribe(
496 &refresh_llm_token_listener,
497 |this, _listener, _event, cx| {
498 let client = this.client.clone();
499 let llm_token = this.llm_token.clone();
500 cx.spawn(async move |_this, _cx| {
501 llm_token.refresh(&client).await?;
502 anyhow::Ok(())
503 })
504 .detach_and_log_err(cx);
505 },
506 ),
507 update_required: false,
508 #[cfg(feature = "cli-support")]
509 eval_cache: None,
510 edit_prediction_model: EditPredictionModel::Zeta2,
511 sweep_ai: SweepAi::new(cx),
512 mercury: Mercury::new(cx),
513 data_collection_choice,
514 reject_predictions_tx: reject_tx,
515 rated_predictions: Default::default(),
516 shown_predictions: Default::default(),
517 };
518
519 this.configure_context_retrieval(cx);
520 let weak_this = cx.weak_entity();
521 cx.on_flags_ready(move |_, cx| {
522 weak_this
523 .update(cx, |this, cx| this.configure_context_retrieval(cx))
524 .ok();
525 })
526 .detach();
527 cx.observe_global::<SettingsStore>(|this, cx| {
528 this.configure_context_retrieval(cx);
529 })
530 .detach();
531
532 this
533 }
534
535 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
536 self.edit_prediction_model = model;
537 }
538
539 pub fn has_sweep_api_token(&self) -> bool {
540 self.sweep_ai
541 .api_token
542 .clone()
543 .now_or_never()
544 .flatten()
545 .is_some()
546 }
547
548 pub fn has_mercury_api_token(&self) -> bool {
549 self.mercury
550 .api_token
551 .clone()
552 .now_or_never()
553 .flatten()
554 .is_some()
555 }
556
557 #[cfg(feature = "cli-support")]
558 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
559 self.eval_cache = Some(cache);
560 }
561
562 pub fn options(&self) -> &ZetaOptions {
563 &self.options
564 }
565
566 pub fn set_options(&mut self, options: ZetaOptions) {
567 self.options = options;
568 }
569
570 pub fn set_use_context(&mut self, use_context: bool) {
571 self.use_context = use_context;
572 }
573
574 pub fn clear_history(&mut self) {
575 for project_state in self.projects.values_mut() {
576 project_state.events.clear();
577 }
578 }
579
580 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
581 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
582 project_state.events.clear();
583 }
584 }
585
586 pub fn edit_history_for_project(
587 &self,
588 project: &Entity<Project>,
589 cx: &App,
590 ) -> Vec<Arc<zeta_prompt::Event>> {
591 self.projects
592 .get(&project.entity_id())
593 .map(|project_state| project_state.events(cx))
594 .unwrap_or_default()
595 }
596
597 pub fn context_for_project<'a>(
598 &'a self,
599 project: &Entity<Project>,
600 cx: &'a App,
601 ) -> Arc<[RelatedFile]> {
602 self.projects
603 .get(&project.entity_id())
604 .map(|project| project.context.read(cx).related_files())
605 .unwrap_or_else(|| vec![].into())
606 }
607
608 pub fn context_for_project_with_buffers<'a>(
609 &'a self,
610 project: &Entity<Project>,
611 cx: &'a App,
612 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
613 self.projects
614 .get(&project.entity_id())
615 .map(|project| project.context.read(cx).related_files_with_buffers())
616 }
617
618 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
619 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
620 self.user_store.read(cx).edit_prediction_usage()
621 } else {
622 None
623 }
624 }
625
626 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
627 self.get_or_init_project(project, cx);
628 }
629
630 pub fn register_buffer(
631 &mut self,
632 buffer: &Entity<Buffer>,
633 project: &Entity<Project>,
634 cx: &mut Context<Self>,
635 ) {
636 let project_state = self.get_or_init_project(project, cx);
637 Self::register_buffer_impl(project_state, buffer, project, cx);
638 }
639
640 fn get_or_init_project(
641 &mut self,
642 project: &Entity<Project>,
643 cx: &mut Context<Self>,
644 ) -> &mut ProjectState {
645 let entity_id = project.entity_id();
646 self.projects
647 .entry(entity_id)
648 .or_insert_with(|| ProjectState {
649 context: {
650 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
651 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
652 this.handle_excerpt_store_event(entity_id, event);
653 })
654 .detach();
655 related_excerpt_store
656 },
657 events: VecDeque::new(),
658 last_event: None,
659 recent_paths: VecDeque::new(),
660 debug_tx: None,
661 registered_buffers: HashMap::default(),
662 current_prediction: None,
663 cancelled_predictions: HashSet::default(),
664 pending_predictions: ArrayVec::new(),
665 next_pending_prediction_id: 0,
666 last_prediction_refresh: None,
667 license_detection_watchers: HashMap::default(),
668 _subscription: cx.subscribe(&project, Self::handle_project_event),
669 })
670 }
671
672 pub fn remove_project(&mut self, project: &Entity<Project>) {
673 self.projects.remove(&project.entity_id());
674 }
675
676 fn handle_excerpt_store_event(
677 &mut self,
678 project_entity_id: EntityId,
679 event: &RelatedExcerptStoreEvent,
680 ) {
681 if let Some(project_state) = self.projects.get(&project_entity_id) {
682 if let Some(debug_tx) = project_state.debug_tx.clone() {
683 match event {
684 RelatedExcerptStoreEvent::StartedRefresh => {
685 debug_tx
686 .unbounded_send(DebugEvent::ContextRetrievalStarted(
687 ContextRetrievalStartedDebugEvent {
688 project_entity_id: project_entity_id,
689 timestamp: Instant::now(),
690 search_prompt: String::new(),
691 },
692 ))
693 .ok();
694 }
695 RelatedExcerptStoreEvent::FinishedRefresh {
696 cache_hit_count,
697 cache_miss_count,
698 mean_definition_latency,
699 max_definition_latency,
700 } => {
701 debug_tx
702 .unbounded_send(DebugEvent::ContextRetrievalFinished(
703 ContextRetrievalFinishedDebugEvent {
704 project_entity_id: project_entity_id,
705 timestamp: Instant::now(),
706 metadata: vec![
707 (
708 "Cache Hits",
709 format!(
710 "{}/{}",
711 cache_hit_count,
712 cache_hit_count + cache_miss_count
713 )
714 .into(),
715 ),
716 (
717 "Max LSP Time",
718 format!("{} ms", max_definition_latency.as_millis())
719 .into(),
720 ),
721 (
722 "Mean LSP Time",
723 format!("{} ms", mean_definition_latency.as_millis())
724 .into(),
725 ),
726 ],
727 },
728 ))
729 .ok();
730 }
731 }
732 }
733 }
734 }
735
736 pub fn debug_info(
737 &mut self,
738 project: &Entity<Project>,
739 cx: &mut Context<Self>,
740 ) -> mpsc::UnboundedReceiver<DebugEvent> {
741 let project_state = self.get_or_init_project(project, cx);
742 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
743 project_state.debug_tx = Some(debug_watch_tx);
744 debug_watch_rx
745 }
746
747 fn handle_project_event(
748 &mut self,
749 project: Entity<Project>,
750 event: &project::Event,
751 cx: &mut Context<Self>,
752 ) {
753 // TODO [zeta2] init with recent paths
754 match event {
755 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
756 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
757 return;
758 };
759 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
760 if let Some(path) = path {
761 if let Some(ix) = project_state
762 .recent_paths
763 .iter()
764 .position(|probe| probe == &path)
765 {
766 project_state.recent_paths.remove(ix);
767 }
768 project_state.recent_paths.push_front(path);
769 }
770 }
771 project::Event::DiagnosticsUpdated { .. } => {
772 if cx.has_flag::<Zeta2FeatureFlag>() {
773 self.refresh_prediction_from_diagnostics(project, cx);
774 }
775 }
776 _ => (),
777 }
778 }
779
780 fn register_buffer_impl<'a>(
781 project_state: &'a mut ProjectState,
782 buffer: &Entity<Buffer>,
783 project: &Entity<Project>,
784 cx: &mut Context<Self>,
785 ) -> &'a mut RegisteredBuffer {
786 let buffer_id = buffer.entity_id();
787
788 if let Some(file) = buffer.read(cx).file() {
789 let worktree_id = file.worktree_id(cx);
790 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
791 project_state
792 .license_detection_watchers
793 .entry(worktree_id)
794 .or_insert_with(|| {
795 let project_entity_id = project.entity_id();
796 cx.observe_release(&worktree, move |this, _worktree, _cx| {
797 let Some(project_state) = this.projects.get_mut(&project_entity_id)
798 else {
799 return;
800 };
801 project_state
802 .license_detection_watchers
803 .remove(&worktree_id);
804 })
805 .detach();
806 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
807 });
808 }
809 }
810
811 match project_state.registered_buffers.entry(buffer_id) {
812 hash_map::Entry::Occupied(entry) => entry.into_mut(),
813 hash_map::Entry::Vacant(entry) => {
814 let snapshot = buffer.read(cx).snapshot();
815 let project_entity_id = project.entity_id();
816 entry.insert(RegisteredBuffer {
817 snapshot,
818 last_position: None,
819 _subscriptions: [
820 cx.subscribe(buffer, {
821 let project = project.downgrade();
822 move |this, buffer, event, cx| {
823 if let language::BufferEvent::Edited = event
824 && let Some(project) = project.upgrade()
825 {
826 this.report_changes_for_buffer(&buffer, &project, cx);
827 }
828 }
829 }),
830 cx.observe_release(buffer, move |this, _buffer, _cx| {
831 let Some(project_state) = this.projects.get_mut(&project_entity_id)
832 else {
833 return;
834 };
835 project_state.registered_buffers.remove(&buffer_id);
836 }),
837 ],
838 })
839 }
840 }
841 }
842
843 fn report_changes_for_buffer(
844 &mut self,
845 buffer: &Entity<Buffer>,
846 project: &Entity<Project>,
847 cx: &mut Context<Self>,
848 ) {
849 let project_state = self.get_or_init_project(project, cx);
850 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
851
852 let new_snapshot = buffer.read(cx).snapshot();
853 if new_snapshot.version == registered_buffer.snapshot.version {
854 return;
855 }
856
857 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
858 let end_edit_anchor = new_snapshot
859 .anchored_edits_since::<Point>(&old_snapshot.version)
860 .last()
861 .map(|(_, range)| range.end);
862 let events = &mut project_state.events;
863
864 if let Some(LastEvent {
865 new_snapshot: last_new_snapshot,
866 end_edit_anchor: last_end_edit_anchor,
867 ..
868 }) = project_state.last_event.as_mut()
869 {
870 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
871 == last_new_snapshot.remote_id()
872 && old_snapshot.version == last_new_snapshot.version;
873
874 let should_coalesce = is_next_snapshot_of_same_buffer
875 && end_edit_anchor
876 .as_ref()
877 .zip(last_end_edit_anchor.as_ref())
878 .is_some_and(|(a, b)| {
879 let a = a.to_point(&new_snapshot);
880 let b = b.to_point(&new_snapshot);
881 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
882 });
883
884 if should_coalesce {
885 *last_end_edit_anchor = end_edit_anchor;
886 *last_new_snapshot = new_snapshot;
887 return;
888 }
889 }
890
891 if events.len() + 1 >= EVENT_COUNT_MAX {
892 events.pop_front();
893 }
894
895 if let Some(event) = project_state.last_event.take() {
896 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
897 }
898
899 project_state.last_event = Some(LastEvent {
900 old_snapshot,
901 new_snapshot,
902 end_edit_anchor,
903 });
904 }
905
906 fn prediction_at(
907 &mut self,
908 buffer: &Entity<Buffer>,
909 position: Option<language::Anchor>,
910 project: &Entity<Project>,
911 cx: &App,
912 ) -> Option<BufferEditPrediction<'_>> {
913 let project_state = self.projects.get_mut(&project.entity_id())?;
914 if let Some(position) = position
915 && let Some(buffer) = project_state
916 .registered_buffers
917 .get_mut(&buffer.entity_id())
918 {
919 buffer.last_position = Some(position);
920 }
921
922 let CurrentEditPrediction {
923 requested_by,
924 prediction,
925 ..
926 } = project_state.current_prediction.as_ref()?;
927
928 if prediction.targets_buffer(buffer.read(cx)) {
929 Some(BufferEditPrediction::Local { prediction })
930 } else {
931 let show_jump = match requested_by {
932 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
933 requested_by_buffer_id == &buffer.entity_id()
934 }
935 PredictionRequestedBy::DiagnosticsUpdate => true,
936 };
937
938 if show_jump {
939 Some(BufferEditPrediction::Jump { prediction })
940 } else {
941 None
942 }
943 }
944 }
945
946 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
947 match self.edit_prediction_model {
948 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
949 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
950 }
951
952 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
953 return;
954 };
955
956 let Some(prediction) = project_state.current_prediction.take() else {
957 return;
958 };
959 let request_id = prediction.prediction.id.to_string();
960 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
961 project_state.cancel_pending_prediction(pending_prediction, cx);
962 }
963
964 let client = self.client.clone();
965 let llm_token = self.llm_token.clone();
966 let app_version = AppVersion::global(cx);
967 cx.spawn(async move |this, cx| {
968 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
969 http_client::Url::parse(&predict_edits_url)?
970 } else {
971 client
972 .http_client()
973 .build_zed_llm_url("/predict_edits/accept", &[])?
974 };
975
976 let response = cx
977 .background_spawn(Self::send_api_request::<()>(
978 move |builder| {
979 let req = builder.uri(url.as_ref()).body(
980 serde_json::to_string(&AcceptEditPredictionBody {
981 request_id: request_id.clone(),
982 })?
983 .into(),
984 );
985 Ok(req?)
986 },
987 client,
988 llm_token,
989 app_version,
990 ))
991 .await;
992
993 Self::handle_api_response(&this, response, cx)?;
994 anyhow::Ok(())
995 })
996 .detach_and_log_err(cx);
997 }
998
999 async fn handle_rejected_predictions(
1000 rx: UnboundedReceiver<EditPredictionRejection>,
1001 client: Arc<Client>,
1002 llm_token: LlmApiToken,
1003 app_version: Version,
1004 background_executor: BackgroundExecutor,
1005 ) {
1006 let mut rx = std::pin::pin!(rx.peekable());
1007 let mut batched = Vec::new();
1008
1009 while let Some(rejection) = rx.next().await {
1010 batched.push(rejection);
1011
1012 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1013 select_biased! {
1014 next = rx.as_mut().peek().fuse() => {
1015 if next.is_some() {
1016 continue;
1017 }
1018 }
1019 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1020 }
1021 }
1022
1023 let url = client
1024 .http_client()
1025 .build_zed_llm_url("/predict_edits/reject", &[])
1026 .unwrap();
1027
1028 let flush_count = batched
1029 .len()
1030 // in case items have accumulated after failure
1031 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1032 let start = batched.len() - flush_count;
1033
1034 let body = RejectEditPredictionsBodyRef {
1035 rejections: &batched[start..],
1036 };
1037
1038 let result = Self::send_api_request::<()>(
1039 |builder| {
1040 let req = builder
1041 .uri(url.as_ref())
1042 .body(serde_json::to_string(&body)?.into());
1043 anyhow::Ok(req?)
1044 },
1045 client.clone(),
1046 llm_token.clone(),
1047 app_version.clone(),
1048 )
1049 .await;
1050
1051 if result.log_err().is_some() {
1052 batched.drain(start..);
1053 }
1054 }
1055 }
1056
1057 fn reject_current_prediction(
1058 &mut self,
1059 reason: EditPredictionRejectReason,
1060 project: &Entity<Project>,
1061 ) {
1062 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1063 project_state.pending_predictions.clear();
1064 if let Some(prediction) = project_state.current_prediction.take() {
1065 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1066 }
1067 };
1068 }
1069
1070 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1071 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1072 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1073 if !current_prediction.was_shown {
1074 current_prediction.was_shown = true;
1075 self.shown_predictions
1076 .push_front(current_prediction.prediction.clone());
1077 if self.shown_predictions.len() > 50 {
1078 let completion = self.shown_predictions.pop_back().unwrap();
1079 self.rated_predictions.remove(&completion.id);
1080 }
1081 }
1082 }
1083 }
1084 }
1085
1086 fn reject_prediction(
1087 &mut self,
1088 prediction_id: EditPredictionId,
1089 reason: EditPredictionRejectReason,
1090 was_shown: bool,
1091 ) {
1092 match self.edit_prediction_model {
1093 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1094 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1095 }
1096
1097 self.reject_predictions_tx
1098 .unbounded_send(EditPredictionRejection {
1099 request_id: prediction_id.to_string(),
1100 reason,
1101 was_shown,
1102 })
1103 .log_err();
1104 }
1105
1106 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1107 self.projects
1108 .get(&project.entity_id())
1109 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1110 }
1111
1112 pub fn refresh_prediction_from_buffer(
1113 &mut self,
1114 project: Entity<Project>,
1115 buffer: Entity<Buffer>,
1116 position: language::Anchor,
1117 cx: &mut Context<Self>,
1118 ) {
1119 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1120 let Some(request_task) = this
1121 .update(cx, |this, cx| {
1122 this.request_prediction(
1123 &project,
1124 &buffer,
1125 position,
1126 PredictEditsRequestTrigger::Other,
1127 cx,
1128 )
1129 })
1130 .log_err()
1131 else {
1132 return Task::ready(anyhow::Ok(None));
1133 };
1134
1135 cx.spawn(async move |_cx| {
1136 request_task.await.map(|prediction_result| {
1137 prediction_result.map(|prediction_result| {
1138 (
1139 prediction_result,
1140 PredictionRequestedBy::Buffer(buffer.entity_id()),
1141 )
1142 })
1143 })
1144 })
1145 })
1146 }
1147
1148 pub fn refresh_prediction_from_diagnostics(
1149 &mut self,
1150 project: Entity<Project>,
1151 cx: &mut Context<Self>,
1152 ) {
1153 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1154 return;
1155 };
1156
1157 // Prefer predictions from buffer
1158 if project_state.current_prediction.is_some() {
1159 return;
1160 };
1161
1162 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1163 let Some((active_buffer, snapshot, cursor_point)) = this
1164 .read_with(cx, |this, cx| {
1165 let project_state = this.projects.get(&project.entity_id())?;
1166 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1167 let snapshot = buffer.read(cx).snapshot();
1168
1169 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1170 return None;
1171 }
1172
1173 let cursor_point = position
1174 .map(|pos| pos.to_point(&snapshot))
1175 .unwrap_or_default();
1176
1177 Some((buffer, snapshot, cursor_point))
1178 })
1179 .log_err()
1180 .flatten()
1181 else {
1182 return Task::ready(anyhow::Ok(None));
1183 };
1184
1185 cx.spawn(async move |cx| {
1186 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1187 active_buffer,
1188 &snapshot,
1189 Default::default(),
1190 cursor_point,
1191 &project,
1192 cx,
1193 )
1194 .await?
1195 else {
1196 return anyhow::Ok(None);
1197 };
1198
1199 let Some(prediction_result) = this
1200 .update(cx, |this, cx| {
1201 this.request_prediction(
1202 &project,
1203 &jump_buffer,
1204 jump_position,
1205 PredictEditsRequestTrigger::Diagnostics,
1206 cx,
1207 )
1208 })?
1209 .await?
1210 else {
1211 return anyhow::Ok(None);
1212 };
1213
1214 this.update(cx, |this, cx| {
1215 Some((
1216 if this
1217 .get_or_init_project(&project, cx)
1218 .current_prediction
1219 .is_none()
1220 {
1221 prediction_result
1222 } else {
1223 EditPredictionResult {
1224 id: prediction_result.id,
1225 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1226 }
1227 },
1228 PredictionRequestedBy::DiagnosticsUpdate,
1229 ))
1230 })
1231 })
1232 });
1233 }
1234
1235 fn predictions_enabled_at(
1236 snapshot: &BufferSnapshot,
1237 position: Option<language::Anchor>,
1238 cx: &App,
1239 ) -> bool {
1240 let file = snapshot.file();
1241 let all_settings = all_language_settings(file, cx);
1242 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1243 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1244 {
1245 return false;
1246 }
1247
1248 if let Some(last_position) = position {
1249 let settings = snapshot.settings_at(last_position, cx);
1250
1251 if !settings.edit_predictions_disabled_in.is_empty()
1252 && let Some(scope) = snapshot.language_scope_at(last_position)
1253 && let Some(scope_name) = scope.override_name()
1254 && settings
1255 .edit_predictions_disabled_in
1256 .iter()
1257 .any(|s| s == scope_name)
1258 {
1259 return false;
1260 }
1261 }
1262
1263 true
1264 }
1265
1266 #[cfg(not(test))]
1267 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1268 #[cfg(test)]
1269 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1270
1271 fn queue_prediction_refresh(
1272 &mut self,
1273 project: Entity<Project>,
1274 throttle_entity: EntityId,
1275 cx: &mut Context<Self>,
1276 do_refresh: impl FnOnce(
1277 WeakEntity<Self>,
1278 &mut AsyncApp,
1279 )
1280 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1281 + 'static,
1282 ) {
1283 let project_state = self.get_or_init_project(&project, cx);
1284 let pending_prediction_id = project_state.next_pending_prediction_id;
1285 project_state.next_pending_prediction_id += 1;
1286 let last_request = project_state.last_prediction_refresh;
1287
1288 let task = cx.spawn(async move |this, cx| {
1289 if let Some((last_entity, last_timestamp)) = last_request
1290 && throttle_entity == last_entity
1291 && let Some(timeout) =
1292 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1293 {
1294 cx.background_executor().timer(timeout).await;
1295 }
1296
1297 // If this task was cancelled before the throttle timeout expired,
1298 // do not perform a request.
1299 let mut is_cancelled = true;
1300 this.update(cx, |this, cx| {
1301 let project_state = this.get_or_init_project(&project, cx);
1302 if !project_state
1303 .cancelled_predictions
1304 .remove(&pending_prediction_id)
1305 {
1306 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1307 is_cancelled = false;
1308 }
1309 })
1310 .ok();
1311 if is_cancelled {
1312 return None;
1313 }
1314
1315 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1316 let new_prediction_id = new_prediction_result
1317 .as_ref()
1318 .map(|(prediction, _)| prediction.id.clone());
1319
1320 // When a prediction completes, remove it from the pending list, and cancel
1321 // any pending predictions that were enqueued before it.
1322 this.update(cx, |this, cx| {
1323 let project_state = this.get_or_init_project(&project, cx);
1324
1325 let is_cancelled = project_state
1326 .cancelled_predictions
1327 .remove(&pending_prediction_id);
1328
1329 let new_current_prediction = if !is_cancelled
1330 && let Some((prediction_result, requested_by)) = new_prediction_result
1331 {
1332 match prediction_result.prediction {
1333 Ok(prediction) => {
1334 let new_prediction = CurrentEditPrediction {
1335 requested_by,
1336 prediction,
1337 was_shown: false,
1338 };
1339
1340 if let Some(current_prediction) =
1341 project_state.current_prediction.as_ref()
1342 {
1343 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1344 {
1345 this.reject_current_prediction(
1346 EditPredictionRejectReason::Replaced,
1347 &project,
1348 );
1349
1350 Some(new_prediction)
1351 } else {
1352 this.reject_prediction(
1353 new_prediction.prediction.id,
1354 EditPredictionRejectReason::CurrentPreferred,
1355 false,
1356 );
1357 None
1358 }
1359 } else {
1360 Some(new_prediction)
1361 }
1362 }
1363 Err(reject_reason) => {
1364 this.reject_prediction(prediction_result.id, reject_reason, false);
1365 None
1366 }
1367 }
1368 } else {
1369 None
1370 };
1371
1372 let project_state = this.get_or_init_project(&project, cx);
1373
1374 if let Some(new_prediction) = new_current_prediction {
1375 project_state.current_prediction = Some(new_prediction);
1376 }
1377
1378 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1379 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1380 if pending_prediction.id == pending_prediction_id {
1381 pending_predictions.remove(ix);
1382 for pending_prediction in pending_predictions.drain(0..ix) {
1383 project_state.cancel_pending_prediction(pending_prediction, cx)
1384 }
1385 break;
1386 }
1387 }
1388 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1389 cx.notify();
1390 })
1391 .ok();
1392
1393 new_prediction_id
1394 });
1395
1396 if project_state.pending_predictions.len() <= 1 {
1397 project_state.pending_predictions.push(PendingPrediction {
1398 id: pending_prediction_id,
1399 task,
1400 });
1401 } else if project_state.pending_predictions.len() == 2 {
1402 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1403 project_state.pending_predictions.push(PendingPrediction {
1404 id: pending_prediction_id,
1405 task,
1406 });
1407 project_state.cancel_pending_prediction(pending_prediction, cx);
1408 }
1409 }
1410
1411 pub fn request_prediction(
1412 &mut self,
1413 project: &Entity<Project>,
1414 active_buffer: &Entity<Buffer>,
1415 position: language::Anchor,
1416 trigger: PredictEditsRequestTrigger,
1417 cx: &mut Context<Self>,
1418 ) -> Task<Result<Option<EditPredictionResult>>> {
1419 self.request_prediction_internal(
1420 project.clone(),
1421 active_buffer.clone(),
1422 position,
1423 trigger,
1424 cx.has_flag::<Zeta2FeatureFlag>(),
1425 cx,
1426 )
1427 }
1428
1429 fn request_prediction_internal(
1430 &mut self,
1431 project: Entity<Project>,
1432 active_buffer: Entity<Buffer>,
1433 position: language::Anchor,
1434 trigger: PredictEditsRequestTrigger,
1435 allow_jump: bool,
1436 cx: &mut Context<Self>,
1437 ) -> Task<Result<Option<EditPredictionResult>>> {
1438 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1439
1440 self.get_or_init_project(&project, cx);
1441 let project_state = self.projects.get(&project.entity_id()).unwrap();
1442 let events = project_state.events(cx);
1443 let has_events = !events.is_empty();
1444 let debug_tx = project_state.debug_tx.clone();
1445
1446 let snapshot = active_buffer.read(cx).snapshot();
1447 let cursor_point = position.to_point(&snapshot);
1448 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1449 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1450 let diagnostic_search_range =
1451 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1452
1453 let related_files = if self.use_context {
1454 self.context_for_project(&project, cx)
1455 } else {
1456 Vec::new().into()
1457 };
1458
1459 let inputs = EditPredictionModelInput {
1460 project: project.clone(),
1461 buffer: active_buffer.clone(),
1462 snapshot: snapshot.clone(),
1463 position,
1464 events,
1465 related_files,
1466 recent_paths: project_state.recent_paths.clone(),
1467 trigger,
1468 diagnostic_search_range: diagnostic_search_range.clone(),
1469 debug_tx,
1470 };
1471
1472 let task = match self.edit_prediction_model {
1473 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1474 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1475 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1476 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1477 };
1478
1479 cx.spawn(async move |this, cx| {
1480 let prediction = task.await?;
1481
1482 if prediction.is_none() && allow_jump {
1483 let cursor_point = position.to_point(&snapshot);
1484 if has_events
1485 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1486 active_buffer.clone(),
1487 &snapshot,
1488 diagnostic_search_range,
1489 cursor_point,
1490 &project,
1491 cx,
1492 )
1493 .await?
1494 {
1495 return this
1496 .update(cx, |this, cx| {
1497 this.request_prediction_internal(
1498 project,
1499 jump_buffer,
1500 jump_position,
1501 trigger,
1502 false,
1503 cx,
1504 )
1505 })?
1506 .await;
1507 }
1508
1509 return anyhow::Ok(None);
1510 }
1511
1512 Ok(prediction)
1513 })
1514 }
1515
1516 async fn next_diagnostic_location(
1517 active_buffer: Entity<Buffer>,
1518 active_buffer_snapshot: &BufferSnapshot,
1519 active_buffer_diagnostic_search_range: Range<Point>,
1520 active_buffer_cursor_point: Point,
1521 project: &Entity<Project>,
1522 cx: &mut AsyncApp,
1523 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1524 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1525 let mut jump_location = active_buffer_snapshot
1526 .diagnostic_groups(None)
1527 .into_iter()
1528 .filter_map(|(_, group)| {
1529 let range = &group.entries[group.primary_ix]
1530 .range
1531 .to_point(&active_buffer_snapshot);
1532 if range.overlaps(&active_buffer_diagnostic_search_range) {
1533 None
1534 } else {
1535 Some(range.start)
1536 }
1537 })
1538 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1539 .map(|position| {
1540 (
1541 active_buffer.clone(),
1542 active_buffer_snapshot.anchor_before(position),
1543 )
1544 });
1545
1546 if jump_location.is_none() {
1547 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1548 let file = buffer.file()?;
1549
1550 Some(ProjectPath {
1551 worktree_id: file.worktree_id(cx),
1552 path: file.path().clone(),
1553 })
1554 })?;
1555
1556 let buffer_task = project.update(cx, |project, cx| {
1557 let (path, _, _) = project
1558 .diagnostic_summaries(false, cx)
1559 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1560 .max_by_key(|(path, _, _)| {
1561 // find the buffer with errors that shares most parent directories
1562 path.path
1563 .components()
1564 .zip(
1565 active_buffer_path
1566 .as_ref()
1567 .map(|p| p.path.components())
1568 .unwrap_or_default(),
1569 )
1570 .take_while(|(a, b)| a == b)
1571 .count()
1572 })?;
1573
1574 Some(project.open_buffer(path, cx))
1575 })?;
1576
1577 if let Some(buffer_task) = buffer_task {
1578 let closest_buffer = buffer_task.await?;
1579
1580 jump_location = closest_buffer
1581 .read_with(cx, |buffer, _cx| {
1582 buffer
1583 .buffer_diagnostics(None)
1584 .into_iter()
1585 .min_by_key(|entry| entry.diagnostic.severity)
1586 .map(|entry| entry.range.start)
1587 })?
1588 .map(|position| (closest_buffer, position));
1589 }
1590 }
1591
1592 anyhow::Ok(jump_location)
1593 }
1594
1595 async fn send_raw_llm_request(
1596 request: open_ai::Request,
1597 client: Arc<Client>,
1598 llm_token: LlmApiToken,
1599 app_version: Version,
1600 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1601 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1602 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1603 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1604 http_client::Url::parse(&predict_edits_url)?
1605 } else {
1606 client
1607 .http_client()
1608 .build_zed_llm_url("/predict_edits/raw", &[])?
1609 };
1610
1611 #[cfg(feature = "cli-support")]
1612 let cache_key = if let Some(cache) = eval_cache {
1613 use collections::FxHasher;
1614 use std::hash::{Hash, Hasher};
1615
1616 let mut hasher = FxHasher::default();
1617 url.hash(&mut hasher);
1618 let request_str = serde_json::to_string_pretty(&request)?;
1619 request_str.hash(&mut hasher);
1620 let hash = hasher.finish();
1621
1622 let key = (eval_cache_kind, hash);
1623 if let Some(response_str) = cache.read(key) {
1624 return Ok((serde_json::from_str(&response_str)?, None));
1625 }
1626
1627 Some((cache, request_str, key))
1628 } else {
1629 None
1630 };
1631
1632 let (response, usage) = Self::send_api_request(
1633 |builder| {
1634 let req = builder
1635 .uri(url.as_ref())
1636 .body(serde_json::to_string(&request)?.into());
1637 Ok(req?)
1638 },
1639 client,
1640 llm_token,
1641 app_version,
1642 )
1643 .await?;
1644
1645 #[cfg(feature = "cli-support")]
1646 if let Some((cache, request, key)) = cache_key {
1647 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1648 }
1649
1650 Ok((response, usage))
1651 }
1652
1653 fn handle_api_response<T>(
1654 this: &WeakEntity<Self>,
1655 response: Result<(T, Option<EditPredictionUsage>)>,
1656 cx: &mut gpui::AsyncApp,
1657 ) -> Result<T> {
1658 match response {
1659 Ok((data, usage)) => {
1660 if let Some(usage) = usage {
1661 this.update(cx, |this, cx| {
1662 this.user_store.update(cx, |user_store, cx| {
1663 user_store.update_edit_prediction_usage(usage, cx);
1664 });
1665 })
1666 .ok();
1667 }
1668 Ok(data)
1669 }
1670 Err(err) => {
1671 if err.is::<ZedUpdateRequiredError>() {
1672 cx.update(|cx| {
1673 this.update(cx, |this, _cx| {
1674 this.update_required = true;
1675 })
1676 .ok();
1677
1678 let error_message: SharedString = err.to_string().into();
1679 show_app_notification(
1680 NotificationId::unique::<ZedUpdateRequiredError>(),
1681 cx,
1682 move |cx| {
1683 cx.new(|cx| {
1684 ErrorMessagePrompt::new(error_message.clone(), cx)
1685 .with_link_button("Update Zed", "https://zed.dev/releases")
1686 })
1687 },
1688 );
1689 })
1690 .ok();
1691 }
1692 Err(err)
1693 }
1694 }
1695 }
1696
1697 async fn send_api_request<Res>(
1698 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1699 client: Arc<Client>,
1700 llm_token: LlmApiToken,
1701 app_version: Version,
1702 ) -> Result<(Res, Option<EditPredictionUsage>)>
1703 where
1704 Res: DeserializeOwned,
1705 {
1706 let http_client = client.http_client();
1707 let mut token = llm_token.acquire(&client).await?;
1708 let mut did_retry = false;
1709
1710 loop {
1711 let request_builder = http_client::Request::builder().method(Method::POST);
1712
1713 let request = build(
1714 request_builder
1715 .header("Content-Type", "application/json")
1716 .header("Authorization", format!("Bearer {}", token))
1717 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1718 )?;
1719
1720 let mut response = http_client.send(request).await?;
1721
1722 if let Some(minimum_required_version) = response
1723 .headers()
1724 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1725 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1726 {
1727 anyhow::ensure!(
1728 app_version >= minimum_required_version,
1729 ZedUpdateRequiredError {
1730 minimum_version: minimum_required_version
1731 }
1732 );
1733 }
1734
1735 if response.status().is_success() {
1736 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1737
1738 let mut body = Vec::new();
1739 response.body_mut().read_to_end(&mut body).await?;
1740 return Ok((serde_json::from_slice(&body)?, usage));
1741 } else if !did_retry
1742 && response
1743 .headers()
1744 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1745 .is_some()
1746 {
1747 did_retry = true;
1748 token = llm_token.refresh(&client).await?;
1749 } else {
1750 let mut body = String::new();
1751 response.body_mut().read_to_string(&mut body).await?;
1752 anyhow::bail!(
1753 "Request failed with status: {:?}\nBody: {}",
1754 response.status(),
1755 body
1756 );
1757 }
1758 }
1759 }
1760
1761 pub fn refresh_context(
1762 &mut self,
1763 project: &Entity<Project>,
1764 buffer: &Entity<language::Buffer>,
1765 cursor_position: language::Anchor,
1766 cx: &mut Context<Self>,
1767 ) {
1768 if self.use_context {
1769 self.get_or_init_project(project, cx)
1770 .context
1771 .update(cx, |store, cx| {
1772 store.refresh(buffer.clone(), cursor_position, cx);
1773 });
1774 }
1775 }
1776
1777 #[cfg(feature = "cli-support")]
1778 pub fn set_context_for_buffer(
1779 &mut self,
1780 project: &Entity<Project>,
1781 related_files: Vec<RelatedFile>,
1782 cx: &mut Context<Self>,
1783 ) {
1784 self.get_or_init_project(project, cx)
1785 .context
1786 .update(cx, |store, _| {
1787 store.set_related_files(related_files);
1788 });
1789 }
1790
1791 fn is_file_open_source(
1792 &self,
1793 project: &Entity<Project>,
1794 file: &Arc<dyn File>,
1795 cx: &App,
1796 ) -> bool {
1797 if !file.is_local() || file.is_private() {
1798 return false;
1799 }
1800 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1801 return false;
1802 };
1803 project_state
1804 .license_detection_watchers
1805 .get(&file.worktree_id(cx))
1806 .as_ref()
1807 .is_some_and(|watcher| watcher.is_project_open_source())
1808 }
1809
1810 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1811 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1812 }
1813
1814 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1815 if !self.data_collection_choice.is_enabled() {
1816 return false;
1817 }
1818 events.iter().all(|event| {
1819 matches!(
1820 event.as_ref(),
1821 zeta_prompt::Event::BufferChange {
1822 in_open_source_repo: true,
1823 ..
1824 }
1825 )
1826 })
1827 }
1828
1829 fn load_data_collection_choice() -> DataCollectionChoice {
1830 let choice = KEY_VALUE_STORE
1831 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1832 .log_err()
1833 .flatten();
1834
1835 match choice.as_deref() {
1836 Some("true") => DataCollectionChoice::Enabled,
1837 Some("false") => DataCollectionChoice::Disabled,
1838 Some(_) => {
1839 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1840 DataCollectionChoice::NotAnswered
1841 }
1842 None => DataCollectionChoice::NotAnswered,
1843 }
1844 }
1845
1846 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1847 self.data_collection_choice = self.data_collection_choice.toggle();
1848 let new_choice = self.data_collection_choice;
1849 db::write_and_log(cx, move || {
1850 KEY_VALUE_STORE.write_kvp(
1851 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1852 new_choice.is_enabled().to_string(),
1853 )
1854 });
1855 }
1856
1857 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1858 self.shown_predictions.iter()
1859 }
1860
1861 pub fn shown_completions_len(&self) -> usize {
1862 self.shown_predictions.len()
1863 }
1864
1865 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1866 self.rated_predictions.contains(id)
1867 }
1868
1869 pub fn rate_prediction(
1870 &mut self,
1871 prediction: &EditPrediction,
1872 rating: EditPredictionRating,
1873 feedback: String,
1874 cx: &mut Context<Self>,
1875 ) {
1876 self.rated_predictions.insert(prediction.id.clone());
1877 telemetry::event!(
1878 "Edit Prediction Rated",
1879 rating,
1880 inputs = prediction.inputs,
1881 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1882 feedback
1883 );
1884 self.client.telemetry().flush_events().detach();
1885 cx.notify();
1886 }
1887
1888 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1889 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1890 && all_language_settings(None, cx).edit_predictions.use_context;
1891 }
1892}
1893
1894#[derive(Error, Debug)]
1895#[error(
1896 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1897)]
1898pub struct ZedUpdateRequiredError {
1899 minimum_version: Version,
1900}
1901
1902#[cfg(feature = "cli-support")]
1903pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1904
1905#[cfg(feature = "cli-support")]
1906#[derive(Debug, Clone, Copy, PartialEq)]
1907pub enum EvalCacheEntryKind {
1908 Context,
1909 Search,
1910 Prediction,
1911}
1912
1913#[cfg(feature = "cli-support")]
1914impl std::fmt::Display for EvalCacheEntryKind {
1915 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1916 match self {
1917 EvalCacheEntryKind::Search => write!(f, "search"),
1918 EvalCacheEntryKind::Context => write!(f, "context"),
1919 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1920 }
1921 }
1922}
1923
1924#[cfg(feature = "cli-support")]
1925pub trait EvalCache: Send + Sync {
1926 fn read(&self, key: EvalCacheKey) -> Option<String>;
1927 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1928}
1929
1930#[derive(Debug, Clone, Copy)]
1931pub enum DataCollectionChoice {
1932 NotAnswered,
1933 Enabled,
1934 Disabled,
1935}
1936
1937impl DataCollectionChoice {
1938 pub fn is_enabled(self) -> bool {
1939 match self {
1940 Self::Enabled => true,
1941 Self::NotAnswered | Self::Disabled => false,
1942 }
1943 }
1944
1945 pub fn is_answered(self) -> bool {
1946 match self {
1947 Self::Enabled | Self::Disabled => true,
1948 Self::NotAnswered => false,
1949 }
1950 }
1951
1952 #[must_use]
1953 pub fn toggle(&self) -> DataCollectionChoice {
1954 match self {
1955 Self::Enabled => Self::Disabled,
1956 Self::Disabled => Self::Enabled,
1957 Self::NotAnswered => Self::Enabled,
1958 }
1959 }
1960}
1961
1962impl From<bool> for DataCollectionChoice {
1963 fn from(value: bool) -> Self {
1964 match value {
1965 true => DataCollectionChoice::Enabled,
1966 false => DataCollectionChoice::Disabled,
1967 }
1968 }
1969}
1970
1971struct ZedPredictUpsell;
1972
1973impl Dismissable for ZedPredictUpsell {
1974 const KEY: &'static str = "dismissed-edit-predict-upsell";
1975
1976 fn dismissed() -> bool {
1977 // To make this backwards compatible with older versions of Zed, we
1978 // check if the user has seen the previous Edit Prediction Onboarding
1979 // before, by checking the data collection choice which was written to
1980 // the database once the user clicked on "Accept and Enable"
1981 if KEY_VALUE_STORE
1982 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1983 .log_err()
1984 .is_some_and(|s| s.is_some())
1985 {
1986 return true;
1987 }
1988
1989 KEY_VALUE_STORE
1990 .read_kvp(Self::KEY)
1991 .log_err()
1992 .is_some_and(|s| s.is_some())
1993 }
1994}
1995
1996pub fn should_show_upsell_modal() -> bool {
1997 !ZedPredictUpsell::dismissed()
1998}
1999
2000pub fn init(cx: &mut App) {
2001 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2002 workspace.register_action(
2003 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2004 ZedPredictModal::toggle(
2005 workspace,
2006 workspace.user_store().clone(),
2007 workspace.client().clone(),
2008 window,
2009 cx,
2010 )
2011 },
2012 );
2013
2014 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2015 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2016 settings
2017 .project
2018 .all_languages
2019 .features
2020 .get_or_insert_default()
2021 .edit_prediction_provider = Some(EditPredictionProvider::None)
2022 });
2023 });
2024 })
2025 .detach();
2026}