1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
24 http_client::{self, AsyncBody, Method},
25 prelude::*,
26};
27use language::language_settings::all_language_settings;
28use language::{Anchor, Buffer, File, Point, ToPoint};
29use language::{BufferSnapshot, OffsetRangeExt};
30use language_model::{LlmApiToken, RefreshLlmTokenListener};
31use project::{Project, ProjectPath, WorktreeId};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
36use std::collections::{VecDeque, hash_map};
37use workspace::Workspace;
38
39use std::ops::Range;
40use std::path::Path;
41use std::rc::Rc;
42use std::str::FromStr as _;
43use std::sync::{Arc, LazyLock};
44use std::time::{Duration, Instant};
45use std::{env, mem};
46use thiserror::Error;
47use util::{RangeExt as _, ResultExt as _};
48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
49
50mod cursor_excerpt;
51mod license_detection;
52pub mod mercury;
53mod onboarding_modal;
54pub mod open_ai_response;
55mod prediction;
56pub mod sweep_ai;
57
58#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
59pub mod udiff;
60
61mod zed_edit_prediction_delegate;
62pub mod zeta1;
63pub mod zeta2;
64
65#[cfg(test)]
66mod edit_prediction_tests;
67
68use crate::license_detection::LicenseDetectionWatcher;
69use crate::mercury::Mercury;
70use crate::onboarding_modal::ZedPredictModal;
71pub use crate::prediction::EditPrediction;
72pub use crate::prediction::EditPredictionId;
73use crate::prediction::EditPredictionResult;
74pub use crate::sweep_ai::SweepAi;
75pub use telemetry_events::EditPredictionRating;
76pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
77
78actions!(
79 edit_prediction,
80 [
81 /// Resets the edit prediction onboarding state.
82 ResetOnboarding,
83 /// Clears the edit prediction history.
84 ClearHistory,
85 ]
86);
87
88/// Maximum number of events to track.
89const EVENT_COUNT_MAX: usize = 6;
90const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
91const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
92const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
93
94pub struct SweepFeatureFlag;
95
96impl FeatureFlag for SweepFeatureFlag {
97 const NAME: &str = "sweep-ai";
98}
99
100pub struct MercuryFeatureFlag;
101
102impl FeatureFlag for MercuryFeatureFlag {
103 const NAME: &str = "mercury";
104}
105
106pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
107 context: EditPredictionExcerptOptions {
108 max_bytes: 512,
109 min_bytes: 128,
110 target_before_cursor_over_total_bytes: 0.5,
111 },
112 prompt_format: PromptFormat::DEFAULT,
113};
114
115static USE_OLLAMA: LazyLock<bool> =
116 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
117
118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
119 match env::var("ZED_ZETA2_MODEL").as_deref() {
120 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
121 Ok(model) => model,
122 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
123 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
124 }
125 .to_string()
126});
127static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
128 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
129 if *USE_OLLAMA {
130 Some("http://localhost:11434/v1/chat/completions".into())
131 } else {
132 None
133 }
134 })
135});
136
137pub struct Zeta2FeatureFlag;
138
139impl FeatureFlag for Zeta2FeatureFlag {
140 const NAME: &'static str = "zeta2";
141
142 fn enabled_for_staff() -> bool {
143 true
144 }
145}
146
147#[derive(Clone)]
148struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
149
150impl Global for EditPredictionStoreGlobal {}
151
152pub struct EditPredictionStore {
153 client: Arc<Client>,
154 user_store: Entity<UserStore>,
155 llm_token: LlmApiToken,
156 _llm_token_subscription: Subscription,
157 projects: HashMap<EntityId, ProjectState>,
158 use_context: bool,
159 options: ZetaOptions,
160 update_required: bool,
161 #[cfg(feature = "eval-support")]
162 eval_cache: Option<Arc<dyn EvalCache>>,
163 edit_prediction_model: EditPredictionModel,
164 pub sweep_ai: SweepAi,
165 pub mercury: Mercury,
166 data_collection_choice: DataCollectionChoice,
167 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
168 shown_predictions: VecDeque<EditPrediction>,
169 rated_predictions: HashSet<EditPredictionId>,
170}
171
172#[derive(Copy, Clone, Default, PartialEq, Eq)]
173pub enum EditPredictionModel {
174 #[default]
175 Zeta1,
176 Zeta2,
177 Sweep,
178 Mercury,
179}
180
181pub struct EditPredictionModelInput {
182 project: Entity<Project>,
183 buffer: Entity<Buffer>,
184 snapshot: BufferSnapshot,
185 position: Anchor,
186 events: Vec<Arc<zeta_prompt::Event>>,
187 related_files: Arc<[RelatedFile]>,
188 recent_paths: VecDeque<ProjectPath>,
189 trigger: PredictEditsRequestTrigger,
190 diagnostic_search_range: Range<Point>,
191 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
192}
193
194#[derive(Debug, Clone, PartialEq)]
195pub struct ZetaOptions {
196 pub context: EditPredictionExcerptOptions,
197 pub prompt_format: predict_edits_v3::PromptFormat,
198}
199
200#[derive(Debug)]
201pub enum DebugEvent {
202 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
203 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
204 EditPredictionStarted(EditPredictionStartedDebugEvent),
205 EditPredictionFinished(EditPredictionFinishedDebugEvent),
206}
207
208#[derive(Debug)]
209pub struct ContextRetrievalStartedDebugEvent {
210 pub project_entity_id: EntityId,
211 pub timestamp: Instant,
212 pub search_prompt: String,
213}
214
215#[derive(Debug)]
216pub struct ContextRetrievalFinishedDebugEvent {
217 pub project_entity_id: EntityId,
218 pub timestamp: Instant,
219 pub metadata: Vec<(&'static str, SharedString)>,
220}
221
222#[derive(Debug)]
223pub struct EditPredictionStartedDebugEvent {
224 pub buffer: WeakEntity<Buffer>,
225 pub position: Anchor,
226 pub prompt: Option<String>,
227}
228
229#[derive(Debug)]
230pub struct EditPredictionFinishedDebugEvent {
231 pub buffer: WeakEntity<Buffer>,
232 pub position: Anchor,
233 pub model_output: Option<String>,
234}
235
236pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
237
238struct ProjectState {
239 events: VecDeque<Arc<zeta_prompt::Event>>,
240 last_event: Option<LastEvent>,
241 recent_paths: VecDeque<ProjectPath>,
242 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
243 current_prediction: Option<CurrentEditPrediction>,
244 next_pending_prediction_id: usize,
245 pending_predictions: ArrayVec<PendingPrediction, 2>,
246 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
247 last_prediction_refresh: Option<(EntityId, Instant)>,
248 cancelled_predictions: HashSet<usize>,
249 context: Entity<RelatedExcerptStore>,
250 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
251 _subscription: gpui::Subscription,
252}
253
254impl ProjectState {
255 pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
256 self.events
257 .iter()
258 .cloned()
259 .chain(
260 self.last_event
261 .as_ref()
262 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
263 )
264 .collect()
265 }
266
267 fn cancel_pending_prediction(
268 &mut self,
269 pending_prediction: PendingPrediction,
270 cx: &mut Context<EditPredictionStore>,
271 ) {
272 self.cancelled_predictions.insert(pending_prediction.id);
273
274 cx.spawn(async move |this, cx| {
275 let Some(prediction_id) = pending_prediction.task.await else {
276 return;
277 };
278
279 this.update(cx, |this, _cx| {
280 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
281 })
282 .ok();
283 })
284 .detach()
285 }
286
287 fn active_buffer(
288 &self,
289 project: &Entity<Project>,
290 cx: &App,
291 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
292 let project = project.read(cx);
293 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
294 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
295 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
296 Some((active_buffer, registered_buffer.last_position))
297 }
298}
299
300#[derive(Debug, Clone)]
301struct CurrentEditPrediction {
302 pub requested_by: PredictionRequestedBy,
303 pub prediction: EditPrediction,
304 pub was_shown: bool,
305}
306
307impl CurrentEditPrediction {
308 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
309 let Some(new_edits) = self
310 .prediction
311 .interpolate(&self.prediction.buffer.read(cx))
312 else {
313 return false;
314 };
315
316 if self.prediction.buffer != old_prediction.prediction.buffer {
317 return true;
318 }
319
320 let Some(old_edits) = old_prediction
321 .prediction
322 .interpolate(&old_prediction.prediction.buffer.read(cx))
323 else {
324 return true;
325 };
326
327 let requested_by_buffer_id = self.requested_by.buffer_id();
328
329 // This reduces the occurrence of UI thrash from replacing edits
330 //
331 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
332 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
333 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
334 && old_edits.len() == 1
335 && new_edits.len() == 1
336 {
337 let (old_range, old_text) = &old_edits[0];
338 let (new_range, new_text) = &new_edits[0];
339 new_range == old_range && new_text.starts_with(old_text.as_ref())
340 } else {
341 true
342 }
343 }
344}
345
346#[derive(Debug, Clone)]
347enum PredictionRequestedBy {
348 DiagnosticsUpdate,
349 Buffer(EntityId),
350}
351
352impl PredictionRequestedBy {
353 pub fn buffer_id(&self) -> Option<EntityId> {
354 match self {
355 PredictionRequestedBy::DiagnosticsUpdate => None,
356 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
357 }
358 }
359}
360
361#[derive(Debug)]
362struct PendingPrediction {
363 id: usize,
364 task: Task<Option<EditPredictionId>>,
365}
366
367/// A prediction from the perspective of a buffer.
368#[derive(Debug)]
369enum BufferEditPrediction<'a> {
370 Local { prediction: &'a EditPrediction },
371 Jump { prediction: &'a EditPrediction },
372}
373
374#[cfg(test)]
375impl std::ops::Deref for BufferEditPrediction<'_> {
376 type Target = EditPrediction;
377
378 fn deref(&self) -> &Self::Target {
379 match self {
380 BufferEditPrediction::Local { prediction } => prediction,
381 BufferEditPrediction::Jump { prediction } => prediction,
382 }
383 }
384}
385
386struct RegisteredBuffer {
387 snapshot: BufferSnapshot,
388 last_position: Option<Anchor>,
389 _subscriptions: [gpui::Subscription; 2],
390}
391
392struct LastEvent {
393 old_snapshot: BufferSnapshot,
394 new_snapshot: BufferSnapshot,
395 end_edit_anchor: Option<Anchor>,
396}
397
398impl LastEvent {
399 pub fn finalize(
400 &self,
401 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
402 cx: &App,
403 ) -> Option<Arc<zeta_prompt::Event>> {
404 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
405 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
406
407 let file = self.new_snapshot.file();
408 let old_file = self.old_snapshot.file();
409
410 let in_open_source_repo = [file, old_file].iter().all(|file| {
411 file.is_some_and(|file| {
412 license_detection_watchers
413 .get(&file.worktree_id(cx))
414 .is_some_and(|watcher| watcher.is_project_open_source())
415 })
416 });
417
418 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
419
420 if path == old_path && diff.is_empty() {
421 None
422 } else {
423 Some(Arc::new(zeta_prompt::Event::BufferChange {
424 old_path,
425 path,
426 diff,
427 in_open_source_repo,
428 // TODO: Actually detect if this edit was predicted or not
429 predicted: false,
430 }))
431 }
432 }
433}
434
435fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
436 if let Some(file) = snapshot.file() {
437 file.full_path(cx).into()
438 } else {
439 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
440 }
441}
442
443impl EditPredictionStore {
444 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
445 cx.try_global::<EditPredictionStoreGlobal>()
446 .map(|global| global.0.clone())
447 }
448
449 pub fn global(
450 client: &Arc<Client>,
451 user_store: &Entity<UserStore>,
452 cx: &mut App,
453 ) -> Entity<Self> {
454 cx.try_global::<EditPredictionStoreGlobal>()
455 .map(|global| global.0.clone())
456 .unwrap_or_else(|| {
457 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
458 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
459 ep_store
460 })
461 }
462
463 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
464 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
465 let data_collection_choice = Self::load_data_collection_choice();
466
467 let llm_token = LlmApiToken::default();
468
469 let (reject_tx, reject_rx) = mpsc::unbounded();
470 cx.background_spawn({
471 let client = client.clone();
472 let llm_token = llm_token.clone();
473 let app_version = AppVersion::global(cx);
474 let background_executor = cx.background_executor().clone();
475 async move {
476 Self::handle_rejected_predictions(
477 reject_rx,
478 client,
479 llm_token,
480 app_version,
481 background_executor,
482 )
483 .await
484 }
485 })
486 .detach();
487
488 let mut this = Self {
489 projects: HashMap::default(),
490 client,
491 user_store,
492 options: DEFAULT_OPTIONS,
493 use_context: false,
494 llm_token,
495 _llm_token_subscription: cx.subscribe(
496 &refresh_llm_token_listener,
497 |this, _listener, _event, cx| {
498 let client = this.client.clone();
499 let llm_token = this.llm_token.clone();
500 cx.spawn(async move |_this, _cx| {
501 llm_token.refresh(&client).await?;
502 anyhow::Ok(())
503 })
504 .detach_and_log_err(cx);
505 },
506 ),
507 update_required: false,
508 #[cfg(feature = "eval-support")]
509 eval_cache: None,
510 edit_prediction_model: EditPredictionModel::Zeta2,
511 sweep_ai: SweepAi::new(cx),
512 mercury: Mercury::new(cx),
513 data_collection_choice,
514 reject_predictions_tx: reject_tx,
515 rated_predictions: Default::default(),
516 shown_predictions: Default::default(),
517 };
518
519 this.configure_context_retrieval(cx);
520 let weak_this = cx.weak_entity();
521 cx.on_flags_ready(move |_, cx| {
522 weak_this
523 .update(cx, |this, cx| this.configure_context_retrieval(cx))
524 .ok();
525 })
526 .detach();
527 cx.observe_global::<SettingsStore>(|this, cx| {
528 this.configure_context_retrieval(cx);
529 })
530 .detach();
531
532 this
533 }
534
535 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
536 self.edit_prediction_model = model;
537 }
538
539 pub fn has_sweep_api_token(&self) -> bool {
540 self.sweep_ai
541 .api_token
542 .clone()
543 .now_or_never()
544 .flatten()
545 .is_some()
546 }
547
548 pub fn has_mercury_api_token(&self) -> bool {
549 self.mercury
550 .api_token
551 .clone()
552 .now_or_never()
553 .flatten()
554 .is_some()
555 }
556
557 #[cfg(feature = "eval-support")]
558 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
559 self.eval_cache = Some(cache);
560 }
561
562 pub fn options(&self) -> &ZetaOptions {
563 &self.options
564 }
565
566 pub fn set_options(&mut self, options: ZetaOptions) {
567 self.options = options;
568 }
569
570 pub fn set_use_context(&mut self, use_context: bool) {
571 self.use_context = use_context;
572 }
573
574 pub fn clear_history(&mut self) {
575 for project_state in self.projects.values_mut() {
576 project_state.events.clear();
577 }
578 }
579
580 pub fn edit_history_for_project(
581 &self,
582 project: &Entity<Project>,
583 ) -> Vec<Arc<zeta_prompt::Event>> {
584 self.projects
585 .get(&project.entity_id())
586 .map(|project_state| project_state.events.iter().cloned().collect())
587 .unwrap_or_default()
588 }
589
590 pub fn context_for_project<'a>(
591 &'a self,
592 project: &Entity<Project>,
593 cx: &'a App,
594 ) -> Arc<[RelatedFile]> {
595 self.projects
596 .get(&project.entity_id())
597 .map(|project| project.context.read(cx).related_files())
598 .unwrap_or_else(|| vec![].into())
599 }
600
601 pub fn context_for_project_with_buffers<'a>(
602 &'a self,
603 project: &Entity<Project>,
604 cx: &'a App,
605 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
606 self.projects
607 .get(&project.entity_id())
608 .map(|project| project.context.read(cx).related_files_with_buffers())
609 }
610
611 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
612 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
613 self.user_store.read(cx).edit_prediction_usage()
614 } else {
615 None
616 }
617 }
618
619 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
620 self.get_or_init_project(project, cx);
621 }
622
623 pub fn register_buffer(
624 &mut self,
625 buffer: &Entity<Buffer>,
626 project: &Entity<Project>,
627 cx: &mut Context<Self>,
628 ) {
629 let project_state = self.get_or_init_project(project, cx);
630 Self::register_buffer_impl(project_state, buffer, project, cx);
631 }
632
633 fn get_or_init_project(
634 &mut self,
635 project: &Entity<Project>,
636 cx: &mut Context<Self>,
637 ) -> &mut ProjectState {
638 let entity_id = project.entity_id();
639 self.projects
640 .entry(entity_id)
641 .or_insert_with(|| ProjectState {
642 context: {
643 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
644 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
645 this.handle_excerpt_store_event(entity_id, event);
646 })
647 .detach();
648 related_excerpt_store
649 },
650 events: VecDeque::new(),
651 last_event: None,
652 recent_paths: VecDeque::new(),
653 debug_tx: None,
654 registered_buffers: HashMap::default(),
655 current_prediction: None,
656 cancelled_predictions: HashSet::default(),
657 pending_predictions: ArrayVec::new(),
658 next_pending_prediction_id: 0,
659 last_prediction_refresh: None,
660 license_detection_watchers: HashMap::default(),
661 _subscription: cx.subscribe(&project, Self::handle_project_event),
662 })
663 }
664
665 pub fn remove_project(&mut self, project: &Entity<Project>) {
666 self.projects.remove(&project.entity_id());
667 }
668
669 fn handle_excerpt_store_event(
670 &mut self,
671 project_entity_id: EntityId,
672 event: &RelatedExcerptStoreEvent,
673 ) {
674 if let Some(project_state) = self.projects.get(&project_entity_id) {
675 if let Some(debug_tx) = project_state.debug_tx.clone() {
676 match event {
677 RelatedExcerptStoreEvent::StartedRefresh => {
678 debug_tx
679 .unbounded_send(DebugEvent::ContextRetrievalStarted(
680 ContextRetrievalStartedDebugEvent {
681 project_entity_id: project_entity_id,
682 timestamp: Instant::now(),
683 search_prompt: String::new(),
684 },
685 ))
686 .ok();
687 }
688 RelatedExcerptStoreEvent::FinishedRefresh {
689 cache_hit_count,
690 cache_miss_count,
691 mean_definition_latency,
692 max_definition_latency,
693 } => {
694 debug_tx
695 .unbounded_send(DebugEvent::ContextRetrievalFinished(
696 ContextRetrievalFinishedDebugEvent {
697 project_entity_id: project_entity_id,
698 timestamp: Instant::now(),
699 metadata: vec![
700 (
701 "Cache Hits",
702 format!(
703 "{}/{}",
704 cache_hit_count,
705 cache_hit_count + cache_miss_count
706 )
707 .into(),
708 ),
709 (
710 "Max LSP Time",
711 format!("{} ms", max_definition_latency.as_millis())
712 .into(),
713 ),
714 (
715 "Mean LSP Time",
716 format!("{} ms", mean_definition_latency.as_millis())
717 .into(),
718 ),
719 ],
720 },
721 ))
722 .ok();
723 }
724 }
725 }
726 }
727 }
728
729 pub fn debug_info(
730 &mut self,
731 project: &Entity<Project>,
732 cx: &mut Context<Self>,
733 ) -> mpsc::UnboundedReceiver<DebugEvent> {
734 let project_state = self.get_or_init_project(project, cx);
735 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
736 project_state.debug_tx = Some(debug_watch_tx);
737 debug_watch_rx
738 }
739
740 fn handle_project_event(
741 &mut self,
742 project: Entity<Project>,
743 event: &project::Event,
744 cx: &mut Context<Self>,
745 ) {
746 // TODO [zeta2] init with recent paths
747 match event {
748 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
749 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
750 return;
751 };
752 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
753 if let Some(path) = path {
754 if let Some(ix) = project_state
755 .recent_paths
756 .iter()
757 .position(|probe| probe == &path)
758 {
759 project_state.recent_paths.remove(ix);
760 }
761 project_state.recent_paths.push_front(path);
762 }
763 }
764 project::Event::DiagnosticsUpdated { .. } => {
765 if cx.has_flag::<Zeta2FeatureFlag>() {
766 self.refresh_prediction_from_diagnostics(project, cx);
767 }
768 }
769 _ => (),
770 }
771 }
772
773 fn register_buffer_impl<'a>(
774 project_state: &'a mut ProjectState,
775 buffer: &Entity<Buffer>,
776 project: &Entity<Project>,
777 cx: &mut Context<Self>,
778 ) -> &'a mut RegisteredBuffer {
779 let buffer_id = buffer.entity_id();
780
781 if let Some(file) = buffer.read(cx).file() {
782 let worktree_id = file.worktree_id(cx);
783 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
784 project_state
785 .license_detection_watchers
786 .entry(worktree_id)
787 .or_insert_with(|| {
788 let project_entity_id = project.entity_id();
789 cx.observe_release(&worktree, move |this, _worktree, _cx| {
790 let Some(project_state) = this.projects.get_mut(&project_entity_id)
791 else {
792 return;
793 };
794 project_state
795 .license_detection_watchers
796 .remove(&worktree_id);
797 })
798 .detach();
799 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
800 });
801 }
802 }
803
804 match project_state.registered_buffers.entry(buffer_id) {
805 hash_map::Entry::Occupied(entry) => entry.into_mut(),
806 hash_map::Entry::Vacant(entry) => {
807 let snapshot = buffer.read(cx).snapshot();
808 let project_entity_id = project.entity_id();
809 entry.insert(RegisteredBuffer {
810 snapshot,
811 last_position: None,
812 _subscriptions: [
813 cx.subscribe(buffer, {
814 let project = project.downgrade();
815 move |this, buffer, event, cx| {
816 if let language::BufferEvent::Edited = event
817 && let Some(project) = project.upgrade()
818 {
819 this.report_changes_for_buffer(&buffer, &project, cx);
820 }
821 }
822 }),
823 cx.observe_release(buffer, move |this, _buffer, _cx| {
824 let Some(project_state) = this.projects.get_mut(&project_entity_id)
825 else {
826 return;
827 };
828 project_state.registered_buffers.remove(&buffer_id);
829 }),
830 ],
831 })
832 }
833 }
834 }
835
836 fn report_changes_for_buffer(
837 &mut self,
838 buffer: &Entity<Buffer>,
839 project: &Entity<Project>,
840 cx: &mut Context<Self>,
841 ) {
842 let project_state = self.get_or_init_project(project, cx);
843 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
844
845 let new_snapshot = buffer.read(cx).snapshot();
846 if new_snapshot.version == registered_buffer.snapshot.version {
847 return;
848 }
849
850 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
851 let end_edit_anchor = new_snapshot
852 .anchored_edits_since::<Point>(&old_snapshot.version)
853 .last()
854 .map(|(_, range)| range.end);
855 let events = &mut project_state.events;
856
857 if let Some(LastEvent {
858 new_snapshot: last_new_snapshot,
859 end_edit_anchor: last_end_edit_anchor,
860 ..
861 }) = project_state.last_event.as_mut()
862 {
863 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
864 == last_new_snapshot.remote_id()
865 && old_snapshot.version == last_new_snapshot.version;
866
867 let should_coalesce = is_next_snapshot_of_same_buffer
868 && end_edit_anchor
869 .as_ref()
870 .zip(last_end_edit_anchor.as_ref())
871 .is_some_and(|(a, b)| {
872 let a = a.to_point(&new_snapshot);
873 let b = b.to_point(&new_snapshot);
874 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
875 });
876
877 if should_coalesce {
878 *last_end_edit_anchor = end_edit_anchor;
879 *last_new_snapshot = new_snapshot;
880 return;
881 }
882 }
883
884 if events.len() + 1 >= EVENT_COUNT_MAX {
885 events.pop_front();
886 }
887
888 if let Some(event) = project_state.last_event.take() {
889 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
890 }
891
892 project_state.last_event = Some(LastEvent {
893 old_snapshot,
894 new_snapshot,
895 end_edit_anchor,
896 });
897 }
898
899 fn prediction_at(
900 &mut self,
901 buffer: &Entity<Buffer>,
902 position: Option<language::Anchor>,
903 project: &Entity<Project>,
904 cx: &App,
905 ) -> Option<BufferEditPrediction<'_>> {
906 let project_state = self.projects.get_mut(&project.entity_id())?;
907 if let Some(position) = position
908 && let Some(buffer) = project_state
909 .registered_buffers
910 .get_mut(&buffer.entity_id())
911 {
912 buffer.last_position = Some(position);
913 }
914
915 let CurrentEditPrediction {
916 requested_by,
917 prediction,
918 ..
919 } = project_state.current_prediction.as_ref()?;
920
921 if prediction.targets_buffer(buffer.read(cx)) {
922 Some(BufferEditPrediction::Local { prediction })
923 } else {
924 let show_jump = match requested_by {
925 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
926 requested_by_buffer_id == &buffer.entity_id()
927 }
928 PredictionRequestedBy::DiagnosticsUpdate => true,
929 };
930
931 if show_jump {
932 Some(BufferEditPrediction::Jump { prediction })
933 } else {
934 None
935 }
936 }
937 }
938
939 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
940 match self.edit_prediction_model {
941 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
942 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
943 }
944
945 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
946 return;
947 };
948
949 let Some(prediction) = project_state.current_prediction.take() else {
950 return;
951 };
952 let request_id = prediction.prediction.id.to_string();
953 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
954 project_state.cancel_pending_prediction(pending_prediction, cx);
955 }
956
957 let client = self.client.clone();
958 let llm_token = self.llm_token.clone();
959 let app_version = AppVersion::global(cx);
960 cx.spawn(async move |this, cx| {
961 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
962 http_client::Url::parse(&predict_edits_url)?
963 } else {
964 client
965 .http_client()
966 .build_zed_llm_url("/predict_edits/accept", &[])?
967 };
968
969 let response = cx
970 .background_spawn(Self::send_api_request::<()>(
971 move |builder| {
972 let req = builder.uri(url.as_ref()).body(
973 serde_json::to_string(&AcceptEditPredictionBody {
974 request_id: request_id.clone(),
975 })?
976 .into(),
977 );
978 Ok(req?)
979 },
980 client,
981 llm_token,
982 app_version,
983 ))
984 .await;
985
986 Self::handle_api_response(&this, response, cx)?;
987 anyhow::Ok(())
988 })
989 .detach_and_log_err(cx);
990 }
991
992 async fn handle_rejected_predictions(
993 rx: UnboundedReceiver<EditPredictionRejection>,
994 client: Arc<Client>,
995 llm_token: LlmApiToken,
996 app_version: Version,
997 background_executor: BackgroundExecutor,
998 ) {
999 let mut rx = std::pin::pin!(rx.peekable());
1000 let mut batched = Vec::new();
1001
1002 while let Some(rejection) = rx.next().await {
1003 batched.push(rejection);
1004
1005 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1006 select_biased! {
1007 next = rx.as_mut().peek().fuse() => {
1008 if next.is_some() {
1009 continue;
1010 }
1011 }
1012 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1013 }
1014 }
1015
1016 let url = client
1017 .http_client()
1018 .build_zed_llm_url("/predict_edits/reject", &[])
1019 .unwrap();
1020
1021 let flush_count = batched
1022 .len()
1023 // in case items have accumulated after failure
1024 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1025 let start = batched.len() - flush_count;
1026
1027 let body = RejectEditPredictionsBodyRef {
1028 rejections: &batched[start..],
1029 };
1030
1031 let result = Self::send_api_request::<()>(
1032 |builder| {
1033 let req = builder
1034 .uri(url.as_ref())
1035 .body(serde_json::to_string(&body)?.into());
1036 anyhow::Ok(req?)
1037 },
1038 client.clone(),
1039 llm_token.clone(),
1040 app_version.clone(),
1041 )
1042 .await;
1043
1044 if result.log_err().is_some() {
1045 batched.drain(start..);
1046 }
1047 }
1048 }
1049
1050 fn reject_current_prediction(
1051 &mut self,
1052 reason: EditPredictionRejectReason,
1053 project: &Entity<Project>,
1054 ) {
1055 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1056 project_state.pending_predictions.clear();
1057 if let Some(prediction) = project_state.current_prediction.take() {
1058 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1059 }
1060 };
1061 }
1062
1063 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1064 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1065 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1066 if !current_prediction.was_shown {
1067 current_prediction.was_shown = true;
1068 self.shown_predictions
1069 .push_front(current_prediction.prediction.clone());
1070 if self.shown_predictions.len() > 50 {
1071 let completion = self.shown_predictions.pop_back().unwrap();
1072 self.rated_predictions.remove(&completion.id);
1073 }
1074 }
1075 }
1076 }
1077 }
1078
1079 fn reject_prediction(
1080 &mut self,
1081 prediction_id: EditPredictionId,
1082 reason: EditPredictionRejectReason,
1083 was_shown: bool,
1084 ) {
1085 match self.edit_prediction_model {
1086 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1087 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1088 }
1089
1090 self.reject_predictions_tx
1091 .unbounded_send(EditPredictionRejection {
1092 request_id: prediction_id.to_string(),
1093 reason,
1094 was_shown,
1095 })
1096 .log_err();
1097 }
1098
1099 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1100 self.projects
1101 .get(&project.entity_id())
1102 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1103 }
1104
1105 pub fn refresh_prediction_from_buffer(
1106 &mut self,
1107 project: Entity<Project>,
1108 buffer: Entity<Buffer>,
1109 position: language::Anchor,
1110 cx: &mut Context<Self>,
1111 ) {
1112 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1113 let Some(request_task) = this
1114 .update(cx, |this, cx| {
1115 this.request_prediction(
1116 &project,
1117 &buffer,
1118 position,
1119 PredictEditsRequestTrigger::Other,
1120 cx,
1121 )
1122 })
1123 .log_err()
1124 else {
1125 return Task::ready(anyhow::Ok(None));
1126 };
1127
1128 cx.spawn(async move |_cx| {
1129 request_task.await.map(|prediction_result| {
1130 prediction_result.map(|prediction_result| {
1131 (
1132 prediction_result,
1133 PredictionRequestedBy::Buffer(buffer.entity_id()),
1134 )
1135 })
1136 })
1137 })
1138 })
1139 }
1140
1141 pub fn refresh_prediction_from_diagnostics(
1142 &mut self,
1143 project: Entity<Project>,
1144 cx: &mut Context<Self>,
1145 ) {
1146 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1147 return;
1148 };
1149
1150 // Prefer predictions from buffer
1151 if project_state.current_prediction.is_some() {
1152 return;
1153 };
1154
1155 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1156 let Some((active_buffer, snapshot, cursor_point)) = this
1157 .read_with(cx, |this, cx| {
1158 let project_state = this.projects.get(&project.entity_id())?;
1159 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1160 let snapshot = buffer.read(cx).snapshot();
1161
1162 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1163 return None;
1164 }
1165
1166 let cursor_point = position
1167 .map(|pos| pos.to_point(&snapshot))
1168 .unwrap_or_default();
1169
1170 Some((buffer, snapshot, cursor_point))
1171 })
1172 .log_err()
1173 .flatten()
1174 else {
1175 return Task::ready(anyhow::Ok(None));
1176 };
1177
1178 cx.spawn(async move |cx| {
1179 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1180 active_buffer,
1181 &snapshot,
1182 Default::default(),
1183 cursor_point,
1184 &project,
1185 cx,
1186 )
1187 .await?
1188 else {
1189 return anyhow::Ok(None);
1190 };
1191
1192 let Some(prediction_result) = this
1193 .update(cx, |this, cx| {
1194 this.request_prediction(
1195 &project,
1196 &jump_buffer,
1197 jump_position,
1198 PredictEditsRequestTrigger::Diagnostics,
1199 cx,
1200 )
1201 })?
1202 .await?
1203 else {
1204 return anyhow::Ok(None);
1205 };
1206
1207 this.update(cx, |this, cx| {
1208 Some((
1209 if this
1210 .get_or_init_project(&project, cx)
1211 .current_prediction
1212 .is_none()
1213 {
1214 prediction_result
1215 } else {
1216 EditPredictionResult {
1217 id: prediction_result.id,
1218 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1219 }
1220 },
1221 PredictionRequestedBy::DiagnosticsUpdate,
1222 ))
1223 })
1224 })
1225 });
1226 }
1227
1228 fn predictions_enabled_at(
1229 snapshot: &BufferSnapshot,
1230 position: Option<language::Anchor>,
1231 cx: &App,
1232 ) -> bool {
1233 let file = snapshot.file();
1234 let all_settings = all_language_settings(file, cx);
1235 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1236 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1237 {
1238 return false;
1239 }
1240
1241 if let Some(last_position) = position {
1242 let settings = snapshot.settings_at(last_position, cx);
1243
1244 if !settings.edit_predictions_disabled_in.is_empty()
1245 && let Some(scope) = snapshot.language_scope_at(last_position)
1246 && let Some(scope_name) = scope.override_name()
1247 && settings
1248 .edit_predictions_disabled_in
1249 .iter()
1250 .any(|s| s == scope_name)
1251 {
1252 return false;
1253 }
1254 }
1255
1256 true
1257 }
1258
1259 #[cfg(not(test))]
1260 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1261 #[cfg(test)]
1262 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1263
1264 fn queue_prediction_refresh(
1265 &mut self,
1266 project: Entity<Project>,
1267 throttle_entity: EntityId,
1268 cx: &mut Context<Self>,
1269 do_refresh: impl FnOnce(
1270 WeakEntity<Self>,
1271 &mut AsyncApp,
1272 )
1273 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1274 + 'static,
1275 ) {
1276 let project_state = self.get_or_init_project(&project, cx);
1277 let pending_prediction_id = project_state.next_pending_prediction_id;
1278 project_state.next_pending_prediction_id += 1;
1279 let last_request = project_state.last_prediction_refresh;
1280
1281 let task = cx.spawn(async move |this, cx| {
1282 if let Some((last_entity, last_timestamp)) = last_request
1283 && throttle_entity == last_entity
1284 && let Some(timeout) =
1285 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1286 {
1287 cx.background_executor().timer(timeout).await;
1288 }
1289
1290 // If this task was cancelled before the throttle timeout expired,
1291 // do not perform a request.
1292 let mut is_cancelled = true;
1293 this.update(cx, |this, cx| {
1294 let project_state = this.get_or_init_project(&project, cx);
1295 if !project_state
1296 .cancelled_predictions
1297 .remove(&pending_prediction_id)
1298 {
1299 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1300 is_cancelled = false;
1301 }
1302 })
1303 .ok();
1304 if is_cancelled {
1305 return None;
1306 }
1307
1308 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1309 let new_prediction_id = new_prediction_result
1310 .as_ref()
1311 .map(|(prediction, _)| prediction.id.clone());
1312
1313 // When a prediction completes, remove it from the pending list, and cancel
1314 // any pending predictions that were enqueued before it.
1315 this.update(cx, |this, cx| {
1316 let project_state = this.get_or_init_project(&project, cx);
1317
1318 let is_cancelled = project_state
1319 .cancelled_predictions
1320 .remove(&pending_prediction_id);
1321
1322 let new_current_prediction = if !is_cancelled
1323 && let Some((prediction_result, requested_by)) = new_prediction_result
1324 {
1325 match prediction_result.prediction {
1326 Ok(prediction) => {
1327 let new_prediction = CurrentEditPrediction {
1328 requested_by,
1329 prediction,
1330 was_shown: false,
1331 };
1332
1333 if let Some(current_prediction) =
1334 project_state.current_prediction.as_ref()
1335 {
1336 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1337 {
1338 this.reject_current_prediction(
1339 EditPredictionRejectReason::Replaced,
1340 &project,
1341 );
1342
1343 Some(new_prediction)
1344 } else {
1345 this.reject_prediction(
1346 new_prediction.prediction.id,
1347 EditPredictionRejectReason::CurrentPreferred,
1348 false,
1349 );
1350 None
1351 }
1352 } else {
1353 Some(new_prediction)
1354 }
1355 }
1356 Err(reject_reason) => {
1357 this.reject_prediction(prediction_result.id, reject_reason, false);
1358 None
1359 }
1360 }
1361 } else {
1362 None
1363 };
1364
1365 let project_state = this.get_or_init_project(&project, cx);
1366
1367 if let Some(new_prediction) = new_current_prediction {
1368 project_state.current_prediction = Some(new_prediction);
1369 }
1370
1371 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1372 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1373 if pending_prediction.id == pending_prediction_id {
1374 pending_predictions.remove(ix);
1375 for pending_prediction in pending_predictions.drain(0..ix) {
1376 project_state.cancel_pending_prediction(pending_prediction, cx)
1377 }
1378 break;
1379 }
1380 }
1381 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1382 cx.notify();
1383 })
1384 .ok();
1385
1386 new_prediction_id
1387 });
1388
1389 if project_state.pending_predictions.len() <= 1 {
1390 project_state.pending_predictions.push(PendingPrediction {
1391 id: pending_prediction_id,
1392 task,
1393 });
1394 } else if project_state.pending_predictions.len() == 2 {
1395 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1396 project_state.pending_predictions.push(PendingPrediction {
1397 id: pending_prediction_id,
1398 task,
1399 });
1400 project_state.cancel_pending_prediction(pending_prediction, cx);
1401 }
1402 }
1403
1404 pub fn request_prediction(
1405 &mut self,
1406 project: &Entity<Project>,
1407 active_buffer: &Entity<Buffer>,
1408 position: language::Anchor,
1409 trigger: PredictEditsRequestTrigger,
1410 cx: &mut Context<Self>,
1411 ) -> Task<Result<Option<EditPredictionResult>>> {
1412 self.request_prediction_internal(
1413 project.clone(),
1414 active_buffer.clone(),
1415 position,
1416 trigger,
1417 cx.has_flag::<Zeta2FeatureFlag>(),
1418 cx,
1419 )
1420 }
1421
1422 fn request_prediction_internal(
1423 &mut self,
1424 project: Entity<Project>,
1425 active_buffer: Entity<Buffer>,
1426 position: language::Anchor,
1427 trigger: PredictEditsRequestTrigger,
1428 allow_jump: bool,
1429 cx: &mut Context<Self>,
1430 ) -> Task<Result<Option<EditPredictionResult>>> {
1431 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1432
1433 self.get_or_init_project(&project, cx);
1434 let project_state = self.projects.get(&project.entity_id()).unwrap();
1435 let events = project_state.events(cx);
1436 let has_events = !events.is_empty();
1437 let debug_tx = project_state.debug_tx.clone();
1438
1439 let snapshot = active_buffer.read(cx).snapshot();
1440 let cursor_point = position.to_point(&snapshot);
1441 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1442 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1443 let diagnostic_search_range =
1444 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1445
1446 let related_files = if self.use_context {
1447 self.context_for_project(&project, cx)
1448 } else {
1449 Vec::new().into()
1450 };
1451
1452 let inputs = EditPredictionModelInput {
1453 project: project.clone(),
1454 buffer: active_buffer.clone(),
1455 snapshot: snapshot.clone(),
1456 position,
1457 events,
1458 related_files,
1459 recent_paths: project_state.recent_paths.clone(),
1460 trigger,
1461 diagnostic_search_range: diagnostic_search_range.clone(),
1462 debug_tx,
1463 };
1464
1465 let task = match self.edit_prediction_model {
1466 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1467 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1468 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1469 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1470 };
1471
1472 cx.spawn(async move |this, cx| {
1473 let prediction = task.await?;
1474
1475 if prediction.is_none() && allow_jump {
1476 let cursor_point = position.to_point(&snapshot);
1477 if has_events
1478 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1479 active_buffer.clone(),
1480 &snapshot,
1481 diagnostic_search_range,
1482 cursor_point,
1483 &project,
1484 cx,
1485 )
1486 .await?
1487 {
1488 return this
1489 .update(cx, |this, cx| {
1490 this.request_prediction_internal(
1491 project,
1492 jump_buffer,
1493 jump_position,
1494 trigger,
1495 false,
1496 cx,
1497 )
1498 })?
1499 .await;
1500 }
1501
1502 return anyhow::Ok(None);
1503 }
1504
1505 Ok(prediction)
1506 })
1507 }
1508
1509 async fn next_diagnostic_location(
1510 active_buffer: Entity<Buffer>,
1511 active_buffer_snapshot: &BufferSnapshot,
1512 active_buffer_diagnostic_search_range: Range<Point>,
1513 active_buffer_cursor_point: Point,
1514 project: &Entity<Project>,
1515 cx: &mut AsyncApp,
1516 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1517 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1518 let mut jump_location = active_buffer_snapshot
1519 .diagnostic_groups(None)
1520 .into_iter()
1521 .filter_map(|(_, group)| {
1522 let range = &group.entries[group.primary_ix]
1523 .range
1524 .to_point(&active_buffer_snapshot);
1525 if range.overlaps(&active_buffer_diagnostic_search_range) {
1526 None
1527 } else {
1528 Some(range.start)
1529 }
1530 })
1531 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1532 .map(|position| {
1533 (
1534 active_buffer.clone(),
1535 active_buffer_snapshot.anchor_before(position),
1536 )
1537 });
1538
1539 if jump_location.is_none() {
1540 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1541 let file = buffer.file()?;
1542
1543 Some(ProjectPath {
1544 worktree_id: file.worktree_id(cx),
1545 path: file.path().clone(),
1546 })
1547 })?;
1548
1549 let buffer_task = project.update(cx, |project, cx| {
1550 let (path, _, _) = project
1551 .diagnostic_summaries(false, cx)
1552 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1553 .max_by_key(|(path, _, _)| {
1554 // find the buffer with errors that shares most parent directories
1555 path.path
1556 .components()
1557 .zip(
1558 active_buffer_path
1559 .as_ref()
1560 .map(|p| p.path.components())
1561 .unwrap_or_default(),
1562 )
1563 .take_while(|(a, b)| a == b)
1564 .count()
1565 })?;
1566
1567 Some(project.open_buffer(path, cx))
1568 })?;
1569
1570 if let Some(buffer_task) = buffer_task {
1571 let closest_buffer = buffer_task.await?;
1572
1573 jump_location = closest_buffer
1574 .read_with(cx, |buffer, _cx| {
1575 buffer
1576 .buffer_diagnostics(None)
1577 .into_iter()
1578 .min_by_key(|entry| entry.diagnostic.severity)
1579 .map(|entry| entry.range.start)
1580 })?
1581 .map(|position| (closest_buffer, position));
1582 }
1583 }
1584
1585 anyhow::Ok(jump_location)
1586 }
1587
1588 async fn send_raw_llm_request(
1589 request: open_ai::Request,
1590 client: Arc<Client>,
1591 llm_token: LlmApiToken,
1592 app_version: Version,
1593 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1594 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1595 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1596 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1597 http_client::Url::parse(&predict_edits_url)?
1598 } else {
1599 client
1600 .http_client()
1601 .build_zed_llm_url("/predict_edits/raw", &[])?
1602 };
1603
1604 #[cfg(feature = "eval-support")]
1605 let cache_key = if let Some(cache) = eval_cache {
1606 use collections::FxHasher;
1607 use std::hash::{Hash, Hasher};
1608
1609 let mut hasher = FxHasher::default();
1610 url.hash(&mut hasher);
1611 let request_str = serde_json::to_string_pretty(&request)?;
1612 request_str.hash(&mut hasher);
1613 let hash = hasher.finish();
1614
1615 let key = (eval_cache_kind, hash);
1616 if let Some(response_str) = cache.read(key) {
1617 return Ok((serde_json::from_str(&response_str)?, None));
1618 }
1619
1620 Some((cache, request_str, key))
1621 } else {
1622 None
1623 };
1624
1625 let (response, usage) = Self::send_api_request(
1626 |builder| {
1627 let req = builder
1628 .uri(url.as_ref())
1629 .body(serde_json::to_string(&request)?.into());
1630 Ok(req?)
1631 },
1632 client,
1633 llm_token,
1634 app_version,
1635 )
1636 .await?;
1637
1638 #[cfg(feature = "eval-support")]
1639 if let Some((cache, request, key)) = cache_key {
1640 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1641 }
1642
1643 Ok((response, usage))
1644 }
1645
1646 fn handle_api_response<T>(
1647 this: &WeakEntity<Self>,
1648 response: Result<(T, Option<EditPredictionUsage>)>,
1649 cx: &mut gpui::AsyncApp,
1650 ) -> Result<T> {
1651 match response {
1652 Ok((data, usage)) => {
1653 if let Some(usage) = usage {
1654 this.update(cx, |this, cx| {
1655 this.user_store.update(cx, |user_store, cx| {
1656 user_store.update_edit_prediction_usage(usage, cx);
1657 });
1658 })
1659 .ok();
1660 }
1661 Ok(data)
1662 }
1663 Err(err) => {
1664 if err.is::<ZedUpdateRequiredError>() {
1665 cx.update(|cx| {
1666 this.update(cx, |this, _cx| {
1667 this.update_required = true;
1668 })
1669 .ok();
1670
1671 let error_message: SharedString = err.to_string().into();
1672 show_app_notification(
1673 NotificationId::unique::<ZedUpdateRequiredError>(),
1674 cx,
1675 move |cx| {
1676 cx.new(|cx| {
1677 ErrorMessagePrompt::new(error_message.clone(), cx)
1678 .with_link_button("Update Zed", "https://zed.dev/releases")
1679 })
1680 },
1681 );
1682 })
1683 .ok();
1684 }
1685 Err(err)
1686 }
1687 }
1688 }
1689
1690 async fn send_api_request<Res>(
1691 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1692 client: Arc<Client>,
1693 llm_token: LlmApiToken,
1694 app_version: Version,
1695 ) -> Result<(Res, Option<EditPredictionUsage>)>
1696 where
1697 Res: DeserializeOwned,
1698 {
1699 let http_client = client.http_client();
1700 let mut token = llm_token.acquire(&client).await?;
1701 let mut did_retry = false;
1702
1703 loop {
1704 let request_builder = http_client::Request::builder().method(Method::POST);
1705
1706 let request = build(
1707 request_builder
1708 .header("Content-Type", "application/json")
1709 .header("Authorization", format!("Bearer {}", token))
1710 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1711 )?;
1712
1713 let mut response = http_client.send(request).await?;
1714
1715 if let Some(minimum_required_version) = response
1716 .headers()
1717 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1718 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1719 {
1720 anyhow::ensure!(
1721 app_version >= minimum_required_version,
1722 ZedUpdateRequiredError {
1723 minimum_version: minimum_required_version
1724 }
1725 );
1726 }
1727
1728 if response.status().is_success() {
1729 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1730
1731 let mut body = Vec::new();
1732 response.body_mut().read_to_end(&mut body).await?;
1733 return Ok((serde_json::from_slice(&body)?, usage));
1734 } else if !did_retry
1735 && response
1736 .headers()
1737 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1738 .is_some()
1739 {
1740 did_retry = true;
1741 token = llm_token.refresh(&client).await?;
1742 } else {
1743 let mut body = String::new();
1744 response.body_mut().read_to_string(&mut body).await?;
1745 anyhow::bail!(
1746 "Request failed with status: {:?}\nBody: {}",
1747 response.status(),
1748 body
1749 );
1750 }
1751 }
1752 }
1753
1754 pub fn refresh_context(
1755 &mut self,
1756 project: &Entity<Project>,
1757 buffer: &Entity<language::Buffer>,
1758 cursor_position: language::Anchor,
1759 cx: &mut Context<Self>,
1760 ) {
1761 if self.use_context {
1762 self.get_or_init_project(project, cx)
1763 .context
1764 .update(cx, |store, cx| {
1765 store.refresh(buffer.clone(), cursor_position, cx);
1766 });
1767 }
1768 }
1769
1770 #[cfg(feature = "eval-support")]
1771 pub fn set_context_for_buffer(
1772 &mut self,
1773 project: &Entity<Project>,
1774 related_files: Vec<RelatedFile>,
1775 cx: &mut Context<Self>,
1776 ) {
1777 self.get_or_init_project(project, cx)
1778 .context
1779 .update(cx, |store, _| {
1780 store.set_related_files(related_files);
1781 });
1782 }
1783
1784 fn is_file_open_source(
1785 &self,
1786 project: &Entity<Project>,
1787 file: &Arc<dyn File>,
1788 cx: &App,
1789 ) -> bool {
1790 if !file.is_local() || file.is_private() {
1791 return false;
1792 }
1793 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1794 return false;
1795 };
1796 project_state
1797 .license_detection_watchers
1798 .get(&file.worktree_id(cx))
1799 .as_ref()
1800 .is_some_and(|watcher| watcher.is_project_open_source())
1801 }
1802
1803 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1804 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1805 }
1806
1807 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1808 if !self.data_collection_choice.is_enabled() {
1809 return false;
1810 }
1811 events.iter().all(|event| {
1812 matches!(
1813 event.as_ref(),
1814 zeta_prompt::Event::BufferChange {
1815 in_open_source_repo: true,
1816 ..
1817 }
1818 )
1819 })
1820 }
1821
1822 fn load_data_collection_choice() -> DataCollectionChoice {
1823 let choice = KEY_VALUE_STORE
1824 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1825 .log_err()
1826 .flatten();
1827
1828 match choice.as_deref() {
1829 Some("true") => DataCollectionChoice::Enabled,
1830 Some("false") => DataCollectionChoice::Disabled,
1831 Some(_) => {
1832 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1833 DataCollectionChoice::NotAnswered
1834 }
1835 None => DataCollectionChoice::NotAnswered,
1836 }
1837 }
1838
1839 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1840 self.data_collection_choice = self.data_collection_choice.toggle();
1841 let new_choice = self.data_collection_choice;
1842 db::write_and_log(cx, move || {
1843 KEY_VALUE_STORE.write_kvp(
1844 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1845 new_choice.is_enabled().to_string(),
1846 )
1847 });
1848 }
1849
1850 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1851 self.shown_predictions.iter()
1852 }
1853
1854 pub fn shown_completions_len(&self) -> usize {
1855 self.shown_predictions.len()
1856 }
1857
1858 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1859 self.rated_predictions.contains(id)
1860 }
1861
1862 pub fn rate_prediction(
1863 &mut self,
1864 prediction: &EditPrediction,
1865 rating: EditPredictionRating,
1866 feedback: String,
1867 cx: &mut Context<Self>,
1868 ) {
1869 self.rated_predictions.insert(prediction.id.clone());
1870 telemetry::event!(
1871 "Edit Prediction Rated",
1872 rating,
1873 inputs = prediction.inputs,
1874 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1875 feedback
1876 );
1877 self.client.telemetry().flush_events().detach();
1878 cx.notify();
1879 }
1880
1881 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1882 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1883 && all_language_settings(None, cx).edit_predictions.use_context;
1884 }
1885}
1886
1887#[derive(Error, Debug)]
1888#[error(
1889 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1890)]
1891pub struct ZedUpdateRequiredError {
1892 minimum_version: Version,
1893}
1894
1895#[cfg(feature = "eval-support")]
1896pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1897
1898#[cfg(feature = "eval-support")]
1899#[derive(Debug, Clone, Copy, PartialEq)]
1900pub enum EvalCacheEntryKind {
1901 Context,
1902 Search,
1903 Prediction,
1904}
1905
1906#[cfg(feature = "eval-support")]
1907impl std::fmt::Display for EvalCacheEntryKind {
1908 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1909 match self {
1910 EvalCacheEntryKind::Search => write!(f, "search"),
1911 EvalCacheEntryKind::Context => write!(f, "context"),
1912 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1913 }
1914 }
1915}
1916
1917#[cfg(feature = "eval-support")]
1918pub trait EvalCache: Send + Sync {
1919 fn read(&self, key: EvalCacheKey) -> Option<String>;
1920 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1921}
1922
1923#[derive(Debug, Clone, Copy)]
1924pub enum DataCollectionChoice {
1925 NotAnswered,
1926 Enabled,
1927 Disabled,
1928}
1929
1930impl DataCollectionChoice {
1931 pub fn is_enabled(self) -> bool {
1932 match self {
1933 Self::Enabled => true,
1934 Self::NotAnswered | Self::Disabled => false,
1935 }
1936 }
1937
1938 pub fn is_answered(self) -> bool {
1939 match self {
1940 Self::Enabled | Self::Disabled => true,
1941 Self::NotAnswered => false,
1942 }
1943 }
1944
1945 #[must_use]
1946 pub fn toggle(&self) -> DataCollectionChoice {
1947 match self {
1948 Self::Enabled => Self::Disabled,
1949 Self::Disabled => Self::Enabled,
1950 Self::NotAnswered => Self::Enabled,
1951 }
1952 }
1953}
1954
1955impl From<bool> for DataCollectionChoice {
1956 fn from(value: bool) -> Self {
1957 match value {
1958 true => DataCollectionChoice::Enabled,
1959 false => DataCollectionChoice::Disabled,
1960 }
1961 }
1962}
1963
1964struct ZedPredictUpsell;
1965
1966impl Dismissable for ZedPredictUpsell {
1967 const KEY: &'static str = "dismissed-edit-predict-upsell";
1968
1969 fn dismissed() -> bool {
1970 // To make this backwards compatible with older versions of Zed, we
1971 // check if the user has seen the previous Edit Prediction Onboarding
1972 // before, by checking the data collection choice which was written to
1973 // the database once the user clicked on "Accept and Enable"
1974 if KEY_VALUE_STORE
1975 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1976 .log_err()
1977 .is_some_and(|s| s.is_some())
1978 {
1979 return true;
1980 }
1981
1982 KEY_VALUE_STORE
1983 .read_kvp(Self::KEY)
1984 .log_err()
1985 .is_some_and(|s| s.is_some())
1986 }
1987}
1988
1989pub fn should_show_upsell_modal() -> bool {
1990 !ZedPredictUpsell::dismissed()
1991}
1992
1993pub fn init(cx: &mut App) {
1994 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
1995 workspace.register_action(
1996 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
1997 ZedPredictModal::toggle(
1998 workspace,
1999 workspace.user_store().clone(),
2000 workspace.client().clone(),
2001 window,
2002 cx,
2003 )
2004 },
2005 );
2006
2007 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2008 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2009 settings
2010 .project
2011 .all_languages
2012 .features
2013 .get_or_insert_default()
2014 .edit_prediction_provider = Some(EditPredictionProvider::None)
2015 });
2016 });
2017 })
2018 .detach();
2019}