1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::http_client::Url;
23use gpui::{
24 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
25 http_client::{self, AsyncBody, Method},
26 prelude::*,
27};
28use language::language_settings::all_language_settings;
29use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
30use language::{BufferSnapshot, OffsetRangeExt};
31use language_model::{LlmApiToken, RefreshLlmTokenListener};
32use project::{Project, ProjectPath, WorktreeId};
33use release_channel::AppVersion;
34use semver::Version;
35use serde::de::DeserializeOwned;
36use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
37use std::collections::{VecDeque, hash_map};
38use text::Edit;
39use workspace::Workspace;
40
41use std::ops::Range;
42use std::path::Path;
43use std::rc::Rc;
44use std::str::FromStr as _;
45use std::sync::{Arc, LazyLock};
46use std::time::{Duration, Instant};
47use std::{env, mem};
48use thiserror::Error;
49use util::{RangeExt as _, ResultExt as _};
50use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
51
52pub mod cursor_excerpt;
53pub mod example_spec;
54mod license_detection;
55pub mod mercury;
56mod onboarding_modal;
57pub mod open_ai_response;
58mod prediction;
59pub mod sweep_ai;
60
61pub mod udiff;
62
63mod capture_example;
64mod zed_edit_prediction_delegate;
65pub mod zeta1;
66pub mod zeta2;
67
68#[cfg(test)]
69mod edit_prediction_tests;
70
71use crate::license_detection::LicenseDetectionWatcher;
72use crate::mercury::Mercury;
73use crate::onboarding_modal::ZedPredictModal;
74pub use crate::prediction::EditPrediction;
75pub use crate::prediction::EditPredictionId;
76use crate::prediction::EditPredictionResult;
77pub use crate::sweep_ai::SweepAi;
78pub use capture_example::capture_example;
79pub use language_model::ApiKeyState;
80pub use telemetry_events::EditPredictionRating;
81pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
82
83actions!(
84 edit_prediction,
85 [
86 /// Resets the edit prediction onboarding state.
87 ResetOnboarding,
88 /// Clears the edit prediction history.
89 ClearHistory,
90 ]
91);
92
93/// Maximum number of events to track.
94const EVENT_COUNT_MAX: usize = 6;
95const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
96const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
97const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
98const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
99
100pub struct SweepFeatureFlag;
101
102impl FeatureFlag for SweepFeatureFlag {
103 const NAME: &str = "sweep-ai";
104}
105
106pub struct MercuryFeatureFlag;
107
108impl FeatureFlag for MercuryFeatureFlag {
109 const NAME: &str = "mercury";
110}
111
112pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
113 context: EditPredictionExcerptOptions {
114 max_bytes: 512,
115 min_bytes: 128,
116 target_before_cursor_over_total_bytes: 0.5,
117 },
118 prompt_format: PromptFormat::DEFAULT,
119};
120
121static USE_OLLAMA: LazyLock<bool> =
122 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
123
124static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
125 match env::var("ZED_ZETA2_MODEL").as_deref() {
126 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
127 Ok(model) => model,
128 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
129 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
130 }
131 .to_string()
132});
133
134pub struct Zeta2FeatureFlag;
135
136impl FeatureFlag for Zeta2FeatureFlag {
137 const NAME: &'static str = "zeta2";
138
139 fn enabled_for_staff() -> bool {
140 true
141 }
142}
143
144#[derive(Clone)]
145struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
146
147impl Global for EditPredictionStoreGlobal {}
148
149pub struct EditPredictionStore {
150 client: Arc<Client>,
151 user_store: Entity<UserStore>,
152 llm_token: LlmApiToken,
153 _llm_token_subscription: Subscription,
154 projects: HashMap<EntityId, ProjectState>,
155 use_context: bool,
156 options: ZetaOptions,
157 update_required: bool,
158 #[cfg(feature = "cli-support")]
159 eval_cache: Option<Arc<dyn EvalCache>>,
160 edit_prediction_model: EditPredictionModel,
161 pub sweep_ai: SweepAi,
162 pub mercury: Mercury,
163 data_collection_choice: DataCollectionChoice,
164 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
165 shown_predictions: VecDeque<EditPrediction>,
166 rated_predictions: HashSet<EditPredictionId>,
167 custom_predict_edits_url: Option<Arc<Url>>,
168}
169
170#[derive(Copy, Clone, Default, PartialEq, Eq)]
171pub enum EditPredictionModel {
172 #[default]
173 Zeta1,
174 Zeta2,
175 Sweep,
176 Mercury,
177}
178
179pub struct EditPredictionModelInput {
180 project: Entity<Project>,
181 buffer: Entity<Buffer>,
182 snapshot: BufferSnapshot,
183 position: Anchor,
184 events: Vec<Arc<zeta_prompt::Event>>,
185 related_files: Arc<[RelatedFile]>,
186 recent_paths: VecDeque<ProjectPath>,
187 trigger: PredictEditsRequestTrigger,
188 diagnostic_search_range: Range<Point>,
189 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
190}
191
192#[derive(Debug, Clone, PartialEq)]
193pub struct ZetaOptions {
194 pub context: EditPredictionExcerptOptions,
195 pub prompt_format: predict_edits_v3::PromptFormat,
196}
197
198#[derive(Debug)]
199pub enum DebugEvent {
200 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
201 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
202 EditPredictionStarted(EditPredictionStartedDebugEvent),
203 EditPredictionFinished(EditPredictionFinishedDebugEvent),
204}
205
206#[derive(Debug)]
207pub struct ContextRetrievalStartedDebugEvent {
208 pub project_entity_id: EntityId,
209 pub timestamp: Instant,
210 pub search_prompt: String,
211}
212
213#[derive(Debug)]
214pub struct ContextRetrievalFinishedDebugEvent {
215 pub project_entity_id: EntityId,
216 pub timestamp: Instant,
217 pub metadata: Vec<(&'static str, SharedString)>,
218}
219
220#[derive(Debug)]
221pub struct EditPredictionStartedDebugEvent {
222 pub buffer: WeakEntity<Buffer>,
223 pub position: Anchor,
224 pub prompt: Option<String>,
225}
226
227#[derive(Debug)]
228pub struct EditPredictionFinishedDebugEvent {
229 pub buffer: WeakEntity<Buffer>,
230 pub position: Anchor,
231 pub model_output: Option<String>,
232}
233
234pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
235
236/// An event with associated metadata for reconstructing buffer state.
237#[derive(Clone)]
238pub struct StoredEvent {
239 pub event: Arc<zeta_prompt::Event>,
240 pub old_snapshot: TextBufferSnapshot,
241}
242
243struct ProjectState {
244 events: VecDeque<StoredEvent>,
245 last_event: Option<LastEvent>,
246 recent_paths: VecDeque<ProjectPath>,
247 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
248 current_prediction: Option<CurrentEditPrediction>,
249 next_pending_prediction_id: usize,
250 pending_predictions: ArrayVec<PendingPrediction, 2>,
251 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
252 last_prediction_refresh: Option<(EntityId, Instant)>,
253 cancelled_predictions: HashSet<usize>,
254 context: Entity<RelatedExcerptStore>,
255 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
256 _subscription: gpui::Subscription,
257}
258
259impl ProjectState {
260 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
261 self.events
262 .iter()
263 .cloned()
264 .chain(
265 self.last_event
266 .as_ref()
267 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
268 )
269 .collect()
270 }
271
272 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
273 self.events
274 .iter()
275 .cloned()
276 .chain(self.last_event.as_ref().iter().flat_map(|event| {
277 let (one, two) = event.split_by_pause();
278 let one = one.finalize(&self.license_detection_watchers, cx);
279 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
280 one.into_iter().chain(two)
281 }))
282 .collect()
283 }
284
285 fn cancel_pending_prediction(
286 &mut self,
287 pending_prediction: PendingPrediction,
288 cx: &mut Context<EditPredictionStore>,
289 ) {
290 self.cancelled_predictions.insert(pending_prediction.id);
291
292 cx.spawn(async move |this, cx| {
293 let Some(prediction_id) = pending_prediction.task.await else {
294 return;
295 };
296
297 this.update(cx, |this, _cx| {
298 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
299 })
300 .ok();
301 })
302 .detach()
303 }
304
305 fn active_buffer(
306 &self,
307 project: &Entity<Project>,
308 cx: &App,
309 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
310 let project = project.read(cx);
311 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
312 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
313 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
314 Some((active_buffer, registered_buffer.last_position))
315 }
316}
317
318#[derive(Debug, Clone)]
319struct CurrentEditPrediction {
320 pub requested_by: PredictionRequestedBy,
321 pub prediction: EditPrediction,
322 pub was_shown: bool,
323}
324
325impl CurrentEditPrediction {
326 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
327 let Some(new_edits) = self
328 .prediction
329 .interpolate(&self.prediction.buffer.read(cx))
330 else {
331 return false;
332 };
333
334 if self.prediction.buffer != old_prediction.prediction.buffer {
335 return true;
336 }
337
338 let Some(old_edits) = old_prediction
339 .prediction
340 .interpolate(&old_prediction.prediction.buffer.read(cx))
341 else {
342 return true;
343 };
344
345 let requested_by_buffer_id = self.requested_by.buffer_id();
346
347 // This reduces the occurrence of UI thrash from replacing edits
348 //
349 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
350 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
351 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
352 && old_edits.len() == 1
353 && new_edits.len() == 1
354 {
355 let (old_range, old_text) = &old_edits[0];
356 let (new_range, new_text) = &new_edits[0];
357 new_range == old_range && new_text.starts_with(old_text.as_ref())
358 } else {
359 true
360 }
361 }
362}
363
364#[derive(Debug, Clone)]
365enum PredictionRequestedBy {
366 DiagnosticsUpdate,
367 Buffer(EntityId),
368}
369
370impl PredictionRequestedBy {
371 pub fn buffer_id(&self) -> Option<EntityId> {
372 match self {
373 PredictionRequestedBy::DiagnosticsUpdate => None,
374 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
375 }
376 }
377}
378
379#[derive(Debug)]
380struct PendingPrediction {
381 id: usize,
382 task: Task<Option<EditPredictionId>>,
383}
384
385/// A prediction from the perspective of a buffer.
386#[derive(Debug)]
387enum BufferEditPrediction<'a> {
388 Local { prediction: &'a EditPrediction },
389 Jump { prediction: &'a EditPrediction },
390}
391
392#[cfg(test)]
393impl std::ops::Deref for BufferEditPrediction<'_> {
394 type Target = EditPrediction;
395
396 fn deref(&self) -> &Self::Target {
397 match self {
398 BufferEditPrediction::Local { prediction } => prediction,
399 BufferEditPrediction::Jump { prediction } => prediction,
400 }
401 }
402}
403
404struct RegisteredBuffer {
405 file: Option<Arc<dyn File>>,
406 snapshot: TextBufferSnapshot,
407 last_position: Option<Anchor>,
408 _subscriptions: [gpui::Subscription; 2],
409}
410
411#[derive(Clone)]
412struct LastEvent {
413 old_snapshot: TextBufferSnapshot,
414 new_snapshot: TextBufferSnapshot,
415 old_file: Option<Arc<dyn File>>,
416 new_file: Option<Arc<dyn File>>,
417 end_edit_anchor: Option<Anchor>,
418 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
419 last_edit_time: Option<Instant>,
420}
421
422impl LastEvent {
423 pub fn finalize(
424 &self,
425 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
426 cx: &App,
427 ) -> Option<StoredEvent> {
428 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
429 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
430
431 let in_open_source_repo =
432 [self.new_file.as_ref(), self.old_file.as_ref()]
433 .iter()
434 .all(|file| {
435 file.is_some_and(|file| {
436 license_detection_watchers
437 .get(&file.worktree_id(cx))
438 .is_some_and(|watcher| watcher.is_project_open_source())
439 })
440 });
441
442 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
443
444 if path == old_path && diff.is_empty() {
445 None
446 } else {
447 Some(StoredEvent {
448 event: Arc::new(zeta_prompt::Event::BufferChange {
449 old_path,
450 path,
451 diff,
452 in_open_source_repo,
453 // TODO: Actually detect if this edit was predicted or not
454 predicted: false,
455 }),
456 old_snapshot: self.old_snapshot.clone(),
457 })
458 }
459 }
460
461 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
462 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
463 return (self.clone(), None);
464 };
465
466 let before = LastEvent {
467 old_snapshot: self.old_snapshot.clone(),
468 new_snapshot: boundary_snapshot.clone(),
469 old_file: self.old_file.clone(),
470 new_file: self.new_file.clone(),
471 end_edit_anchor: self.end_edit_anchor,
472 snapshot_after_last_editing_pause: None,
473 last_edit_time: self.last_edit_time,
474 };
475
476 let after = LastEvent {
477 old_snapshot: boundary_snapshot.clone(),
478 new_snapshot: self.new_snapshot.clone(),
479 old_file: self.old_file.clone(),
480 new_file: self.new_file.clone(),
481 end_edit_anchor: self.end_edit_anchor,
482 snapshot_after_last_editing_pause: None,
483 last_edit_time: self.last_edit_time,
484 };
485
486 (before, Some(after))
487 }
488}
489
490pub(crate) fn compute_diff_between_snapshots(
491 old_snapshot: &TextBufferSnapshot,
492 new_snapshot: &TextBufferSnapshot,
493) -> Option<String> {
494 let edits: Vec<Edit<usize>> = new_snapshot
495 .edits_since::<usize>(&old_snapshot.version)
496 .collect();
497
498 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
499
500 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
501 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
502 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
503 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
504
505 const CONTEXT_LINES: u32 = 3;
506
507 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
508 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
509 let old_context_end_row =
510 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
511 let new_context_end_row =
512 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
513
514 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
515 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
516 let old_end_line_offset = old_snapshot
517 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
518 let new_end_line_offset = new_snapshot
519 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
520 let old_edit_range = old_start_line_offset..old_end_line_offset;
521 let new_edit_range = new_start_line_offset..new_end_line_offset;
522
523 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
524 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
525
526 let diff = language::unified_diff_with_offsets(
527 &old_region_text,
528 &new_region_text,
529 old_context_start_row,
530 new_context_start_row,
531 );
532
533 Some(diff)
534}
535
536fn buffer_path_with_id_fallback(
537 file: Option<&Arc<dyn File>>,
538 snapshot: &TextBufferSnapshot,
539 cx: &App,
540) -> Arc<Path> {
541 if let Some(file) = file {
542 file.full_path(cx).into()
543 } else {
544 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
545 }
546}
547
548impl EditPredictionStore {
549 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
550 cx.try_global::<EditPredictionStoreGlobal>()
551 .map(|global| global.0.clone())
552 }
553
554 pub fn global(
555 client: &Arc<Client>,
556 user_store: &Entity<UserStore>,
557 cx: &mut App,
558 ) -> Entity<Self> {
559 cx.try_global::<EditPredictionStoreGlobal>()
560 .map(|global| global.0.clone())
561 .unwrap_or_else(|| {
562 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
563 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
564 ep_store
565 })
566 }
567
568 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
569 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
570 let data_collection_choice = Self::load_data_collection_choice();
571
572 let llm_token = LlmApiToken::default();
573
574 let (reject_tx, reject_rx) = mpsc::unbounded();
575 cx.background_spawn({
576 let client = client.clone();
577 let llm_token = llm_token.clone();
578 let app_version = AppVersion::global(cx);
579 let background_executor = cx.background_executor().clone();
580 async move {
581 Self::handle_rejected_predictions(
582 reject_rx,
583 client,
584 llm_token,
585 app_version,
586 background_executor,
587 )
588 .await
589 }
590 })
591 .detach();
592
593 let mut this = Self {
594 projects: HashMap::default(),
595 client,
596 user_store,
597 options: DEFAULT_OPTIONS,
598 use_context: false,
599 llm_token,
600 _llm_token_subscription: cx.subscribe(
601 &refresh_llm_token_listener,
602 |this, _listener, _event, cx| {
603 let client = this.client.clone();
604 let llm_token = this.llm_token.clone();
605 cx.spawn(async move |_this, _cx| {
606 llm_token.refresh(&client).await?;
607 anyhow::Ok(())
608 })
609 .detach_and_log_err(cx);
610 },
611 ),
612 update_required: false,
613 #[cfg(feature = "cli-support")]
614 eval_cache: None,
615 edit_prediction_model: EditPredictionModel::Zeta2,
616 sweep_ai: SweepAi::new(cx),
617 mercury: Mercury::new(cx),
618 data_collection_choice,
619 reject_predictions_tx: reject_tx,
620 rated_predictions: Default::default(),
621 shown_predictions: Default::default(),
622 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
623 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
624 Err(_) => {
625 if *USE_OLLAMA {
626 Some(
627 Url::parse("http://localhost:11434/v1/chat/completions")
628 .unwrap()
629 .into(),
630 )
631 } else {
632 None
633 }
634 }
635 },
636 };
637
638 this.configure_context_retrieval(cx);
639 let weak_this = cx.weak_entity();
640 cx.on_flags_ready(move |_, cx| {
641 weak_this
642 .update(cx, |this, cx| this.configure_context_retrieval(cx))
643 .ok();
644 })
645 .detach();
646 cx.observe_global::<SettingsStore>(|this, cx| {
647 this.configure_context_retrieval(cx);
648 })
649 .detach();
650
651 this
652 }
653
654 #[cfg(test)]
655 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
656 self.custom_predict_edits_url = Some(url.into());
657 }
658
659 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
660 self.edit_prediction_model = model;
661 }
662
663 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
664 self.sweep_ai.api_token.read(cx).has_key()
665 }
666
667 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
668 self.mercury.api_token.read(cx).has_key()
669 }
670
671 #[cfg(feature = "cli-support")]
672 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
673 self.eval_cache = Some(cache);
674 }
675
676 pub fn options(&self) -> &ZetaOptions {
677 &self.options
678 }
679
680 pub fn set_options(&mut self, options: ZetaOptions) {
681 self.options = options;
682 }
683
684 pub fn set_use_context(&mut self, use_context: bool) {
685 self.use_context = use_context;
686 }
687
688 pub fn clear_history(&mut self) {
689 for project_state in self.projects.values_mut() {
690 project_state.events.clear();
691 }
692 }
693
694 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
695 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
696 project_state.events.clear();
697 }
698 }
699
700 pub fn edit_history_for_project(
701 &self,
702 project: &Entity<Project>,
703 cx: &App,
704 ) -> Vec<StoredEvent> {
705 self.projects
706 .get(&project.entity_id())
707 .map(|project_state| project_state.events(cx))
708 .unwrap_or_default()
709 }
710
711 pub fn edit_history_for_project_with_pause_split_last_event(
712 &self,
713 project: &Entity<Project>,
714 cx: &App,
715 ) -> Vec<StoredEvent> {
716 self.projects
717 .get(&project.entity_id())
718 .map(|project_state| project_state.events_split_by_pause(cx))
719 .unwrap_or_default()
720 }
721
722 pub fn context_for_project<'a>(
723 &'a self,
724 project: &Entity<Project>,
725 cx: &'a App,
726 ) -> Arc<[RelatedFile]> {
727 self.projects
728 .get(&project.entity_id())
729 .map(|project| project.context.read(cx).related_files())
730 .unwrap_or_else(|| vec![].into())
731 }
732
733 pub fn context_for_project_with_buffers<'a>(
734 &'a self,
735 project: &Entity<Project>,
736 cx: &'a App,
737 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
738 self.projects
739 .get(&project.entity_id())
740 .map(|project| project.context.read(cx).related_files_with_buffers())
741 }
742
743 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
744 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
745 self.user_store.read(cx).edit_prediction_usage()
746 } else {
747 None
748 }
749 }
750
751 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
752 self.get_or_init_project(project, cx);
753 }
754
755 pub fn register_buffer(
756 &mut self,
757 buffer: &Entity<Buffer>,
758 project: &Entity<Project>,
759 cx: &mut Context<Self>,
760 ) {
761 let project_state = self.get_or_init_project(project, cx);
762 Self::register_buffer_impl(project_state, buffer, project, cx);
763 }
764
765 fn get_or_init_project(
766 &mut self,
767 project: &Entity<Project>,
768 cx: &mut Context<Self>,
769 ) -> &mut ProjectState {
770 let entity_id = project.entity_id();
771 self.projects
772 .entry(entity_id)
773 .or_insert_with(|| ProjectState {
774 context: {
775 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
776 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
777 this.handle_excerpt_store_event(entity_id, event);
778 })
779 .detach();
780 related_excerpt_store
781 },
782 events: VecDeque::new(),
783 last_event: None,
784 recent_paths: VecDeque::new(),
785 debug_tx: None,
786 registered_buffers: HashMap::default(),
787 current_prediction: None,
788 cancelled_predictions: HashSet::default(),
789 pending_predictions: ArrayVec::new(),
790 next_pending_prediction_id: 0,
791 last_prediction_refresh: None,
792 license_detection_watchers: HashMap::default(),
793 _subscription: cx.subscribe(&project, Self::handle_project_event),
794 })
795 }
796
797 pub fn remove_project(&mut self, project: &Entity<Project>) {
798 self.projects.remove(&project.entity_id());
799 }
800
801 fn handle_excerpt_store_event(
802 &mut self,
803 project_entity_id: EntityId,
804 event: &RelatedExcerptStoreEvent,
805 ) {
806 if let Some(project_state) = self.projects.get(&project_entity_id) {
807 if let Some(debug_tx) = project_state.debug_tx.clone() {
808 match event {
809 RelatedExcerptStoreEvent::StartedRefresh => {
810 debug_tx
811 .unbounded_send(DebugEvent::ContextRetrievalStarted(
812 ContextRetrievalStartedDebugEvent {
813 project_entity_id: project_entity_id,
814 timestamp: Instant::now(),
815 search_prompt: String::new(),
816 },
817 ))
818 .ok();
819 }
820 RelatedExcerptStoreEvent::FinishedRefresh {
821 cache_hit_count,
822 cache_miss_count,
823 mean_definition_latency,
824 max_definition_latency,
825 } => {
826 debug_tx
827 .unbounded_send(DebugEvent::ContextRetrievalFinished(
828 ContextRetrievalFinishedDebugEvent {
829 project_entity_id: project_entity_id,
830 timestamp: Instant::now(),
831 metadata: vec![
832 (
833 "Cache Hits",
834 format!(
835 "{}/{}",
836 cache_hit_count,
837 cache_hit_count + cache_miss_count
838 )
839 .into(),
840 ),
841 (
842 "Max LSP Time",
843 format!("{} ms", max_definition_latency.as_millis())
844 .into(),
845 ),
846 (
847 "Mean LSP Time",
848 format!("{} ms", mean_definition_latency.as_millis())
849 .into(),
850 ),
851 ],
852 },
853 ))
854 .ok();
855 }
856 }
857 }
858 }
859 }
860
861 pub fn debug_info(
862 &mut self,
863 project: &Entity<Project>,
864 cx: &mut Context<Self>,
865 ) -> mpsc::UnboundedReceiver<DebugEvent> {
866 let project_state = self.get_or_init_project(project, cx);
867 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
868 project_state.debug_tx = Some(debug_watch_tx);
869 debug_watch_rx
870 }
871
872 fn handle_project_event(
873 &mut self,
874 project: Entity<Project>,
875 event: &project::Event,
876 cx: &mut Context<Self>,
877 ) {
878 // TODO [zeta2] init with recent paths
879 match event {
880 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
881 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
882 return;
883 };
884 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
885 if let Some(path) = path {
886 if let Some(ix) = project_state
887 .recent_paths
888 .iter()
889 .position(|probe| probe == &path)
890 {
891 project_state.recent_paths.remove(ix);
892 }
893 project_state.recent_paths.push_front(path);
894 }
895 }
896 project::Event::DiagnosticsUpdated { .. } => {
897 if cx.has_flag::<Zeta2FeatureFlag>() {
898 self.refresh_prediction_from_diagnostics(project, cx);
899 }
900 }
901 _ => (),
902 }
903 }
904
905 fn register_buffer_impl<'a>(
906 project_state: &'a mut ProjectState,
907 buffer: &Entity<Buffer>,
908 project: &Entity<Project>,
909 cx: &mut Context<Self>,
910 ) -> &'a mut RegisteredBuffer {
911 let buffer_id = buffer.entity_id();
912
913 if let Some(file) = buffer.read(cx).file() {
914 let worktree_id = file.worktree_id(cx);
915 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
916 project_state
917 .license_detection_watchers
918 .entry(worktree_id)
919 .or_insert_with(|| {
920 let project_entity_id = project.entity_id();
921 cx.observe_release(&worktree, move |this, _worktree, _cx| {
922 let Some(project_state) = this.projects.get_mut(&project_entity_id)
923 else {
924 return;
925 };
926 project_state
927 .license_detection_watchers
928 .remove(&worktree_id);
929 })
930 .detach();
931 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
932 });
933 }
934 }
935
936 match project_state.registered_buffers.entry(buffer_id) {
937 hash_map::Entry::Occupied(entry) => entry.into_mut(),
938 hash_map::Entry::Vacant(entry) => {
939 let buf = buffer.read(cx);
940 let snapshot = buf.text_snapshot();
941 let file = buf.file().cloned();
942 let project_entity_id = project.entity_id();
943 entry.insert(RegisteredBuffer {
944 snapshot,
945 file,
946 last_position: None,
947 _subscriptions: [
948 cx.subscribe(buffer, {
949 let project = project.downgrade();
950 move |this, buffer, event, cx| {
951 if let language::BufferEvent::Edited = event
952 && let Some(project) = project.upgrade()
953 {
954 this.report_changes_for_buffer(&buffer, &project, cx);
955 }
956 }
957 }),
958 cx.observe_release(buffer, move |this, _buffer, _cx| {
959 let Some(project_state) = this.projects.get_mut(&project_entity_id)
960 else {
961 return;
962 };
963 project_state.registered_buffers.remove(&buffer_id);
964 }),
965 ],
966 })
967 }
968 }
969 }
970
971 fn report_changes_for_buffer(
972 &mut self,
973 buffer: &Entity<Buffer>,
974 project: &Entity<Project>,
975 cx: &mut Context<Self>,
976 ) {
977 let project_state = self.get_or_init_project(project, cx);
978 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
979
980 let buf = buffer.read(cx);
981 let new_file = buf.file().cloned();
982 let new_snapshot = buf.text_snapshot();
983 if new_snapshot.version == registered_buffer.snapshot.version {
984 return;
985 }
986
987 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
988 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
989 let end_edit_anchor = new_snapshot
990 .anchored_edits_since::<Point>(&old_snapshot.version)
991 .last()
992 .map(|(_, range)| range.end);
993 let events = &mut project_state.events;
994
995 let now = cx.background_executor().now();
996 if let Some(last_event) = project_state.last_event.as_mut() {
997 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
998 == last_event.new_snapshot.remote_id()
999 && old_snapshot.version == last_event.new_snapshot.version;
1000
1001 let should_coalesce = is_next_snapshot_of_same_buffer
1002 && end_edit_anchor
1003 .as_ref()
1004 .zip(last_event.end_edit_anchor.as_ref())
1005 .is_some_and(|(a, b)| {
1006 let a = a.to_point(&new_snapshot);
1007 let b = b.to_point(&new_snapshot);
1008 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
1009 });
1010
1011 if should_coalesce {
1012 let pause_elapsed = last_event
1013 .last_edit_time
1014 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1015 .unwrap_or(false);
1016 if pause_elapsed {
1017 last_event.snapshot_after_last_editing_pause =
1018 Some(last_event.new_snapshot.clone());
1019 }
1020
1021 last_event.end_edit_anchor = end_edit_anchor;
1022 last_event.new_snapshot = new_snapshot;
1023 last_event.last_edit_time = Some(now);
1024 return;
1025 }
1026 }
1027
1028 if events.len() + 1 >= EVENT_COUNT_MAX {
1029 events.pop_front();
1030 }
1031
1032 if let Some(event) = project_state.last_event.take() {
1033 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
1034 }
1035
1036 project_state.last_event = Some(LastEvent {
1037 old_file,
1038 new_file,
1039 old_snapshot,
1040 new_snapshot,
1041 end_edit_anchor,
1042 snapshot_after_last_editing_pause: None,
1043 last_edit_time: Some(now),
1044 });
1045 }
1046
1047 fn prediction_at(
1048 &mut self,
1049 buffer: &Entity<Buffer>,
1050 position: Option<language::Anchor>,
1051 project: &Entity<Project>,
1052 cx: &App,
1053 ) -> Option<BufferEditPrediction<'_>> {
1054 let project_state = self.projects.get_mut(&project.entity_id())?;
1055 if let Some(position) = position
1056 && let Some(buffer) = project_state
1057 .registered_buffers
1058 .get_mut(&buffer.entity_id())
1059 {
1060 buffer.last_position = Some(position);
1061 }
1062
1063 let CurrentEditPrediction {
1064 requested_by,
1065 prediction,
1066 ..
1067 } = project_state.current_prediction.as_ref()?;
1068
1069 if prediction.targets_buffer(buffer.read(cx)) {
1070 Some(BufferEditPrediction::Local { prediction })
1071 } else {
1072 let show_jump = match requested_by {
1073 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1074 requested_by_buffer_id == &buffer.entity_id()
1075 }
1076 PredictionRequestedBy::DiagnosticsUpdate => true,
1077 };
1078
1079 if show_jump {
1080 Some(BufferEditPrediction::Jump { prediction })
1081 } else {
1082 None
1083 }
1084 }
1085 }
1086
1087 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1088 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
1089 match self.edit_prediction_model {
1090 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1091 if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
1092 return;
1093 }
1094 }
1095 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1096 }
1097
1098 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1099 return;
1100 };
1101
1102 let Some(prediction) = project_state.current_prediction.take() else {
1103 return;
1104 };
1105 let request_id = prediction.prediction.id.to_string();
1106 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1107 project_state.cancel_pending_prediction(pending_prediction, cx);
1108 }
1109
1110 let client = self.client.clone();
1111 let llm_token = self.llm_token.clone();
1112 let app_version = AppVersion::global(cx);
1113 cx.spawn(async move |this, cx| {
1114 let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
1115 (http_client::Url::parse(&accept_edits_url)?, false)
1116 } else {
1117 (
1118 client
1119 .http_client()
1120 .build_zed_llm_url("/predict_edits/accept", &[])?,
1121 true,
1122 )
1123 };
1124
1125 let response = cx
1126 .background_spawn(Self::send_api_request::<()>(
1127 move |builder| {
1128 let req = builder.uri(url.as_ref()).body(
1129 serde_json::to_string(&AcceptEditPredictionBody {
1130 request_id: request_id.clone(),
1131 })?
1132 .into(),
1133 );
1134 Ok(req?)
1135 },
1136 client,
1137 llm_token,
1138 app_version,
1139 require_auth,
1140 ))
1141 .await;
1142
1143 Self::handle_api_response(&this, response, cx)?;
1144 anyhow::Ok(())
1145 })
1146 .detach_and_log_err(cx);
1147 }
1148
1149 async fn handle_rejected_predictions(
1150 rx: UnboundedReceiver<EditPredictionRejection>,
1151 client: Arc<Client>,
1152 llm_token: LlmApiToken,
1153 app_version: Version,
1154 background_executor: BackgroundExecutor,
1155 ) {
1156 let mut rx = std::pin::pin!(rx.peekable());
1157 let mut batched = Vec::new();
1158
1159 while let Some(rejection) = rx.next().await {
1160 batched.push(rejection);
1161
1162 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1163 select_biased! {
1164 next = rx.as_mut().peek().fuse() => {
1165 if next.is_some() {
1166 continue;
1167 }
1168 }
1169 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1170 }
1171 }
1172
1173 let url = client
1174 .http_client()
1175 .build_zed_llm_url("/predict_edits/reject", &[])
1176 .unwrap();
1177
1178 let flush_count = batched
1179 .len()
1180 // in case items have accumulated after failure
1181 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1182 let start = batched.len() - flush_count;
1183
1184 let body = RejectEditPredictionsBodyRef {
1185 rejections: &batched[start..],
1186 };
1187
1188 let result = Self::send_api_request::<()>(
1189 |builder| {
1190 let req = builder
1191 .uri(url.as_ref())
1192 .body(serde_json::to_string(&body)?.into());
1193 anyhow::Ok(req?)
1194 },
1195 client.clone(),
1196 llm_token.clone(),
1197 app_version.clone(),
1198 true,
1199 )
1200 .await;
1201
1202 if result.log_err().is_some() {
1203 batched.drain(start..);
1204 }
1205 }
1206 }
1207
1208 fn reject_current_prediction(
1209 &mut self,
1210 reason: EditPredictionRejectReason,
1211 project: &Entity<Project>,
1212 ) {
1213 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1214 project_state.pending_predictions.clear();
1215 if let Some(prediction) = project_state.current_prediction.take() {
1216 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1217 }
1218 };
1219 }
1220
1221 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1222 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1223 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1224 if !current_prediction.was_shown {
1225 current_prediction.was_shown = true;
1226 self.shown_predictions
1227 .push_front(current_prediction.prediction.clone());
1228 if self.shown_predictions.len() > 50 {
1229 let completion = self.shown_predictions.pop_back().unwrap();
1230 self.rated_predictions.remove(&completion.id);
1231 }
1232 }
1233 }
1234 }
1235 }
1236
1237 fn reject_prediction(
1238 &mut self,
1239 prediction_id: EditPredictionId,
1240 reason: EditPredictionRejectReason,
1241 was_shown: bool,
1242 ) {
1243 match self.edit_prediction_model {
1244 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1245 if self.custom_predict_edits_url.is_some() {
1246 return;
1247 }
1248 }
1249 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1250 }
1251
1252 self.reject_predictions_tx
1253 .unbounded_send(EditPredictionRejection {
1254 request_id: prediction_id.to_string(),
1255 reason,
1256 was_shown,
1257 })
1258 .log_err();
1259 }
1260
1261 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1262 self.projects
1263 .get(&project.entity_id())
1264 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1265 }
1266
1267 pub fn refresh_prediction_from_buffer(
1268 &mut self,
1269 project: Entity<Project>,
1270 buffer: Entity<Buffer>,
1271 position: language::Anchor,
1272 cx: &mut Context<Self>,
1273 ) {
1274 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1275 let Some(request_task) = this
1276 .update(cx, |this, cx| {
1277 this.request_prediction(
1278 &project,
1279 &buffer,
1280 position,
1281 PredictEditsRequestTrigger::Other,
1282 cx,
1283 )
1284 })
1285 .log_err()
1286 else {
1287 return Task::ready(anyhow::Ok(None));
1288 };
1289
1290 cx.spawn(async move |_cx| {
1291 request_task.await.map(|prediction_result| {
1292 prediction_result.map(|prediction_result| {
1293 (
1294 prediction_result,
1295 PredictionRequestedBy::Buffer(buffer.entity_id()),
1296 )
1297 })
1298 })
1299 })
1300 })
1301 }
1302
1303 pub fn refresh_prediction_from_diagnostics(
1304 &mut self,
1305 project: Entity<Project>,
1306 cx: &mut Context<Self>,
1307 ) {
1308 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1309 return;
1310 };
1311
1312 // Prefer predictions from buffer
1313 if project_state.current_prediction.is_some() {
1314 return;
1315 };
1316
1317 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1318 let Some((active_buffer, snapshot, cursor_point)) = this
1319 .read_with(cx, |this, cx| {
1320 let project_state = this.projects.get(&project.entity_id())?;
1321 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1322 let snapshot = buffer.read(cx).snapshot();
1323
1324 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1325 return None;
1326 }
1327
1328 let cursor_point = position
1329 .map(|pos| pos.to_point(&snapshot))
1330 .unwrap_or_default();
1331
1332 Some((buffer, snapshot, cursor_point))
1333 })
1334 .log_err()
1335 .flatten()
1336 else {
1337 return Task::ready(anyhow::Ok(None));
1338 };
1339
1340 cx.spawn(async move |cx| {
1341 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1342 active_buffer,
1343 &snapshot,
1344 Default::default(),
1345 cursor_point,
1346 &project,
1347 cx,
1348 )
1349 .await?
1350 else {
1351 return anyhow::Ok(None);
1352 };
1353
1354 let Some(prediction_result) = this
1355 .update(cx, |this, cx| {
1356 this.request_prediction(
1357 &project,
1358 &jump_buffer,
1359 jump_position,
1360 PredictEditsRequestTrigger::Diagnostics,
1361 cx,
1362 )
1363 })?
1364 .await?
1365 else {
1366 return anyhow::Ok(None);
1367 };
1368
1369 this.update(cx, |this, cx| {
1370 Some((
1371 if this
1372 .get_or_init_project(&project, cx)
1373 .current_prediction
1374 .is_none()
1375 {
1376 prediction_result
1377 } else {
1378 EditPredictionResult {
1379 id: prediction_result.id,
1380 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1381 }
1382 },
1383 PredictionRequestedBy::DiagnosticsUpdate,
1384 ))
1385 })
1386 })
1387 });
1388 }
1389
1390 fn predictions_enabled_at(
1391 snapshot: &BufferSnapshot,
1392 position: Option<language::Anchor>,
1393 cx: &App,
1394 ) -> bool {
1395 let file = snapshot.file();
1396 let all_settings = all_language_settings(file, cx);
1397 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1398 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1399 {
1400 return false;
1401 }
1402
1403 if let Some(last_position) = position {
1404 let settings = snapshot.settings_at(last_position, cx);
1405
1406 if !settings.edit_predictions_disabled_in.is_empty()
1407 && let Some(scope) = snapshot.language_scope_at(last_position)
1408 && let Some(scope_name) = scope.override_name()
1409 && settings
1410 .edit_predictions_disabled_in
1411 .iter()
1412 .any(|s| s == scope_name)
1413 {
1414 return false;
1415 }
1416 }
1417
1418 true
1419 }
1420
1421 #[cfg(not(test))]
1422 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1423 #[cfg(test)]
1424 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1425
1426 fn queue_prediction_refresh(
1427 &mut self,
1428 project: Entity<Project>,
1429 throttle_entity: EntityId,
1430 cx: &mut Context<Self>,
1431 do_refresh: impl FnOnce(
1432 WeakEntity<Self>,
1433 &mut AsyncApp,
1434 )
1435 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1436 + 'static,
1437 ) {
1438 let project_state = self.get_or_init_project(&project, cx);
1439 let pending_prediction_id = project_state.next_pending_prediction_id;
1440 project_state.next_pending_prediction_id += 1;
1441 let last_request = project_state.last_prediction_refresh;
1442
1443 let task = cx.spawn(async move |this, cx| {
1444 if let Some((last_entity, last_timestamp)) = last_request
1445 && throttle_entity == last_entity
1446 && let Some(timeout) =
1447 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1448 {
1449 cx.background_executor().timer(timeout).await;
1450 }
1451
1452 // If this task was cancelled before the throttle timeout expired,
1453 // do not perform a request.
1454 let mut is_cancelled = true;
1455 this.update(cx, |this, cx| {
1456 let project_state = this.get_or_init_project(&project, cx);
1457 if !project_state
1458 .cancelled_predictions
1459 .remove(&pending_prediction_id)
1460 {
1461 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1462 is_cancelled = false;
1463 }
1464 })
1465 .ok();
1466 if is_cancelled {
1467 return None;
1468 }
1469
1470 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1471 let new_prediction_id = new_prediction_result
1472 .as_ref()
1473 .map(|(prediction, _)| prediction.id.clone());
1474
1475 // When a prediction completes, remove it from the pending list, and cancel
1476 // any pending predictions that were enqueued before it.
1477 this.update(cx, |this, cx| {
1478 let project_state = this.get_or_init_project(&project, cx);
1479
1480 let is_cancelled = project_state
1481 .cancelled_predictions
1482 .remove(&pending_prediction_id);
1483
1484 let new_current_prediction = if !is_cancelled
1485 && let Some((prediction_result, requested_by)) = new_prediction_result
1486 {
1487 match prediction_result.prediction {
1488 Ok(prediction) => {
1489 let new_prediction = CurrentEditPrediction {
1490 requested_by,
1491 prediction,
1492 was_shown: false,
1493 };
1494
1495 if let Some(current_prediction) =
1496 project_state.current_prediction.as_ref()
1497 {
1498 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1499 {
1500 this.reject_current_prediction(
1501 EditPredictionRejectReason::Replaced,
1502 &project,
1503 );
1504
1505 Some(new_prediction)
1506 } else {
1507 this.reject_prediction(
1508 new_prediction.prediction.id,
1509 EditPredictionRejectReason::CurrentPreferred,
1510 false,
1511 );
1512 None
1513 }
1514 } else {
1515 Some(new_prediction)
1516 }
1517 }
1518 Err(reject_reason) => {
1519 this.reject_prediction(prediction_result.id, reject_reason, false);
1520 None
1521 }
1522 }
1523 } else {
1524 None
1525 };
1526
1527 let project_state = this.get_or_init_project(&project, cx);
1528
1529 if let Some(new_prediction) = new_current_prediction {
1530 project_state.current_prediction = Some(new_prediction);
1531 }
1532
1533 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1534 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1535 if pending_prediction.id == pending_prediction_id {
1536 pending_predictions.remove(ix);
1537 for pending_prediction in pending_predictions.drain(0..ix) {
1538 project_state.cancel_pending_prediction(pending_prediction, cx)
1539 }
1540 break;
1541 }
1542 }
1543 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1544 cx.notify();
1545 })
1546 .ok();
1547
1548 new_prediction_id
1549 });
1550
1551 if project_state.pending_predictions.len() <= 1 {
1552 project_state.pending_predictions.push(PendingPrediction {
1553 id: pending_prediction_id,
1554 task,
1555 });
1556 } else if project_state.pending_predictions.len() == 2 {
1557 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1558 project_state.pending_predictions.push(PendingPrediction {
1559 id: pending_prediction_id,
1560 task,
1561 });
1562 project_state.cancel_pending_prediction(pending_prediction, cx);
1563 }
1564 }
1565
1566 pub fn request_prediction(
1567 &mut self,
1568 project: &Entity<Project>,
1569 active_buffer: &Entity<Buffer>,
1570 position: language::Anchor,
1571 trigger: PredictEditsRequestTrigger,
1572 cx: &mut Context<Self>,
1573 ) -> Task<Result<Option<EditPredictionResult>>> {
1574 self.request_prediction_internal(
1575 project.clone(),
1576 active_buffer.clone(),
1577 position,
1578 trigger,
1579 cx.has_flag::<Zeta2FeatureFlag>(),
1580 cx,
1581 )
1582 }
1583
1584 fn request_prediction_internal(
1585 &mut self,
1586 project: Entity<Project>,
1587 active_buffer: Entity<Buffer>,
1588 position: language::Anchor,
1589 trigger: PredictEditsRequestTrigger,
1590 allow_jump: bool,
1591 cx: &mut Context<Self>,
1592 ) -> Task<Result<Option<EditPredictionResult>>> {
1593 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1594
1595 self.get_or_init_project(&project, cx);
1596 let project_state = self.projects.get(&project.entity_id()).unwrap();
1597 let stored_events = project_state.events(cx);
1598 let has_events = !stored_events.is_empty();
1599 let events: Vec<Arc<zeta_prompt::Event>> =
1600 stored_events.into_iter().map(|e| e.event).collect();
1601 let debug_tx = project_state.debug_tx.clone();
1602
1603 let snapshot = active_buffer.read(cx).snapshot();
1604 let cursor_point = position.to_point(&snapshot);
1605 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1606 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1607 let diagnostic_search_range =
1608 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1609
1610 let related_files = if self.use_context {
1611 self.context_for_project(&project, cx)
1612 } else {
1613 Vec::new().into()
1614 };
1615
1616 let inputs = EditPredictionModelInput {
1617 project: project.clone(),
1618 buffer: active_buffer.clone(),
1619 snapshot: snapshot.clone(),
1620 position,
1621 events,
1622 related_files,
1623 recent_paths: project_state.recent_paths.clone(),
1624 trigger,
1625 diagnostic_search_range: diagnostic_search_range.clone(),
1626 debug_tx,
1627 };
1628
1629 let task = match self.edit_prediction_model {
1630 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1631 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1632 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1633 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1634 };
1635
1636 cx.spawn(async move |this, cx| {
1637 let prediction = task.await?;
1638
1639 if prediction.is_none() && allow_jump {
1640 let cursor_point = position.to_point(&snapshot);
1641 if has_events
1642 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1643 active_buffer.clone(),
1644 &snapshot,
1645 diagnostic_search_range,
1646 cursor_point,
1647 &project,
1648 cx,
1649 )
1650 .await?
1651 {
1652 return this
1653 .update(cx, |this, cx| {
1654 this.request_prediction_internal(
1655 project,
1656 jump_buffer,
1657 jump_position,
1658 trigger,
1659 false,
1660 cx,
1661 )
1662 })?
1663 .await;
1664 }
1665
1666 return anyhow::Ok(None);
1667 }
1668
1669 Ok(prediction)
1670 })
1671 }
1672
1673 async fn next_diagnostic_location(
1674 active_buffer: Entity<Buffer>,
1675 active_buffer_snapshot: &BufferSnapshot,
1676 active_buffer_diagnostic_search_range: Range<Point>,
1677 active_buffer_cursor_point: Point,
1678 project: &Entity<Project>,
1679 cx: &mut AsyncApp,
1680 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1681 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1682 let mut jump_location = active_buffer_snapshot
1683 .diagnostic_groups(None)
1684 .into_iter()
1685 .filter_map(|(_, group)| {
1686 let range = &group.entries[group.primary_ix]
1687 .range
1688 .to_point(&active_buffer_snapshot);
1689 if range.overlaps(&active_buffer_diagnostic_search_range) {
1690 None
1691 } else {
1692 Some(range.start)
1693 }
1694 })
1695 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1696 .map(|position| {
1697 (
1698 active_buffer.clone(),
1699 active_buffer_snapshot.anchor_before(position),
1700 )
1701 });
1702
1703 if jump_location.is_none() {
1704 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1705 let file = buffer.file()?;
1706
1707 Some(ProjectPath {
1708 worktree_id: file.worktree_id(cx),
1709 path: file.path().clone(),
1710 })
1711 })?;
1712
1713 let buffer_task = project.update(cx, |project, cx| {
1714 let (path, _, _) = project
1715 .diagnostic_summaries(false, cx)
1716 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1717 .max_by_key(|(path, _, _)| {
1718 // find the buffer with errors that shares most parent directories
1719 path.path
1720 .components()
1721 .zip(
1722 active_buffer_path
1723 .as_ref()
1724 .map(|p| p.path.components())
1725 .unwrap_or_default(),
1726 )
1727 .take_while(|(a, b)| a == b)
1728 .count()
1729 })?;
1730
1731 Some(project.open_buffer(path, cx))
1732 })?;
1733
1734 if let Some(buffer_task) = buffer_task {
1735 let closest_buffer = buffer_task.await?;
1736
1737 jump_location = closest_buffer
1738 .read_with(cx, |buffer, _cx| {
1739 buffer
1740 .buffer_diagnostics(None)
1741 .into_iter()
1742 .min_by_key(|entry| entry.diagnostic.severity)
1743 .map(|entry| entry.range.start)
1744 })?
1745 .map(|position| (closest_buffer, position));
1746 }
1747 }
1748
1749 anyhow::Ok(jump_location)
1750 }
1751
1752 async fn send_raw_llm_request(
1753 request: open_ai::Request,
1754 client: Arc<Client>,
1755 llm_token: LlmApiToken,
1756 app_version: Version,
1757 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1758 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1759 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1760 let url = client
1761 .http_client()
1762 .build_zed_llm_url("/predict_edits/raw", &[])?;
1763
1764 #[cfg(feature = "cli-support")]
1765 let cache_key = if let Some(cache) = eval_cache {
1766 use collections::FxHasher;
1767 use std::hash::{Hash, Hasher};
1768
1769 let mut hasher = FxHasher::default();
1770 url.hash(&mut hasher);
1771 let request_str = serde_json::to_string_pretty(&request)?;
1772 request_str.hash(&mut hasher);
1773 let hash = hasher.finish();
1774
1775 let key = (eval_cache_kind, hash);
1776 if let Some(response_str) = cache.read(key) {
1777 return Ok((serde_json::from_str(&response_str)?, None));
1778 }
1779
1780 Some((cache, request_str, key))
1781 } else {
1782 None
1783 };
1784
1785 let (response, usage) = Self::send_api_request(
1786 |builder| {
1787 let req = builder
1788 .uri(url.as_ref())
1789 .body(serde_json::to_string(&request)?.into());
1790 Ok(req?)
1791 },
1792 client,
1793 llm_token,
1794 app_version,
1795 true,
1796 )
1797 .await?;
1798
1799 #[cfg(feature = "cli-support")]
1800 if let Some((cache, request, key)) = cache_key {
1801 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1802 }
1803
1804 Ok((response, usage))
1805 }
1806
1807 fn handle_api_response<T>(
1808 this: &WeakEntity<Self>,
1809 response: Result<(T, Option<EditPredictionUsage>)>,
1810 cx: &mut gpui::AsyncApp,
1811 ) -> Result<T> {
1812 match response {
1813 Ok((data, usage)) => {
1814 if let Some(usage) = usage {
1815 this.update(cx, |this, cx| {
1816 this.user_store.update(cx, |user_store, cx| {
1817 user_store.update_edit_prediction_usage(usage, cx);
1818 });
1819 })
1820 .ok();
1821 }
1822 Ok(data)
1823 }
1824 Err(err) => {
1825 if err.is::<ZedUpdateRequiredError>() {
1826 cx.update(|cx| {
1827 this.update(cx, |this, _cx| {
1828 this.update_required = true;
1829 })
1830 .ok();
1831
1832 let error_message: SharedString = err.to_string().into();
1833 show_app_notification(
1834 NotificationId::unique::<ZedUpdateRequiredError>(),
1835 cx,
1836 move |cx| {
1837 cx.new(|cx| {
1838 ErrorMessagePrompt::new(error_message.clone(), cx)
1839 .with_link_button("Update Zed", "https://zed.dev/releases")
1840 })
1841 },
1842 );
1843 })
1844 .ok();
1845 }
1846 Err(err)
1847 }
1848 }
1849 }
1850
1851 async fn send_api_request<Res>(
1852 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1853 client: Arc<Client>,
1854 llm_token: LlmApiToken,
1855 app_version: Version,
1856 require_auth: bool,
1857 ) -> Result<(Res, Option<EditPredictionUsage>)>
1858 where
1859 Res: DeserializeOwned,
1860 {
1861 let http_client = client.http_client();
1862
1863 let mut token = if require_auth {
1864 Some(llm_token.acquire(&client).await?)
1865 } else {
1866 llm_token.acquire(&client).await.ok()
1867 };
1868 let mut did_retry = false;
1869
1870 loop {
1871 let request_builder = http_client::Request::builder().method(Method::POST);
1872
1873 let mut request_builder = request_builder
1874 .header("Content-Type", "application/json")
1875 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
1876
1877 // Only add Authorization header if we have a token
1878 if let Some(ref token_value) = token {
1879 request_builder =
1880 request_builder.header("Authorization", format!("Bearer {}", token_value));
1881 }
1882
1883 let request = build(request_builder)?;
1884
1885 let mut response = http_client.send(request).await?;
1886
1887 if let Some(minimum_required_version) = response
1888 .headers()
1889 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1890 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1891 {
1892 anyhow::ensure!(
1893 app_version >= minimum_required_version,
1894 ZedUpdateRequiredError {
1895 minimum_version: minimum_required_version
1896 }
1897 );
1898 }
1899
1900 if response.status().is_success() {
1901 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1902
1903 let mut body = Vec::new();
1904 response.body_mut().read_to_end(&mut body).await?;
1905 return Ok((serde_json::from_slice(&body)?, usage));
1906 } else if !did_retry
1907 && token.is_some()
1908 && response
1909 .headers()
1910 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1911 .is_some()
1912 {
1913 did_retry = true;
1914 token = Some(llm_token.refresh(&client).await?);
1915 } else {
1916 let mut body = String::new();
1917 response.body_mut().read_to_string(&mut body).await?;
1918 anyhow::bail!(
1919 "Request failed with status: {:?}\nBody: {}",
1920 response.status(),
1921 body
1922 );
1923 }
1924 }
1925 }
1926
1927 pub fn refresh_context(
1928 &mut self,
1929 project: &Entity<Project>,
1930 buffer: &Entity<language::Buffer>,
1931 cursor_position: language::Anchor,
1932 cx: &mut Context<Self>,
1933 ) {
1934 if self.use_context {
1935 self.get_or_init_project(project, cx)
1936 .context
1937 .update(cx, |store, cx| {
1938 store.refresh(buffer.clone(), cursor_position, cx);
1939 });
1940 }
1941 }
1942
1943 #[cfg(feature = "cli-support")]
1944 pub fn set_context_for_buffer(
1945 &mut self,
1946 project: &Entity<Project>,
1947 related_files: Vec<RelatedFile>,
1948 cx: &mut Context<Self>,
1949 ) {
1950 self.get_or_init_project(project, cx)
1951 .context
1952 .update(cx, |store, _| {
1953 store.set_related_files(related_files);
1954 });
1955 }
1956
1957 fn is_file_open_source(
1958 &self,
1959 project: &Entity<Project>,
1960 file: &Arc<dyn File>,
1961 cx: &App,
1962 ) -> bool {
1963 if !file.is_local() || file.is_private() {
1964 return false;
1965 }
1966 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1967 return false;
1968 };
1969 project_state
1970 .license_detection_watchers
1971 .get(&file.worktree_id(cx))
1972 .as_ref()
1973 .is_some_and(|watcher| watcher.is_project_open_source())
1974 }
1975
1976 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1977 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1978 }
1979
1980 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1981 if !self.data_collection_choice.is_enabled() {
1982 return false;
1983 }
1984 events.iter().all(|event| {
1985 matches!(
1986 event.as_ref(),
1987 zeta_prompt::Event::BufferChange {
1988 in_open_source_repo: true,
1989 ..
1990 }
1991 )
1992 })
1993 }
1994
1995 fn load_data_collection_choice() -> DataCollectionChoice {
1996 let choice = KEY_VALUE_STORE
1997 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1998 .log_err()
1999 .flatten();
2000
2001 match choice.as_deref() {
2002 Some("true") => DataCollectionChoice::Enabled,
2003 Some("false") => DataCollectionChoice::Disabled,
2004 Some(_) => {
2005 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2006 DataCollectionChoice::NotAnswered
2007 }
2008 None => DataCollectionChoice::NotAnswered,
2009 }
2010 }
2011
2012 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2013 self.data_collection_choice = self.data_collection_choice.toggle();
2014 let new_choice = self.data_collection_choice;
2015 db::write_and_log(cx, move || {
2016 KEY_VALUE_STORE.write_kvp(
2017 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2018 new_choice.is_enabled().to_string(),
2019 )
2020 });
2021 }
2022
2023 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2024 self.shown_predictions.iter()
2025 }
2026
2027 pub fn shown_completions_len(&self) -> usize {
2028 self.shown_predictions.len()
2029 }
2030
2031 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2032 self.rated_predictions.contains(id)
2033 }
2034
2035 pub fn rate_prediction(
2036 &mut self,
2037 prediction: &EditPrediction,
2038 rating: EditPredictionRating,
2039 feedback: String,
2040 cx: &mut Context<Self>,
2041 ) {
2042 self.rated_predictions.insert(prediction.id.clone());
2043 telemetry::event!(
2044 "Edit Prediction Rated",
2045 rating,
2046 inputs = prediction.inputs,
2047 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2048 feedback
2049 );
2050 self.client.telemetry().flush_events().detach();
2051 cx.notify();
2052 }
2053
2054 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2055 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2056 && all_language_settings(None, cx).edit_predictions.use_context;
2057 }
2058}
2059
2060#[derive(Error, Debug)]
2061#[error(
2062 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2063)]
2064pub struct ZedUpdateRequiredError {
2065 minimum_version: Version,
2066}
2067
2068#[cfg(feature = "cli-support")]
2069pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2070
2071#[cfg(feature = "cli-support")]
2072#[derive(Debug, Clone, Copy, PartialEq)]
2073pub enum EvalCacheEntryKind {
2074 Context,
2075 Search,
2076 Prediction,
2077}
2078
2079#[cfg(feature = "cli-support")]
2080impl std::fmt::Display for EvalCacheEntryKind {
2081 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2082 match self {
2083 EvalCacheEntryKind::Search => write!(f, "search"),
2084 EvalCacheEntryKind::Context => write!(f, "context"),
2085 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2086 }
2087 }
2088}
2089
2090#[cfg(feature = "cli-support")]
2091pub trait EvalCache: Send + Sync {
2092 fn read(&self, key: EvalCacheKey) -> Option<String>;
2093 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2094}
2095
2096#[derive(Debug, Clone, Copy)]
2097pub enum DataCollectionChoice {
2098 NotAnswered,
2099 Enabled,
2100 Disabled,
2101}
2102
2103impl DataCollectionChoice {
2104 pub fn is_enabled(self) -> bool {
2105 match self {
2106 Self::Enabled => true,
2107 Self::NotAnswered | Self::Disabled => false,
2108 }
2109 }
2110
2111 pub fn is_answered(self) -> bool {
2112 match self {
2113 Self::Enabled | Self::Disabled => true,
2114 Self::NotAnswered => false,
2115 }
2116 }
2117
2118 #[must_use]
2119 pub fn toggle(&self) -> DataCollectionChoice {
2120 match self {
2121 Self::Enabled => Self::Disabled,
2122 Self::Disabled => Self::Enabled,
2123 Self::NotAnswered => Self::Enabled,
2124 }
2125 }
2126}
2127
2128impl From<bool> for DataCollectionChoice {
2129 fn from(value: bool) -> Self {
2130 match value {
2131 true => DataCollectionChoice::Enabled,
2132 false => DataCollectionChoice::Disabled,
2133 }
2134 }
2135}
2136
2137struct ZedPredictUpsell;
2138
2139impl Dismissable for ZedPredictUpsell {
2140 const KEY: &'static str = "dismissed-edit-predict-upsell";
2141
2142 fn dismissed() -> bool {
2143 // To make this backwards compatible with older versions of Zed, we
2144 // check if the user has seen the previous Edit Prediction Onboarding
2145 // before, by checking the data collection choice which was written to
2146 // the database once the user clicked on "Accept and Enable"
2147 if KEY_VALUE_STORE
2148 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2149 .log_err()
2150 .is_some_and(|s| s.is_some())
2151 {
2152 return true;
2153 }
2154
2155 KEY_VALUE_STORE
2156 .read_kvp(Self::KEY)
2157 .log_err()
2158 .is_some_and(|s| s.is_some())
2159 }
2160}
2161
2162pub fn should_show_upsell_modal() -> bool {
2163 !ZedPredictUpsell::dismissed()
2164}
2165
2166pub fn init(cx: &mut App) {
2167 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2168 workspace.register_action(
2169 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2170 ZedPredictModal::toggle(
2171 workspace,
2172 workspace.user_store().clone(),
2173 workspace.client().clone(),
2174 window,
2175 cx,
2176 )
2177 },
2178 );
2179
2180 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2181 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2182 settings
2183 .project
2184 .all_languages
2185 .features
2186 .get_or_insert_default()
2187 .edit_prediction_provider = Some(EditPredictionProvider::None)
2188 });
2189 });
2190 })
2191 .detach();
2192}