1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::http_client::Url;
23use gpui::{
24 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
25 http_client::{self, AsyncBody, Method},
26 prelude::*,
27};
28use language::language_settings::all_language_settings;
29use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
30use language::{BufferSnapshot, OffsetRangeExt};
31use language_model::{LlmApiToken, RefreshLlmTokenListener};
32use project::{Project, ProjectPath, WorktreeId};
33use release_channel::AppVersion;
34use semver::Version;
35use serde::de::DeserializeOwned;
36use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
37use std::collections::{VecDeque, hash_map};
38use text::Edit;
39use workspace::Workspace;
40
41use std::ops::Range;
42use std::path::Path;
43use std::rc::Rc;
44use std::str::FromStr as _;
45use std::sync::{Arc, LazyLock};
46use std::time::{Duration, Instant};
47use std::{env, mem};
48use thiserror::Error;
49use util::{RangeExt as _, ResultExt as _};
50use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
51
52pub mod cursor_excerpt;
53pub mod example_spec;
54mod license_detection;
55pub mod mercury;
56mod onboarding_modal;
57pub mod open_ai_response;
58mod prediction;
59pub mod sweep_ai;
60
61pub mod udiff;
62
63mod capture_example;
64mod zed_edit_prediction_delegate;
65pub mod zeta1;
66pub mod zeta2;
67
68#[cfg(test)]
69mod edit_prediction_tests;
70
71use crate::license_detection::LicenseDetectionWatcher;
72use crate::mercury::Mercury;
73use crate::onboarding_modal::ZedPredictModal;
74pub use crate::prediction::EditPrediction;
75pub use crate::prediction::EditPredictionId;
76use crate::prediction::EditPredictionResult;
77pub use crate::sweep_ai::SweepAi;
78pub use capture_example::capture_example;
79pub use language_model::ApiKeyState;
80pub use telemetry_events::EditPredictionRating;
81pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
82
83actions!(
84 edit_prediction,
85 [
86 /// Resets the edit prediction onboarding state.
87 ResetOnboarding,
88 /// Clears the edit prediction history.
89 ClearHistory,
90 ]
91);
92
93/// Maximum number of events to track.
94const EVENT_COUNT_MAX: usize = 6;
95const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
96const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
97const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
98const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
99
100pub struct SweepFeatureFlag;
101
102impl FeatureFlag for SweepFeatureFlag {
103 const NAME: &str = "sweep-ai";
104}
105
106pub struct MercuryFeatureFlag;
107
108impl FeatureFlag for MercuryFeatureFlag {
109 const NAME: &str = "mercury";
110}
111
112pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
113 context: EditPredictionExcerptOptions {
114 max_bytes: 512,
115 min_bytes: 128,
116 target_before_cursor_over_total_bytes: 0.5,
117 },
118 prompt_format: PromptFormat::DEFAULT,
119};
120
121static USE_OLLAMA: LazyLock<bool> =
122 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
123
124static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
125 match env::var("ZED_ZETA2_MODEL").as_deref() {
126 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
127 Ok(model) => model,
128 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
129 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
130 }
131 .to_string()
132});
133
134pub struct Zeta2FeatureFlag;
135
136impl FeatureFlag for Zeta2FeatureFlag {
137 const NAME: &'static str = "zeta2";
138
139 fn enabled_for_staff() -> bool {
140 true
141 }
142}
143
144#[derive(Clone)]
145struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
146
147impl Global for EditPredictionStoreGlobal {}
148
149pub struct EditPredictionStore {
150 client: Arc<Client>,
151 user_store: Entity<UserStore>,
152 llm_token: LlmApiToken,
153 _llm_token_subscription: Subscription,
154 projects: HashMap<EntityId, ProjectState>,
155 use_context: bool,
156 options: ZetaOptions,
157 update_required: bool,
158 #[cfg(feature = "cli-support")]
159 eval_cache: Option<Arc<dyn EvalCache>>,
160 edit_prediction_model: EditPredictionModel,
161 pub sweep_ai: SweepAi,
162 pub mercury: Mercury,
163 data_collection_choice: DataCollectionChoice,
164 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
165 shown_predictions: VecDeque<EditPrediction>,
166 rated_predictions: HashSet<EditPredictionId>,
167 custom_predict_edits_url: Option<Arc<Url>>,
168}
169
170#[derive(Copy, Clone, Default, PartialEq, Eq)]
171pub enum EditPredictionModel {
172 #[default]
173 Zeta1,
174 Zeta2,
175 Sweep,
176 Mercury,
177}
178
179pub struct EditPredictionModelInput {
180 project: Entity<Project>,
181 buffer: Entity<Buffer>,
182 snapshot: BufferSnapshot,
183 position: Anchor,
184 events: Vec<Arc<zeta_prompt::Event>>,
185 related_files: Arc<[RelatedFile]>,
186 recent_paths: VecDeque<ProjectPath>,
187 trigger: PredictEditsRequestTrigger,
188 diagnostic_search_range: Range<Point>,
189 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
190}
191
192#[derive(Debug, Clone, PartialEq)]
193pub struct ZetaOptions {
194 pub context: EditPredictionExcerptOptions,
195 pub prompt_format: predict_edits_v3::PromptFormat,
196}
197
198#[derive(Debug)]
199pub enum DebugEvent {
200 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
201 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
202 EditPredictionStarted(EditPredictionStartedDebugEvent),
203 EditPredictionFinished(EditPredictionFinishedDebugEvent),
204}
205
206#[derive(Debug)]
207pub struct ContextRetrievalStartedDebugEvent {
208 pub project_entity_id: EntityId,
209 pub timestamp: Instant,
210 pub search_prompt: String,
211}
212
213#[derive(Debug)]
214pub struct ContextRetrievalFinishedDebugEvent {
215 pub project_entity_id: EntityId,
216 pub timestamp: Instant,
217 pub metadata: Vec<(&'static str, SharedString)>,
218}
219
220#[derive(Debug)]
221pub struct EditPredictionStartedDebugEvent {
222 pub buffer: WeakEntity<Buffer>,
223 pub position: Anchor,
224 pub prompt: Option<String>,
225}
226
227#[derive(Debug)]
228pub struct EditPredictionFinishedDebugEvent {
229 pub buffer: WeakEntity<Buffer>,
230 pub position: Anchor,
231 pub model_output: Option<String>,
232}
233
234pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
235
236/// An event with associated metadata for reconstructing buffer state.
237#[derive(Clone)]
238pub struct StoredEvent {
239 pub event: Arc<zeta_prompt::Event>,
240 pub old_snapshot: TextBufferSnapshot,
241}
242
243struct ProjectState {
244 events: VecDeque<StoredEvent>,
245 last_event: Option<LastEvent>,
246 recent_paths: VecDeque<ProjectPath>,
247 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
248 current_prediction: Option<CurrentEditPrediction>,
249 next_pending_prediction_id: usize,
250 pending_predictions: ArrayVec<PendingPrediction, 2>,
251 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
252 last_prediction_refresh: Option<(EntityId, Instant)>,
253 cancelled_predictions: HashSet<usize>,
254 context: Entity<RelatedExcerptStore>,
255 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
256 _subscription: gpui::Subscription,
257}
258
259impl ProjectState {
260 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
261 self.events
262 .iter()
263 .cloned()
264 .chain(
265 self.last_event
266 .as_ref()
267 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
268 )
269 .collect()
270 }
271
272 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
273 self.events
274 .iter()
275 .cloned()
276 .chain(self.last_event.as_ref().iter().flat_map(|event| {
277 let (one, two) = event.split_by_pause();
278 let one = one.finalize(&self.license_detection_watchers, cx);
279 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
280 one.into_iter().chain(two)
281 }))
282 .collect()
283 }
284
285 fn cancel_pending_prediction(
286 &mut self,
287 pending_prediction: PendingPrediction,
288 cx: &mut Context<EditPredictionStore>,
289 ) {
290 self.cancelled_predictions.insert(pending_prediction.id);
291
292 cx.spawn(async move |this, cx| {
293 let Some(prediction_id) = pending_prediction.task.await else {
294 return;
295 };
296
297 this.update(cx, |this, _cx| {
298 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
299 })
300 .ok();
301 })
302 .detach()
303 }
304
305 fn active_buffer(
306 &self,
307 project: &Entity<Project>,
308 cx: &App,
309 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
310 let project = project.read(cx);
311 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
312 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
313 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
314 Some((active_buffer, registered_buffer.last_position))
315 }
316}
317
318#[derive(Debug, Clone)]
319struct CurrentEditPrediction {
320 pub requested_by: PredictionRequestedBy,
321 pub prediction: EditPrediction,
322 pub was_shown: bool,
323}
324
325impl CurrentEditPrediction {
326 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
327 let Some(new_edits) = self
328 .prediction
329 .interpolate(&self.prediction.buffer.read(cx))
330 else {
331 return false;
332 };
333
334 if self.prediction.buffer != old_prediction.prediction.buffer {
335 return true;
336 }
337
338 let Some(old_edits) = old_prediction
339 .prediction
340 .interpolate(&old_prediction.prediction.buffer.read(cx))
341 else {
342 return true;
343 };
344
345 let requested_by_buffer_id = self.requested_by.buffer_id();
346
347 // This reduces the occurrence of UI thrash from replacing edits
348 //
349 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
350 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
351 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
352 && old_edits.len() == 1
353 && new_edits.len() == 1
354 {
355 let (old_range, old_text) = &old_edits[0];
356 let (new_range, new_text) = &new_edits[0];
357 new_range == old_range && new_text.starts_with(old_text.as_ref())
358 } else {
359 true
360 }
361 }
362}
363
364#[derive(Debug, Clone)]
365enum PredictionRequestedBy {
366 DiagnosticsUpdate,
367 Buffer(EntityId),
368}
369
370impl PredictionRequestedBy {
371 pub fn buffer_id(&self) -> Option<EntityId> {
372 match self {
373 PredictionRequestedBy::DiagnosticsUpdate => None,
374 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
375 }
376 }
377}
378
379#[derive(Debug)]
380struct PendingPrediction {
381 id: usize,
382 task: Task<Option<EditPredictionId>>,
383}
384
385/// A prediction from the perspective of a buffer.
386#[derive(Debug)]
387enum BufferEditPrediction<'a> {
388 Local { prediction: &'a EditPrediction },
389 Jump { prediction: &'a EditPrediction },
390}
391
392#[cfg(test)]
393impl std::ops::Deref for BufferEditPrediction<'_> {
394 type Target = EditPrediction;
395
396 fn deref(&self) -> &Self::Target {
397 match self {
398 BufferEditPrediction::Local { prediction } => prediction,
399 BufferEditPrediction::Jump { prediction } => prediction,
400 }
401 }
402}
403
404struct RegisteredBuffer {
405 file: Option<Arc<dyn File>>,
406 snapshot: TextBufferSnapshot,
407 last_position: Option<Anchor>,
408 _subscriptions: [gpui::Subscription; 2],
409}
410
411#[derive(Clone)]
412struct LastEvent {
413 old_snapshot: TextBufferSnapshot,
414 new_snapshot: TextBufferSnapshot,
415 old_file: Option<Arc<dyn File>>,
416 new_file: Option<Arc<dyn File>>,
417 end_edit_anchor: Option<Anchor>,
418 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
419 last_edit_time: Option<Instant>,
420}
421
422impl LastEvent {
423 pub fn finalize(
424 &self,
425 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
426 cx: &App,
427 ) -> Option<StoredEvent> {
428 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
429 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
430
431 let in_open_source_repo =
432 [self.new_file.as_ref(), self.old_file.as_ref()]
433 .iter()
434 .all(|file| {
435 file.is_some_and(|file| {
436 license_detection_watchers
437 .get(&file.worktree_id(cx))
438 .is_some_and(|watcher| watcher.is_project_open_source())
439 })
440 });
441
442 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
443
444 if path == old_path && diff.is_empty() {
445 None
446 } else {
447 Some(StoredEvent {
448 event: Arc::new(zeta_prompt::Event::BufferChange {
449 old_path,
450 path,
451 diff,
452 in_open_source_repo,
453 // TODO: Actually detect if this edit was predicted or not
454 predicted: false,
455 }),
456 old_snapshot: self.old_snapshot.clone(),
457 })
458 }
459 }
460
461 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
462 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
463 return (self.clone(), None);
464 };
465
466 let before = LastEvent {
467 old_snapshot: self.old_snapshot.clone(),
468 new_snapshot: boundary_snapshot.clone(),
469 old_file: self.old_file.clone(),
470 new_file: self.new_file.clone(),
471 end_edit_anchor: self.end_edit_anchor,
472 snapshot_after_last_editing_pause: None,
473 last_edit_time: self.last_edit_time,
474 };
475
476 let after = LastEvent {
477 old_snapshot: boundary_snapshot.clone(),
478 new_snapshot: self.new_snapshot.clone(),
479 old_file: self.old_file.clone(),
480 new_file: self.new_file.clone(),
481 end_edit_anchor: self.end_edit_anchor,
482 snapshot_after_last_editing_pause: None,
483 last_edit_time: self.last_edit_time,
484 };
485
486 (before, Some(after))
487 }
488}
489
490pub(crate) fn compute_diff_between_snapshots(
491 old_snapshot: &TextBufferSnapshot,
492 new_snapshot: &TextBufferSnapshot,
493) -> Option<String> {
494 let edits: Vec<Edit<usize>> = new_snapshot
495 .edits_since::<usize>(&old_snapshot.version)
496 .collect();
497
498 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
499
500 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
501 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
502 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
503 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
504
505 const CONTEXT_LINES: u32 = 3;
506
507 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
508 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
509 let old_context_end_row =
510 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
511 let new_context_end_row =
512 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
513
514 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
515 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
516 let old_end_line_offset = old_snapshot
517 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
518 let new_end_line_offset = new_snapshot
519 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
520 let old_edit_range = old_start_line_offset..old_end_line_offset;
521 let new_edit_range = new_start_line_offset..new_end_line_offset;
522
523 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
524 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
525
526 let diff = language::unified_diff_with_offsets(
527 &old_region_text,
528 &new_region_text,
529 old_context_start_row,
530 new_context_start_row,
531 );
532
533 Some(diff)
534}
535
536fn buffer_path_with_id_fallback(
537 file: Option<&Arc<dyn File>>,
538 snapshot: &TextBufferSnapshot,
539 cx: &App,
540) -> Arc<Path> {
541 if let Some(file) = file {
542 file.full_path(cx).into()
543 } else {
544 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
545 }
546}
547
548impl EditPredictionStore {
549 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
550 cx.try_global::<EditPredictionStoreGlobal>()
551 .map(|global| global.0.clone())
552 }
553
554 pub fn global(
555 client: &Arc<Client>,
556 user_store: &Entity<UserStore>,
557 cx: &mut App,
558 ) -> Entity<Self> {
559 cx.try_global::<EditPredictionStoreGlobal>()
560 .map(|global| global.0.clone())
561 .unwrap_or_else(|| {
562 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
563 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
564 ep_store
565 })
566 }
567
568 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
569 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
570 let data_collection_choice = Self::load_data_collection_choice();
571
572 let llm_token = LlmApiToken::default();
573
574 let (reject_tx, reject_rx) = mpsc::unbounded();
575 cx.background_spawn({
576 let client = client.clone();
577 let llm_token = llm_token.clone();
578 let app_version = AppVersion::global(cx);
579 let background_executor = cx.background_executor().clone();
580 async move {
581 Self::handle_rejected_predictions(
582 reject_rx,
583 client,
584 llm_token,
585 app_version,
586 background_executor,
587 )
588 .await
589 }
590 })
591 .detach();
592
593 let mut this = Self {
594 projects: HashMap::default(),
595 client,
596 user_store,
597 options: DEFAULT_OPTIONS,
598 use_context: false,
599 llm_token,
600 _llm_token_subscription: cx.subscribe(
601 &refresh_llm_token_listener,
602 |this, _listener, _event, cx| {
603 let client = this.client.clone();
604 let llm_token = this.llm_token.clone();
605 cx.spawn(async move |_this, _cx| {
606 llm_token.refresh(&client).await?;
607 anyhow::Ok(())
608 })
609 .detach_and_log_err(cx);
610 },
611 ),
612 update_required: false,
613 #[cfg(feature = "cli-support")]
614 eval_cache: None,
615 edit_prediction_model: EditPredictionModel::Zeta2,
616 sweep_ai: SweepAi::new(cx),
617 mercury: Mercury::new(cx),
618 data_collection_choice,
619 reject_predictions_tx: reject_tx,
620 rated_predictions: Default::default(),
621 shown_predictions: Default::default(),
622 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
623 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
624 Err(_) => {
625 if *USE_OLLAMA {
626 Some(
627 Url::parse("http://localhost:11434/v1/chat/completions")
628 .unwrap()
629 .into(),
630 )
631 } else {
632 None
633 }
634 }
635 },
636 };
637
638 this.configure_context_retrieval(cx);
639 let weak_this = cx.weak_entity();
640 cx.on_flags_ready(move |_, cx| {
641 weak_this
642 .update(cx, |this, cx| this.configure_context_retrieval(cx))
643 .ok();
644 })
645 .detach();
646 cx.observe_global::<SettingsStore>(|this, cx| {
647 this.configure_context_retrieval(cx);
648 })
649 .detach();
650
651 this
652 }
653
654 #[cfg(test)]
655 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
656 self.custom_predict_edits_url = Some(url.into());
657 }
658
659 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
660 self.edit_prediction_model = model;
661 }
662
663 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
664 self.sweep_ai.api_token.read(cx).has_key()
665 }
666
667 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
668 self.mercury.api_token.read(cx).has_key()
669 }
670
671 #[cfg(feature = "cli-support")]
672 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
673 self.eval_cache = Some(cache);
674 }
675
676 pub fn options(&self) -> &ZetaOptions {
677 &self.options
678 }
679
680 pub fn set_options(&mut self, options: ZetaOptions) {
681 self.options = options;
682 }
683
684 pub fn set_use_context(&mut self, use_context: bool) {
685 self.use_context = use_context;
686 }
687
688 pub fn clear_history(&mut self) {
689 for project_state in self.projects.values_mut() {
690 project_state.events.clear();
691 project_state.last_event.take();
692 }
693 }
694
695 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
696 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
697 project_state.events.clear();
698 project_state.last_event.take();
699 }
700 }
701
702 pub fn edit_history_for_project(
703 &self,
704 project: &Entity<Project>,
705 cx: &App,
706 ) -> Vec<StoredEvent> {
707 self.projects
708 .get(&project.entity_id())
709 .map(|project_state| project_state.events(cx))
710 .unwrap_or_default()
711 }
712
713 pub fn edit_history_for_project_with_pause_split_last_event(
714 &self,
715 project: &Entity<Project>,
716 cx: &App,
717 ) -> Vec<StoredEvent> {
718 self.projects
719 .get(&project.entity_id())
720 .map(|project_state| project_state.events_split_by_pause(cx))
721 .unwrap_or_default()
722 }
723
724 pub fn context_for_project<'a>(
725 &'a self,
726 project: &Entity<Project>,
727 cx: &'a App,
728 ) -> Arc<[RelatedFile]> {
729 self.projects
730 .get(&project.entity_id())
731 .map(|project| project.context.read(cx).related_files())
732 .unwrap_or_else(|| vec![].into())
733 }
734
735 pub fn context_for_project_with_buffers<'a>(
736 &'a self,
737 project: &Entity<Project>,
738 cx: &'a App,
739 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
740 self.projects
741 .get(&project.entity_id())
742 .map(|project| project.context.read(cx).related_files_with_buffers())
743 }
744
745 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
746 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
747 self.user_store.read(cx).edit_prediction_usage()
748 } else {
749 None
750 }
751 }
752
753 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
754 self.get_or_init_project(project, cx);
755 }
756
757 pub fn register_buffer(
758 &mut self,
759 buffer: &Entity<Buffer>,
760 project: &Entity<Project>,
761 cx: &mut Context<Self>,
762 ) {
763 let project_state = self.get_or_init_project(project, cx);
764 Self::register_buffer_impl(project_state, buffer, project, cx);
765 }
766
767 fn get_or_init_project(
768 &mut self,
769 project: &Entity<Project>,
770 cx: &mut Context<Self>,
771 ) -> &mut ProjectState {
772 let entity_id = project.entity_id();
773 self.projects
774 .entry(entity_id)
775 .or_insert_with(|| ProjectState {
776 context: {
777 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
778 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
779 this.handle_excerpt_store_event(entity_id, event);
780 })
781 .detach();
782 related_excerpt_store
783 },
784 events: VecDeque::new(),
785 last_event: None,
786 recent_paths: VecDeque::new(),
787 debug_tx: None,
788 registered_buffers: HashMap::default(),
789 current_prediction: None,
790 cancelled_predictions: HashSet::default(),
791 pending_predictions: ArrayVec::new(),
792 next_pending_prediction_id: 0,
793 last_prediction_refresh: None,
794 license_detection_watchers: HashMap::default(),
795 _subscription: cx.subscribe(&project, Self::handle_project_event),
796 })
797 }
798
799 pub fn remove_project(&mut self, project: &Entity<Project>) {
800 self.projects.remove(&project.entity_id());
801 }
802
803 fn handle_excerpt_store_event(
804 &mut self,
805 project_entity_id: EntityId,
806 event: &RelatedExcerptStoreEvent,
807 ) {
808 if let Some(project_state) = self.projects.get(&project_entity_id) {
809 if let Some(debug_tx) = project_state.debug_tx.clone() {
810 match event {
811 RelatedExcerptStoreEvent::StartedRefresh => {
812 debug_tx
813 .unbounded_send(DebugEvent::ContextRetrievalStarted(
814 ContextRetrievalStartedDebugEvent {
815 project_entity_id: project_entity_id,
816 timestamp: Instant::now(),
817 search_prompt: String::new(),
818 },
819 ))
820 .ok();
821 }
822 RelatedExcerptStoreEvent::FinishedRefresh {
823 cache_hit_count,
824 cache_miss_count,
825 mean_definition_latency,
826 max_definition_latency,
827 } => {
828 debug_tx
829 .unbounded_send(DebugEvent::ContextRetrievalFinished(
830 ContextRetrievalFinishedDebugEvent {
831 project_entity_id: project_entity_id,
832 timestamp: Instant::now(),
833 metadata: vec![
834 (
835 "Cache Hits",
836 format!(
837 "{}/{}",
838 cache_hit_count,
839 cache_hit_count + cache_miss_count
840 )
841 .into(),
842 ),
843 (
844 "Max LSP Time",
845 format!("{} ms", max_definition_latency.as_millis())
846 .into(),
847 ),
848 (
849 "Mean LSP Time",
850 format!("{} ms", mean_definition_latency.as_millis())
851 .into(),
852 ),
853 ],
854 },
855 ))
856 .ok();
857 }
858 }
859 }
860 }
861 }
862
863 pub fn debug_info(
864 &mut self,
865 project: &Entity<Project>,
866 cx: &mut Context<Self>,
867 ) -> mpsc::UnboundedReceiver<DebugEvent> {
868 let project_state = self.get_or_init_project(project, cx);
869 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
870 project_state.debug_tx = Some(debug_watch_tx);
871 debug_watch_rx
872 }
873
874 fn handle_project_event(
875 &mut self,
876 project: Entity<Project>,
877 event: &project::Event,
878 cx: &mut Context<Self>,
879 ) {
880 // TODO [zeta2] init with recent paths
881 match event {
882 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
883 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
884 return;
885 };
886 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
887 if let Some(path) = path {
888 if let Some(ix) = project_state
889 .recent_paths
890 .iter()
891 .position(|probe| probe == &path)
892 {
893 project_state.recent_paths.remove(ix);
894 }
895 project_state.recent_paths.push_front(path);
896 }
897 }
898 project::Event::DiagnosticsUpdated { .. } => {
899 if cx.has_flag::<Zeta2FeatureFlag>() {
900 self.refresh_prediction_from_diagnostics(project, cx);
901 }
902 }
903 _ => (),
904 }
905 }
906
907 fn register_buffer_impl<'a>(
908 project_state: &'a mut ProjectState,
909 buffer: &Entity<Buffer>,
910 project: &Entity<Project>,
911 cx: &mut Context<Self>,
912 ) -> &'a mut RegisteredBuffer {
913 let buffer_id = buffer.entity_id();
914
915 if let Some(file) = buffer.read(cx).file() {
916 let worktree_id = file.worktree_id(cx);
917 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
918 project_state
919 .license_detection_watchers
920 .entry(worktree_id)
921 .or_insert_with(|| {
922 let project_entity_id = project.entity_id();
923 cx.observe_release(&worktree, move |this, _worktree, _cx| {
924 let Some(project_state) = this.projects.get_mut(&project_entity_id)
925 else {
926 return;
927 };
928 project_state
929 .license_detection_watchers
930 .remove(&worktree_id);
931 })
932 .detach();
933 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
934 });
935 }
936 }
937
938 match project_state.registered_buffers.entry(buffer_id) {
939 hash_map::Entry::Occupied(entry) => entry.into_mut(),
940 hash_map::Entry::Vacant(entry) => {
941 let buf = buffer.read(cx);
942 let snapshot = buf.text_snapshot();
943 let file = buf.file().cloned();
944 let project_entity_id = project.entity_id();
945 entry.insert(RegisteredBuffer {
946 snapshot,
947 file,
948 last_position: None,
949 _subscriptions: [
950 cx.subscribe(buffer, {
951 let project = project.downgrade();
952 move |this, buffer, event, cx| {
953 if let language::BufferEvent::Edited = event
954 && let Some(project) = project.upgrade()
955 {
956 this.report_changes_for_buffer(&buffer, &project, cx);
957 }
958 }
959 }),
960 cx.observe_release(buffer, move |this, _buffer, _cx| {
961 let Some(project_state) = this.projects.get_mut(&project_entity_id)
962 else {
963 return;
964 };
965 project_state.registered_buffers.remove(&buffer_id);
966 }),
967 ],
968 })
969 }
970 }
971 }
972
973 fn report_changes_for_buffer(
974 &mut self,
975 buffer: &Entity<Buffer>,
976 project: &Entity<Project>,
977 cx: &mut Context<Self>,
978 ) {
979 let project_state = self.get_or_init_project(project, cx);
980 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
981
982 let buf = buffer.read(cx);
983 let new_file = buf.file().cloned();
984 let new_snapshot = buf.text_snapshot();
985 if new_snapshot.version == registered_buffer.snapshot.version {
986 return;
987 }
988
989 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
990 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
991 let end_edit_anchor = new_snapshot
992 .anchored_edits_since::<Point>(&old_snapshot.version)
993 .last()
994 .map(|(_, range)| range.end);
995 let events = &mut project_state.events;
996
997 let now = cx.background_executor().now();
998 if let Some(last_event) = project_state.last_event.as_mut() {
999 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1000 == last_event.new_snapshot.remote_id()
1001 && old_snapshot.version == last_event.new_snapshot.version;
1002
1003 let should_coalesce = is_next_snapshot_of_same_buffer
1004 && end_edit_anchor
1005 .as_ref()
1006 .zip(last_event.end_edit_anchor.as_ref())
1007 .is_some_and(|(a, b)| {
1008 let a = a.to_point(&new_snapshot);
1009 let b = b.to_point(&new_snapshot);
1010 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
1011 });
1012
1013 if should_coalesce {
1014 let pause_elapsed = last_event
1015 .last_edit_time
1016 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1017 .unwrap_or(false);
1018 if pause_elapsed {
1019 last_event.snapshot_after_last_editing_pause =
1020 Some(last_event.new_snapshot.clone());
1021 }
1022
1023 last_event.end_edit_anchor = end_edit_anchor;
1024 last_event.new_snapshot = new_snapshot;
1025 last_event.last_edit_time = Some(now);
1026 return;
1027 }
1028 }
1029
1030 if events.len() + 1 >= EVENT_COUNT_MAX {
1031 events.pop_front();
1032 }
1033
1034 if let Some(event) = project_state.last_event.take() {
1035 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
1036 }
1037
1038 project_state.last_event = Some(LastEvent {
1039 old_file,
1040 new_file,
1041 old_snapshot,
1042 new_snapshot,
1043 end_edit_anchor,
1044 snapshot_after_last_editing_pause: None,
1045 last_edit_time: Some(now),
1046 });
1047 }
1048
1049 fn prediction_at(
1050 &mut self,
1051 buffer: &Entity<Buffer>,
1052 position: Option<language::Anchor>,
1053 project: &Entity<Project>,
1054 cx: &App,
1055 ) -> Option<BufferEditPrediction<'_>> {
1056 let project_state = self.projects.get_mut(&project.entity_id())?;
1057 if let Some(position) = position
1058 && let Some(buffer) = project_state
1059 .registered_buffers
1060 .get_mut(&buffer.entity_id())
1061 {
1062 buffer.last_position = Some(position);
1063 }
1064
1065 let CurrentEditPrediction {
1066 requested_by,
1067 prediction,
1068 ..
1069 } = project_state.current_prediction.as_ref()?;
1070
1071 if prediction.targets_buffer(buffer.read(cx)) {
1072 Some(BufferEditPrediction::Local { prediction })
1073 } else {
1074 let show_jump = match requested_by {
1075 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1076 requested_by_buffer_id == &buffer.entity_id()
1077 }
1078 PredictionRequestedBy::DiagnosticsUpdate => true,
1079 };
1080
1081 if show_jump {
1082 Some(BufferEditPrediction::Jump { prediction })
1083 } else {
1084 None
1085 }
1086 }
1087 }
1088
1089 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1090 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
1091 match self.edit_prediction_model {
1092 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1093 if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
1094 return;
1095 }
1096 }
1097 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1098 }
1099
1100 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1101 return;
1102 };
1103
1104 let Some(prediction) = project_state.current_prediction.take() else {
1105 return;
1106 };
1107 let request_id = prediction.prediction.id.to_string();
1108 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1109 project_state.cancel_pending_prediction(pending_prediction, cx);
1110 }
1111
1112 let client = self.client.clone();
1113 let llm_token = self.llm_token.clone();
1114 let app_version = AppVersion::global(cx);
1115 cx.spawn(async move |this, cx| {
1116 let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
1117 (http_client::Url::parse(&accept_edits_url)?, false)
1118 } else {
1119 (
1120 client
1121 .http_client()
1122 .build_zed_llm_url("/predict_edits/accept", &[])?,
1123 true,
1124 )
1125 };
1126
1127 let response = cx
1128 .background_spawn(Self::send_api_request::<()>(
1129 move |builder| {
1130 let req = builder.uri(url.as_ref()).body(
1131 serde_json::to_string(&AcceptEditPredictionBody {
1132 request_id: request_id.clone(),
1133 })?
1134 .into(),
1135 );
1136 Ok(req?)
1137 },
1138 client,
1139 llm_token,
1140 app_version,
1141 require_auth,
1142 ))
1143 .await;
1144
1145 Self::handle_api_response(&this, response, cx)?;
1146 anyhow::Ok(())
1147 })
1148 .detach_and_log_err(cx);
1149 }
1150
1151 async fn handle_rejected_predictions(
1152 rx: UnboundedReceiver<EditPredictionRejection>,
1153 client: Arc<Client>,
1154 llm_token: LlmApiToken,
1155 app_version: Version,
1156 background_executor: BackgroundExecutor,
1157 ) {
1158 let mut rx = std::pin::pin!(rx.peekable());
1159 let mut batched = Vec::new();
1160
1161 while let Some(rejection) = rx.next().await {
1162 batched.push(rejection);
1163
1164 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1165 select_biased! {
1166 next = rx.as_mut().peek().fuse() => {
1167 if next.is_some() {
1168 continue;
1169 }
1170 }
1171 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1172 }
1173 }
1174
1175 let url = client
1176 .http_client()
1177 .build_zed_llm_url("/predict_edits/reject", &[])
1178 .unwrap();
1179
1180 let flush_count = batched
1181 .len()
1182 // in case items have accumulated after failure
1183 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1184 let start = batched.len() - flush_count;
1185
1186 let body = RejectEditPredictionsBodyRef {
1187 rejections: &batched[start..],
1188 };
1189
1190 let result = Self::send_api_request::<()>(
1191 |builder| {
1192 let req = builder
1193 .uri(url.as_ref())
1194 .body(serde_json::to_string(&body)?.into());
1195 anyhow::Ok(req?)
1196 },
1197 client.clone(),
1198 llm_token.clone(),
1199 app_version.clone(),
1200 true,
1201 )
1202 .await;
1203
1204 if result.log_err().is_some() {
1205 batched.drain(start..);
1206 }
1207 }
1208 }
1209
1210 fn reject_current_prediction(
1211 &mut self,
1212 reason: EditPredictionRejectReason,
1213 project: &Entity<Project>,
1214 ) {
1215 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1216 project_state.pending_predictions.clear();
1217 if let Some(prediction) = project_state.current_prediction.take() {
1218 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1219 }
1220 };
1221 }
1222
1223 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1224 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1225 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1226 if !current_prediction.was_shown {
1227 current_prediction.was_shown = true;
1228 self.shown_predictions
1229 .push_front(current_prediction.prediction.clone());
1230 if self.shown_predictions.len() > 50 {
1231 let completion = self.shown_predictions.pop_back().unwrap();
1232 self.rated_predictions.remove(&completion.id);
1233 }
1234 }
1235 }
1236 }
1237 }
1238
1239 fn reject_prediction(
1240 &mut self,
1241 prediction_id: EditPredictionId,
1242 reason: EditPredictionRejectReason,
1243 was_shown: bool,
1244 ) {
1245 match self.edit_prediction_model {
1246 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1247 if self.custom_predict_edits_url.is_some() {
1248 return;
1249 }
1250 }
1251 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1252 }
1253
1254 self.reject_predictions_tx
1255 .unbounded_send(EditPredictionRejection {
1256 request_id: prediction_id.to_string(),
1257 reason,
1258 was_shown,
1259 })
1260 .log_err();
1261 }
1262
1263 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1264 self.projects
1265 .get(&project.entity_id())
1266 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1267 }
1268
1269 pub fn refresh_prediction_from_buffer(
1270 &mut self,
1271 project: Entity<Project>,
1272 buffer: Entity<Buffer>,
1273 position: language::Anchor,
1274 cx: &mut Context<Self>,
1275 ) {
1276 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1277 let Some(request_task) = this
1278 .update(cx, |this, cx| {
1279 this.request_prediction(
1280 &project,
1281 &buffer,
1282 position,
1283 PredictEditsRequestTrigger::Other,
1284 cx,
1285 )
1286 })
1287 .log_err()
1288 else {
1289 return Task::ready(anyhow::Ok(None));
1290 };
1291
1292 cx.spawn(async move |_cx| {
1293 request_task.await.map(|prediction_result| {
1294 prediction_result.map(|prediction_result| {
1295 (
1296 prediction_result,
1297 PredictionRequestedBy::Buffer(buffer.entity_id()),
1298 )
1299 })
1300 })
1301 })
1302 })
1303 }
1304
1305 pub fn refresh_prediction_from_diagnostics(
1306 &mut self,
1307 project: Entity<Project>,
1308 cx: &mut Context<Self>,
1309 ) {
1310 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1311 return;
1312 };
1313
1314 // Prefer predictions from buffer
1315 if project_state.current_prediction.is_some() {
1316 return;
1317 };
1318
1319 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1320 let Some((active_buffer, snapshot, cursor_point)) = this
1321 .read_with(cx, |this, cx| {
1322 let project_state = this.projects.get(&project.entity_id())?;
1323 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1324 let snapshot = buffer.read(cx).snapshot();
1325
1326 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1327 return None;
1328 }
1329
1330 let cursor_point = position
1331 .map(|pos| pos.to_point(&snapshot))
1332 .unwrap_or_default();
1333
1334 Some((buffer, snapshot, cursor_point))
1335 })
1336 .log_err()
1337 .flatten()
1338 else {
1339 return Task::ready(anyhow::Ok(None));
1340 };
1341
1342 cx.spawn(async move |cx| {
1343 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1344 active_buffer,
1345 &snapshot,
1346 Default::default(),
1347 cursor_point,
1348 &project,
1349 cx,
1350 )
1351 .await?
1352 else {
1353 return anyhow::Ok(None);
1354 };
1355
1356 let Some(prediction_result) = this
1357 .update(cx, |this, cx| {
1358 this.request_prediction(
1359 &project,
1360 &jump_buffer,
1361 jump_position,
1362 PredictEditsRequestTrigger::Diagnostics,
1363 cx,
1364 )
1365 })?
1366 .await?
1367 else {
1368 return anyhow::Ok(None);
1369 };
1370
1371 this.update(cx, |this, cx| {
1372 Some((
1373 if this
1374 .get_or_init_project(&project, cx)
1375 .current_prediction
1376 .is_none()
1377 {
1378 prediction_result
1379 } else {
1380 EditPredictionResult {
1381 id: prediction_result.id,
1382 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1383 }
1384 },
1385 PredictionRequestedBy::DiagnosticsUpdate,
1386 ))
1387 })
1388 })
1389 });
1390 }
1391
1392 fn predictions_enabled_at(
1393 snapshot: &BufferSnapshot,
1394 position: Option<language::Anchor>,
1395 cx: &App,
1396 ) -> bool {
1397 let file = snapshot.file();
1398 let all_settings = all_language_settings(file, cx);
1399 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1400 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1401 {
1402 return false;
1403 }
1404
1405 if let Some(last_position) = position {
1406 let settings = snapshot.settings_at(last_position, cx);
1407
1408 if !settings.edit_predictions_disabled_in.is_empty()
1409 && let Some(scope) = snapshot.language_scope_at(last_position)
1410 && let Some(scope_name) = scope.override_name()
1411 && settings
1412 .edit_predictions_disabled_in
1413 .iter()
1414 .any(|s| s == scope_name)
1415 {
1416 return false;
1417 }
1418 }
1419
1420 true
1421 }
1422
1423 #[cfg(not(test))]
1424 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1425 #[cfg(test)]
1426 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1427
1428 fn queue_prediction_refresh(
1429 &mut self,
1430 project: Entity<Project>,
1431 throttle_entity: EntityId,
1432 cx: &mut Context<Self>,
1433 do_refresh: impl FnOnce(
1434 WeakEntity<Self>,
1435 &mut AsyncApp,
1436 )
1437 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1438 + 'static,
1439 ) {
1440 let project_state = self.get_or_init_project(&project, cx);
1441 let pending_prediction_id = project_state.next_pending_prediction_id;
1442 project_state.next_pending_prediction_id += 1;
1443 let last_request = project_state.last_prediction_refresh;
1444
1445 let task = cx.spawn(async move |this, cx| {
1446 if let Some((last_entity, last_timestamp)) = last_request
1447 && throttle_entity == last_entity
1448 && let Some(timeout) =
1449 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1450 {
1451 cx.background_executor().timer(timeout).await;
1452 }
1453
1454 // If this task was cancelled before the throttle timeout expired,
1455 // do not perform a request.
1456 let mut is_cancelled = true;
1457 this.update(cx, |this, cx| {
1458 let project_state = this.get_or_init_project(&project, cx);
1459 if !project_state
1460 .cancelled_predictions
1461 .remove(&pending_prediction_id)
1462 {
1463 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1464 is_cancelled = false;
1465 }
1466 })
1467 .ok();
1468 if is_cancelled {
1469 return None;
1470 }
1471
1472 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1473 let new_prediction_id = new_prediction_result
1474 .as_ref()
1475 .map(|(prediction, _)| prediction.id.clone());
1476
1477 // When a prediction completes, remove it from the pending list, and cancel
1478 // any pending predictions that were enqueued before it.
1479 this.update(cx, |this, cx| {
1480 let project_state = this.get_or_init_project(&project, cx);
1481
1482 let is_cancelled = project_state
1483 .cancelled_predictions
1484 .remove(&pending_prediction_id);
1485
1486 let new_current_prediction = if !is_cancelled
1487 && let Some((prediction_result, requested_by)) = new_prediction_result
1488 {
1489 match prediction_result.prediction {
1490 Ok(prediction) => {
1491 let new_prediction = CurrentEditPrediction {
1492 requested_by,
1493 prediction,
1494 was_shown: false,
1495 };
1496
1497 if let Some(current_prediction) =
1498 project_state.current_prediction.as_ref()
1499 {
1500 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1501 {
1502 this.reject_current_prediction(
1503 EditPredictionRejectReason::Replaced,
1504 &project,
1505 );
1506
1507 Some(new_prediction)
1508 } else {
1509 this.reject_prediction(
1510 new_prediction.prediction.id,
1511 EditPredictionRejectReason::CurrentPreferred,
1512 false,
1513 );
1514 None
1515 }
1516 } else {
1517 Some(new_prediction)
1518 }
1519 }
1520 Err(reject_reason) => {
1521 this.reject_prediction(prediction_result.id, reject_reason, false);
1522 None
1523 }
1524 }
1525 } else {
1526 None
1527 };
1528
1529 let project_state = this.get_or_init_project(&project, cx);
1530
1531 if let Some(new_prediction) = new_current_prediction {
1532 project_state.current_prediction = Some(new_prediction);
1533 }
1534
1535 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1536 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1537 if pending_prediction.id == pending_prediction_id {
1538 pending_predictions.remove(ix);
1539 for pending_prediction in pending_predictions.drain(0..ix) {
1540 project_state.cancel_pending_prediction(pending_prediction, cx)
1541 }
1542 break;
1543 }
1544 }
1545 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1546 cx.notify();
1547 })
1548 .ok();
1549
1550 new_prediction_id
1551 });
1552
1553 if project_state.pending_predictions.len() <= 1 {
1554 project_state.pending_predictions.push(PendingPrediction {
1555 id: pending_prediction_id,
1556 task,
1557 });
1558 } else if project_state.pending_predictions.len() == 2 {
1559 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1560 project_state.pending_predictions.push(PendingPrediction {
1561 id: pending_prediction_id,
1562 task,
1563 });
1564 project_state.cancel_pending_prediction(pending_prediction, cx);
1565 }
1566 }
1567
1568 pub fn request_prediction(
1569 &mut self,
1570 project: &Entity<Project>,
1571 active_buffer: &Entity<Buffer>,
1572 position: language::Anchor,
1573 trigger: PredictEditsRequestTrigger,
1574 cx: &mut Context<Self>,
1575 ) -> Task<Result<Option<EditPredictionResult>>> {
1576 self.request_prediction_internal(
1577 project.clone(),
1578 active_buffer.clone(),
1579 position,
1580 trigger,
1581 cx.has_flag::<Zeta2FeatureFlag>(),
1582 cx,
1583 )
1584 }
1585
1586 fn request_prediction_internal(
1587 &mut self,
1588 project: Entity<Project>,
1589 active_buffer: Entity<Buffer>,
1590 position: language::Anchor,
1591 trigger: PredictEditsRequestTrigger,
1592 allow_jump: bool,
1593 cx: &mut Context<Self>,
1594 ) -> Task<Result<Option<EditPredictionResult>>> {
1595 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1596
1597 self.get_or_init_project(&project, cx);
1598 let project_state = self.projects.get(&project.entity_id()).unwrap();
1599 let stored_events = project_state.events(cx);
1600 let has_events = !stored_events.is_empty();
1601 let events: Vec<Arc<zeta_prompt::Event>> =
1602 stored_events.into_iter().map(|e| e.event).collect();
1603 let debug_tx = project_state.debug_tx.clone();
1604
1605 let snapshot = active_buffer.read(cx).snapshot();
1606 let cursor_point = position.to_point(&snapshot);
1607 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1608 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1609 let diagnostic_search_range =
1610 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1611
1612 let related_files = if self.use_context {
1613 self.context_for_project(&project, cx)
1614 } else {
1615 Vec::new().into()
1616 };
1617
1618 let inputs = EditPredictionModelInput {
1619 project: project.clone(),
1620 buffer: active_buffer.clone(),
1621 snapshot: snapshot.clone(),
1622 position,
1623 events,
1624 related_files,
1625 recent_paths: project_state.recent_paths.clone(),
1626 trigger,
1627 diagnostic_search_range: diagnostic_search_range.clone(),
1628 debug_tx,
1629 };
1630
1631 let task = match self.edit_prediction_model {
1632 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1633 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1634 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1635 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1636 };
1637
1638 cx.spawn(async move |this, cx| {
1639 let prediction = task.await?;
1640
1641 if prediction.is_none() && allow_jump {
1642 let cursor_point = position.to_point(&snapshot);
1643 if has_events
1644 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1645 active_buffer.clone(),
1646 &snapshot,
1647 diagnostic_search_range,
1648 cursor_point,
1649 &project,
1650 cx,
1651 )
1652 .await?
1653 {
1654 return this
1655 .update(cx, |this, cx| {
1656 this.request_prediction_internal(
1657 project,
1658 jump_buffer,
1659 jump_position,
1660 trigger,
1661 false,
1662 cx,
1663 )
1664 })?
1665 .await;
1666 }
1667
1668 return anyhow::Ok(None);
1669 }
1670
1671 Ok(prediction)
1672 })
1673 }
1674
1675 async fn next_diagnostic_location(
1676 active_buffer: Entity<Buffer>,
1677 active_buffer_snapshot: &BufferSnapshot,
1678 active_buffer_diagnostic_search_range: Range<Point>,
1679 active_buffer_cursor_point: Point,
1680 project: &Entity<Project>,
1681 cx: &mut AsyncApp,
1682 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1683 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1684 let mut jump_location = active_buffer_snapshot
1685 .diagnostic_groups(None)
1686 .into_iter()
1687 .filter_map(|(_, group)| {
1688 let range = &group.entries[group.primary_ix]
1689 .range
1690 .to_point(&active_buffer_snapshot);
1691 if range.overlaps(&active_buffer_diagnostic_search_range) {
1692 None
1693 } else {
1694 Some(range.start)
1695 }
1696 })
1697 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1698 .map(|position| {
1699 (
1700 active_buffer.clone(),
1701 active_buffer_snapshot.anchor_before(position),
1702 )
1703 });
1704
1705 if jump_location.is_none() {
1706 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1707 let file = buffer.file()?;
1708
1709 Some(ProjectPath {
1710 worktree_id: file.worktree_id(cx),
1711 path: file.path().clone(),
1712 })
1713 })?;
1714
1715 let buffer_task = project.update(cx, |project, cx| {
1716 let (path, _, _) = project
1717 .diagnostic_summaries(false, cx)
1718 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1719 .max_by_key(|(path, _, _)| {
1720 // find the buffer with errors that shares most parent directories
1721 path.path
1722 .components()
1723 .zip(
1724 active_buffer_path
1725 .as_ref()
1726 .map(|p| p.path.components())
1727 .unwrap_or_default(),
1728 )
1729 .take_while(|(a, b)| a == b)
1730 .count()
1731 })?;
1732
1733 Some(project.open_buffer(path, cx))
1734 })?;
1735
1736 if let Some(buffer_task) = buffer_task {
1737 let closest_buffer = buffer_task.await?;
1738
1739 jump_location = closest_buffer
1740 .read_with(cx, |buffer, _cx| {
1741 buffer
1742 .buffer_diagnostics(None)
1743 .into_iter()
1744 .min_by_key(|entry| entry.diagnostic.severity)
1745 .map(|entry| entry.range.start)
1746 })?
1747 .map(|position| (closest_buffer, position));
1748 }
1749 }
1750
1751 anyhow::Ok(jump_location)
1752 }
1753
1754 async fn send_raw_llm_request(
1755 request: open_ai::Request,
1756 client: Arc<Client>,
1757 llm_token: LlmApiToken,
1758 app_version: Version,
1759 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1760 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1761 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1762 let url = client
1763 .http_client()
1764 .build_zed_llm_url("/predict_edits/raw", &[])?;
1765
1766 #[cfg(feature = "cli-support")]
1767 let cache_key = if let Some(cache) = eval_cache {
1768 use collections::FxHasher;
1769 use std::hash::{Hash, Hasher};
1770
1771 let mut hasher = FxHasher::default();
1772 url.hash(&mut hasher);
1773 let request_str = serde_json::to_string_pretty(&request)?;
1774 request_str.hash(&mut hasher);
1775 let hash = hasher.finish();
1776
1777 let key = (eval_cache_kind, hash);
1778 if let Some(response_str) = cache.read(key) {
1779 return Ok((serde_json::from_str(&response_str)?, None));
1780 }
1781
1782 Some((cache, request_str, key))
1783 } else {
1784 None
1785 };
1786
1787 let (response, usage) = Self::send_api_request(
1788 |builder| {
1789 let req = builder
1790 .uri(url.as_ref())
1791 .body(serde_json::to_string(&request)?.into());
1792 Ok(req?)
1793 },
1794 client,
1795 llm_token,
1796 app_version,
1797 true,
1798 )
1799 .await?;
1800
1801 #[cfg(feature = "cli-support")]
1802 if let Some((cache, request, key)) = cache_key {
1803 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1804 }
1805
1806 Ok((response, usage))
1807 }
1808
1809 fn handle_api_response<T>(
1810 this: &WeakEntity<Self>,
1811 response: Result<(T, Option<EditPredictionUsage>)>,
1812 cx: &mut gpui::AsyncApp,
1813 ) -> Result<T> {
1814 match response {
1815 Ok((data, usage)) => {
1816 if let Some(usage) = usage {
1817 this.update(cx, |this, cx| {
1818 this.user_store.update(cx, |user_store, cx| {
1819 user_store.update_edit_prediction_usage(usage, cx);
1820 });
1821 })
1822 .ok();
1823 }
1824 Ok(data)
1825 }
1826 Err(err) => {
1827 if err.is::<ZedUpdateRequiredError>() {
1828 cx.update(|cx| {
1829 this.update(cx, |this, _cx| {
1830 this.update_required = true;
1831 })
1832 .ok();
1833
1834 let error_message: SharedString = err.to_string().into();
1835 show_app_notification(
1836 NotificationId::unique::<ZedUpdateRequiredError>(),
1837 cx,
1838 move |cx| {
1839 cx.new(|cx| {
1840 ErrorMessagePrompt::new(error_message.clone(), cx)
1841 .with_link_button("Update Zed", "https://zed.dev/releases")
1842 })
1843 },
1844 );
1845 })
1846 .ok();
1847 }
1848 Err(err)
1849 }
1850 }
1851 }
1852
1853 async fn send_api_request<Res>(
1854 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1855 client: Arc<Client>,
1856 llm_token: LlmApiToken,
1857 app_version: Version,
1858 require_auth: bool,
1859 ) -> Result<(Res, Option<EditPredictionUsage>)>
1860 where
1861 Res: DeserializeOwned,
1862 {
1863 let http_client = client.http_client();
1864
1865 let mut token = if require_auth {
1866 Some(llm_token.acquire(&client).await?)
1867 } else {
1868 llm_token.acquire(&client).await.ok()
1869 };
1870 let mut did_retry = false;
1871
1872 loop {
1873 let request_builder = http_client::Request::builder().method(Method::POST);
1874
1875 let mut request_builder = request_builder
1876 .header("Content-Type", "application/json")
1877 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
1878
1879 // Only add Authorization header if we have a token
1880 if let Some(ref token_value) = token {
1881 request_builder =
1882 request_builder.header("Authorization", format!("Bearer {}", token_value));
1883 }
1884
1885 let request = build(request_builder)?;
1886
1887 let mut response = http_client.send(request).await?;
1888
1889 if let Some(minimum_required_version) = response
1890 .headers()
1891 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1892 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1893 {
1894 anyhow::ensure!(
1895 app_version >= minimum_required_version,
1896 ZedUpdateRequiredError {
1897 minimum_version: minimum_required_version
1898 }
1899 );
1900 }
1901
1902 if response.status().is_success() {
1903 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1904
1905 let mut body = Vec::new();
1906 response.body_mut().read_to_end(&mut body).await?;
1907 return Ok((serde_json::from_slice(&body)?, usage));
1908 } else if !did_retry
1909 && token.is_some()
1910 && response
1911 .headers()
1912 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1913 .is_some()
1914 {
1915 did_retry = true;
1916 token = Some(llm_token.refresh(&client).await?);
1917 } else {
1918 let mut body = String::new();
1919 response.body_mut().read_to_string(&mut body).await?;
1920 anyhow::bail!(
1921 "Request failed with status: {:?}\nBody: {}",
1922 response.status(),
1923 body
1924 );
1925 }
1926 }
1927 }
1928
1929 pub fn refresh_context(
1930 &mut self,
1931 project: &Entity<Project>,
1932 buffer: &Entity<language::Buffer>,
1933 cursor_position: language::Anchor,
1934 cx: &mut Context<Self>,
1935 ) {
1936 if self.use_context {
1937 self.get_or_init_project(project, cx)
1938 .context
1939 .update(cx, |store, cx| {
1940 store.refresh(buffer.clone(), cursor_position, cx);
1941 });
1942 }
1943 }
1944
1945 #[cfg(feature = "cli-support")]
1946 pub fn set_context_for_buffer(
1947 &mut self,
1948 project: &Entity<Project>,
1949 related_files: Vec<RelatedFile>,
1950 cx: &mut Context<Self>,
1951 ) {
1952 self.get_or_init_project(project, cx)
1953 .context
1954 .update(cx, |store, _| {
1955 store.set_related_files(related_files);
1956 });
1957 }
1958
1959 fn is_file_open_source(
1960 &self,
1961 project: &Entity<Project>,
1962 file: &Arc<dyn File>,
1963 cx: &App,
1964 ) -> bool {
1965 if !file.is_local() || file.is_private() {
1966 return false;
1967 }
1968 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1969 return false;
1970 };
1971 project_state
1972 .license_detection_watchers
1973 .get(&file.worktree_id(cx))
1974 .as_ref()
1975 .is_some_and(|watcher| watcher.is_project_open_source())
1976 }
1977
1978 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1979 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1980 }
1981
1982 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1983 if !self.data_collection_choice.is_enabled() {
1984 return false;
1985 }
1986 events.iter().all(|event| {
1987 matches!(
1988 event.as_ref(),
1989 zeta_prompt::Event::BufferChange {
1990 in_open_source_repo: true,
1991 ..
1992 }
1993 )
1994 })
1995 }
1996
1997 fn load_data_collection_choice() -> DataCollectionChoice {
1998 let choice = KEY_VALUE_STORE
1999 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2000 .log_err()
2001 .flatten();
2002
2003 match choice.as_deref() {
2004 Some("true") => DataCollectionChoice::Enabled,
2005 Some("false") => DataCollectionChoice::Disabled,
2006 Some(_) => {
2007 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2008 DataCollectionChoice::NotAnswered
2009 }
2010 None => DataCollectionChoice::NotAnswered,
2011 }
2012 }
2013
2014 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2015 self.data_collection_choice = self.data_collection_choice.toggle();
2016 let new_choice = self.data_collection_choice;
2017 db::write_and_log(cx, move || {
2018 KEY_VALUE_STORE.write_kvp(
2019 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2020 new_choice.is_enabled().to_string(),
2021 )
2022 });
2023 }
2024
2025 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2026 self.shown_predictions.iter()
2027 }
2028
2029 pub fn shown_completions_len(&self) -> usize {
2030 self.shown_predictions.len()
2031 }
2032
2033 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2034 self.rated_predictions.contains(id)
2035 }
2036
2037 pub fn rate_prediction(
2038 &mut self,
2039 prediction: &EditPrediction,
2040 rating: EditPredictionRating,
2041 feedback: String,
2042 cx: &mut Context<Self>,
2043 ) {
2044 self.rated_predictions.insert(prediction.id.clone());
2045 telemetry::event!(
2046 "Edit Prediction Rated",
2047 rating,
2048 inputs = prediction.inputs,
2049 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2050 feedback
2051 );
2052 self.client.telemetry().flush_events().detach();
2053 cx.notify();
2054 }
2055
2056 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2057 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2058 && all_language_settings(None, cx).edit_predictions.use_context;
2059 }
2060}
2061
2062#[derive(Error, Debug)]
2063#[error(
2064 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2065)]
2066pub struct ZedUpdateRequiredError {
2067 minimum_version: Version,
2068}
2069
2070#[cfg(feature = "cli-support")]
2071pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2072
2073#[cfg(feature = "cli-support")]
2074#[derive(Debug, Clone, Copy, PartialEq)]
2075pub enum EvalCacheEntryKind {
2076 Context,
2077 Search,
2078 Prediction,
2079}
2080
2081#[cfg(feature = "cli-support")]
2082impl std::fmt::Display for EvalCacheEntryKind {
2083 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2084 match self {
2085 EvalCacheEntryKind::Search => write!(f, "search"),
2086 EvalCacheEntryKind::Context => write!(f, "context"),
2087 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2088 }
2089 }
2090}
2091
2092#[cfg(feature = "cli-support")]
2093pub trait EvalCache: Send + Sync {
2094 fn read(&self, key: EvalCacheKey) -> Option<String>;
2095 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2096}
2097
2098#[derive(Debug, Clone, Copy)]
2099pub enum DataCollectionChoice {
2100 NotAnswered,
2101 Enabled,
2102 Disabled,
2103}
2104
2105impl DataCollectionChoice {
2106 pub fn is_enabled(self) -> bool {
2107 match self {
2108 Self::Enabled => true,
2109 Self::NotAnswered | Self::Disabled => false,
2110 }
2111 }
2112
2113 pub fn is_answered(self) -> bool {
2114 match self {
2115 Self::Enabled | Self::Disabled => true,
2116 Self::NotAnswered => false,
2117 }
2118 }
2119
2120 #[must_use]
2121 pub fn toggle(&self) -> DataCollectionChoice {
2122 match self {
2123 Self::Enabled => Self::Disabled,
2124 Self::Disabled => Self::Enabled,
2125 Self::NotAnswered => Self::Enabled,
2126 }
2127 }
2128}
2129
2130impl From<bool> for DataCollectionChoice {
2131 fn from(value: bool) -> Self {
2132 match value {
2133 true => DataCollectionChoice::Enabled,
2134 false => DataCollectionChoice::Disabled,
2135 }
2136 }
2137}
2138
2139struct ZedPredictUpsell;
2140
2141impl Dismissable for ZedPredictUpsell {
2142 const KEY: &'static str = "dismissed-edit-predict-upsell";
2143
2144 fn dismissed() -> bool {
2145 // To make this backwards compatible with older versions of Zed, we
2146 // check if the user has seen the previous Edit Prediction Onboarding
2147 // before, by checking the data collection choice which was written to
2148 // the database once the user clicked on "Accept and Enable"
2149 if KEY_VALUE_STORE
2150 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2151 .log_err()
2152 .is_some_and(|s| s.is_some())
2153 {
2154 return true;
2155 }
2156
2157 KEY_VALUE_STORE
2158 .read_kvp(Self::KEY)
2159 .log_err()
2160 .is_some_and(|s| s.is_some())
2161 }
2162}
2163
2164pub fn should_show_upsell_modal() -> bool {
2165 !ZedPredictUpsell::dismissed()
2166}
2167
2168pub fn init(cx: &mut App) {
2169 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2170 workspace.register_action(
2171 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2172 ZedPredictModal::toggle(
2173 workspace,
2174 workspace.user_store().clone(),
2175 workspace.client().clone(),
2176 window,
2177 cx,
2178 )
2179 },
2180 );
2181
2182 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2183 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2184 settings
2185 .project
2186 .all_languages
2187 .features
2188 .get_or_insert_default()
2189 .edit_prediction_provider = Some(EditPredictionProvider::None)
2190 });
2191 });
2192 })
2193 .detach();
2194}