1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::http_client::Url;
23use gpui::{
24 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
25 http_client::{self, AsyncBody, Method},
26 prelude::*,
27};
28use language::language_settings::all_language_settings;
29use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
30use language::{BufferSnapshot, OffsetRangeExt};
31use language_model::{LlmApiToken, RefreshLlmTokenListener};
32use project::{Project, ProjectPath, WorktreeId};
33use release_channel::AppVersion;
34use semver::Version;
35use serde::de::DeserializeOwned;
36use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
37use std::collections::{VecDeque, hash_map};
38use text::Edit;
39use workspace::Workspace;
40
41use std::ops::Range;
42use std::path::Path;
43use std::rc::Rc;
44use std::str::FromStr as _;
45use std::sync::{Arc, LazyLock};
46use std::time::{Duration, Instant};
47use std::{env, mem};
48use thiserror::Error;
49use util::{RangeExt as _, ResultExt as _};
50use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
51
52pub mod cursor_excerpt;
53pub mod example_spec;
54mod license_detection;
55pub mod mercury;
56mod onboarding_modal;
57pub mod open_ai_response;
58mod prediction;
59pub mod sweep_ai;
60
61pub mod udiff;
62
63mod capture_example;
64mod zed_edit_prediction_delegate;
65pub mod zeta1;
66pub mod zeta2;
67
68#[cfg(test)]
69mod edit_prediction_tests;
70
71use crate::capture_example::should_sample_edit_prediction_example_capture;
72use crate::license_detection::LicenseDetectionWatcher;
73use crate::mercury::Mercury;
74use crate::onboarding_modal::ZedPredictModal;
75pub use crate::prediction::EditPrediction;
76pub use crate::prediction::EditPredictionId;
77use crate::prediction::EditPredictionResult;
78pub use crate::sweep_ai::SweepAi;
79pub use capture_example::capture_example;
80pub use language_model::ApiKeyState;
81pub use telemetry_events::EditPredictionRating;
82pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
83
84actions!(
85 edit_prediction,
86 [
87 /// Resets the edit prediction onboarding state.
88 ResetOnboarding,
89 /// Clears the edit prediction history.
90 ClearHistory,
91 ]
92);
93
94/// Maximum number of events to track.
95const EVENT_COUNT_MAX: usize = 6;
96const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
97const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
98const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
99const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
100
101pub struct SweepFeatureFlag;
102
103impl FeatureFlag for SweepFeatureFlag {
104 const NAME: &str = "sweep-ai";
105}
106
107pub struct MercuryFeatureFlag;
108
109impl FeatureFlag for MercuryFeatureFlag {
110 const NAME: &str = "mercury";
111}
112
113pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
114 context: EditPredictionExcerptOptions {
115 max_bytes: 512,
116 min_bytes: 128,
117 target_before_cursor_over_total_bytes: 0.5,
118 },
119 prompt_format: PromptFormat::DEFAULT,
120};
121
122static USE_OLLAMA: LazyLock<bool> =
123 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
124
125static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
126 match env::var("ZED_ZETA2_MODEL").as_deref() {
127 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
128 Ok(model) => model,
129 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
130 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
131 }
132 .to_string()
133});
134
135pub struct Zeta2FeatureFlag;
136
137impl FeatureFlag for Zeta2FeatureFlag {
138 const NAME: &'static str = "zeta2";
139
140 fn enabled_for_staff() -> bool {
141 true
142 }
143}
144
145pub struct EditPredictionExampleCaptureFeatureFlag;
146
147impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
148 const NAME: &'static str = "edit-prediction-example-capture";
149
150 fn enabled_for_staff() -> bool {
151 true
152 }
153}
154
155#[derive(Clone)]
156struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
157
158impl Global for EditPredictionStoreGlobal {}
159
160pub struct EditPredictionStore {
161 client: Arc<Client>,
162 user_store: Entity<UserStore>,
163 llm_token: LlmApiToken,
164 _llm_token_subscription: Subscription,
165 projects: HashMap<EntityId, ProjectState>,
166 use_context: bool,
167 options: ZetaOptions,
168 update_required: bool,
169 #[cfg(feature = "cli-support")]
170 eval_cache: Option<Arc<dyn EvalCache>>,
171 edit_prediction_model: EditPredictionModel,
172 pub sweep_ai: SweepAi,
173 pub mercury: Mercury,
174 data_collection_choice: DataCollectionChoice,
175 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
176 shown_predictions: VecDeque<EditPrediction>,
177 rated_predictions: HashSet<EditPredictionId>,
178 custom_predict_edits_url: Option<Arc<Url>>,
179}
180
181#[derive(Copy, Clone, Default, PartialEq, Eq)]
182pub enum EditPredictionModel {
183 #[default]
184 Zeta1,
185 Zeta2,
186 Sweep,
187 Mercury,
188}
189
190pub struct EditPredictionModelInput {
191 project: Entity<Project>,
192 buffer: Entity<Buffer>,
193 snapshot: BufferSnapshot,
194 position: Anchor,
195 events: Vec<Arc<zeta_prompt::Event>>,
196 related_files: Arc<[RelatedFile]>,
197 recent_paths: VecDeque<ProjectPath>,
198 trigger: PredictEditsRequestTrigger,
199 diagnostic_search_range: Range<Point>,
200 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
201}
202
203#[derive(Debug, Clone, PartialEq)]
204pub struct ZetaOptions {
205 pub context: EditPredictionExcerptOptions,
206 pub prompt_format: predict_edits_v3::PromptFormat,
207}
208
209#[derive(Debug)]
210pub enum DebugEvent {
211 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
212 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
213 EditPredictionStarted(EditPredictionStartedDebugEvent),
214 EditPredictionFinished(EditPredictionFinishedDebugEvent),
215}
216
217#[derive(Debug)]
218pub struct ContextRetrievalStartedDebugEvent {
219 pub project_entity_id: EntityId,
220 pub timestamp: Instant,
221 pub search_prompt: String,
222}
223
224#[derive(Debug)]
225pub struct ContextRetrievalFinishedDebugEvent {
226 pub project_entity_id: EntityId,
227 pub timestamp: Instant,
228 pub metadata: Vec<(&'static str, SharedString)>,
229}
230
231#[derive(Debug)]
232pub struct EditPredictionStartedDebugEvent {
233 pub buffer: WeakEntity<Buffer>,
234 pub position: Anchor,
235 pub prompt: Option<String>,
236}
237
238#[derive(Debug)]
239pub struct EditPredictionFinishedDebugEvent {
240 pub buffer: WeakEntity<Buffer>,
241 pub position: Anchor,
242 pub model_output: Option<String>,
243}
244
245pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
246
247/// An event with associated metadata for reconstructing buffer state.
248#[derive(Clone)]
249pub struct StoredEvent {
250 pub event: Arc<zeta_prompt::Event>,
251 pub old_snapshot: TextBufferSnapshot,
252}
253
254struct ProjectState {
255 events: VecDeque<StoredEvent>,
256 last_event: Option<LastEvent>,
257 recent_paths: VecDeque<ProjectPath>,
258 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
259 current_prediction: Option<CurrentEditPrediction>,
260 next_pending_prediction_id: usize,
261 pending_predictions: ArrayVec<PendingPrediction, 2>,
262 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
263 last_prediction_refresh: Option<(EntityId, Instant)>,
264 cancelled_predictions: HashSet<usize>,
265 context: Entity<RelatedExcerptStore>,
266 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
267 _subscription: gpui::Subscription,
268}
269
270impl ProjectState {
271 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
272 self.events
273 .iter()
274 .cloned()
275 .chain(
276 self.last_event
277 .as_ref()
278 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
279 )
280 .collect()
281 }
282
283 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
284 self.events
285 .iter()
286 .cloned()
287 .chain(self.last_event.as_ref().iter().flat_map(|event| {
288 let (one, two) = event.split_by_pause();
289 let one = one.finalize(&self.license_detection_watchers, cx);
290 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
291 one.into_iter().chain(two)
292 }))
293 .collect()
294 }
295
296 fn cancel_pending_prediction(
297 &mut self,
298 pending_prediction: PendingPrediction,
299 cx: &mut Context<EditPredictionStore>,
300 ) {
301 self.cancelled_predictions.insert(pending_prediction.id);
302
303 cx.spawn(async move |this, cx| {
304 let Some(prediction_id) = pending_prediction.task.await else {
305 return;
306 };
307
308 this.update(cx, |this, _cx| {
309 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
310 })
311 .ok();
312 })
313 .detach()
314 }
315
316 fn active_buffer(
317 &self,
318 project: &Entity<Project>,
319 cx: &App,
320 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
321 let project = project.read(cx);
322 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
323 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
324 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
325 Some((active_buffer, registered_buffer.last_position))
326 }
327}
328
329#[derive(Debug, Clone)]
330struct CurrentEditPrediction {
331 pub requested_by: PredictionRequestedBy,
332 pub prediction: EditPrediction,
333 pub was_shown: bool,
334}
335
336impl CurrentEditPrediction {
337 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
338 let Some(new_edits) = self
339 .prediction
340 .interpolate(&self.prediction.buffer.read(cx))
341 else {
342 return false;
343 };
344
345 if self.prediction.buffer != old_prediction.prediction.buffer {
346 return true;
347 }
348
349 let Some(old_edits) = old_prediction
350 .prediction
351 .interpolate(&old_prediction.prediction.buffer.read(cx))
352 else {
353 return true;
354 };
355
356 let requested_by_buffer_id = self.requested_by.buffer_id();
357
358 // This reduces the occurrence of UI thrash from replacing edits
359 //
360 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
361 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
362 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
363 && old_edits.len() == 1
364 && new_edits.len() == 1
365 {
366 let (old_range, old_text) = &old_edits[0];
367 let (new_range, new_text) = &new_edits[0];
368 new_range == old_range && new_text.starts_with(old_text.as_ref())
369 } else {
370 true
371 }
372 }
373}
374
375#[derive(Debug, Clone)]
376enum PredictionRequestedBy {
377 DiagnosticsUpdate,
378 Buffer(EntityId),
379}
380
381impl PredictionRequestedBy {
382 pub fn buffer_id(&self) -> Option<EntityId> {
383 match self {
384 PredictionRequestedBy::DiagnosticsUpdate => None,
385 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
386 }
387 }
388}
389
390#[derive(Debug)]
391struct PendingPrediction {
392 id: usize,
393 task: Task<Option<EditPredictionId>>,
394}
395
396/// A prediction from the perspective of a buffer.
397#[derive(Debug)]
398enum BufferEditPrediction<'a> {
399 Local { prediction: &'a EditPrediction },
400 Jump { prediction: &'a EditPrediction },
401}
402
403#[cfg(test)]
404impl std::ops::Deref for BufferEditPrediction<'_> {
405 type Target = EditPrediction;
406
407 fn deref(&self) -> &Self::Target {
408 match self {
409 BufferEditPrediction::Local { prediction } => prediction,
410 BufferEditPrediction::Jump { prediction } => prediction,
411 }
412 }
413}
414
415struct RegisteredBuffer {
416 file: Option<Arc<dyn File>>,
417 snapshot: TextBufferSnapshot,
418 last_position: Option<Anchor>,
419 _subscriptions: [gpui::Subscription; 2],
420}
421
422#[derive(Clone)]
423struct LastEvent {
424 old_snapshot: TextBufferSnapshot,
425 new_snapshot: TextBufferSnapshot,
426 old_file: Option<Arc<dyn File>>,
427 new_file: Option<Arc<dyn File>>,
428 end_edit_anchor: Option<Anchor>,
429 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
430 last_edit_time: Option<Instant>,
431}
432
433impl LastEvent {
434 pub fn finalize(
435 &self,
436 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
437 cx: &App,
438 ) -> Option<StoredEvent> {
439 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
440 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
441
442 let in_open_source_repo =
443 [self.new_file.as_ref(), self.old_file.as_ref()]
444 .iter()
445 .all(|file| {
446 file.is_some_and(|file| {
447 license_detection_watchers
448 .get(&file.worktree_id(cx))
449 .is_some_and(|watcher| watcher.is_project_open_source())
450 })
451 });
452
453 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
454
455 if path == old_path && diff.is_empty() {
456 None
457 } else {
458 Some(StoredEvent {
459 event: Arc::new(zeta_prompt::Event::BufferChange {
460 old_path,
461 path,
462 diff,
463 in_open_source_repo,
464 // TODO: Actually detect if this edit was predicted or not
465 predicted: false,
466 }),
467 old_snapshot: self.old_snapshot.clone(),
468 })
469 }
470 }
471
472 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
473 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
474 return (self.clone(), None);
475 };
476
477 let before = LastEvent {
478 old_snapshot: self.old_snapshot.clone(),
479 new_snapshot: boundary_snapshot.clone(),
480 old_file: self.old_file.clone(),
481 new_file: self.new_file.clone(),
482 end_edit_anchor: self.end_edit_anchor,
483 snapshot_after_last_editing_pause: None,
484 last_edit_time: self.last_edit_time,
485 };
486
487 let after = LastEvent {
488 old_snapshot: boundary_snapshot.clone(),
489 new_snapshot: self.new_snapshot.clone(),
490 old_file: self.old_file.clone(),
491 new_file: self.new_file.clone(),
492 end_edit_anchor: self.end_edit_anchor,
493 snapshot_after_last_editing_pause: None,
494 last_edit_time: self.last_edit_time,
495 };
496
497 (before, Some(after))
498 }
499}
500
501pub(crate) fn compute_diff_between_snapshots(
502 old_snapshot: &TextBufferSnapshot,
503 new_snapshot: &TextBufferSnapshot,
504) -> Option<String> {
505 let edits: Vec<Edit<usize>> = new_snapshot
506 .edits_since::<usize>(&old_snapshot.version)
507 .collect();
508
509 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
510
511 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
512 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
513 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
514 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
515
516 const CONTEXT_LINES: u32 = 3;
517
518 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
519 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
520 let old_context_end_row =
521 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
522 let new_context_end_row =
523 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
524
525 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
526 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
527 let old_end_line_offset = old_snapshot
528 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
529 let new_end_line_offset = new_snapshot
530 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
531 let old_edit_range = old_start_line_offset..old_end_line_offset;
532 let new_edit_range = new_start_line_offset..new_end_line_offset;
533
534 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
535 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
536
537 let diff = language::unified_diff_with_offsets(
538 &old_region_text,
539 &new_region_text,
540 old_context_start_row,
541 new_context_start_row,
542 );
543
544 Some(diff)
545}
546
547fn buffer_path_with_id_fallback(
548 file: Option<&Arc<dyn File>>,
549 snapshot: &TextBufferSnapshot,
550 cx: &App,
551) -> Arc<Path> {
552 if let Some(file) = file {
553 file.full_path(cx).into()
554 } else {
555 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
556 }
557}
558
559impl EditPredictionStore {
560 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
561 cx.try_global::<EditPredictionStoreGlobal>()
562 .map(|global| global.0.clone())
563 }
564
565 pub fn global(
566 client: &Arc<Client>,
567 user_store: &Entity<UserStore>,
568 cx: &mut App,
569 ) -> Entity<Self> {
570 cx.try_global::<EditPredictionStoreGlobal>()
571 .map(|global| global.0.clone())
572 .unwrap_or_else(|| {
573 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
574 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
575 ep_store
576 })
577 }
578
579 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
580 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
581 let data_collection_choice = Self::load_data_collection_choice();
582
583 let llm_token = LlmApiToken::default();
584
585 let (reject_tx, reject_rx) = mpsc::unbounded();
586 cx.background_spawn({
587 let client = client.clone();
588 let llm_token = llm_token.clone();
589 let app_version = AppVersion::global(cx);
590 let background_executor = cx.background_executor().clone();
591 async move {
592 Self::handle_rejected_predictions(
593 reject_rx,
594 client,
595 llm_token,
596 app_version,
597 background_executor,
598 )
599 .await
600 }
601 })
602 .detach();
603
604 let mut this = Self {
605 projects: HashMap::default(),
606 client,
607 user_store,
608 options: DEFAULT_OPTIONS,
609 use_context: false,
610 llm_token,
611 _llm_token_subscription: cx.subscribe(
612 &refresh_llm_token_listener,
613 |this, _listener, _event, cx| {
614 let client = this.client.clone();
615 let llm_token = this.llm_token.clone();
616 cx.spawn(async move |_this, _cx| {
617 llm_token.refresh(&client).await?;
618 anyhow::Ok(())
619 })
620 .detach_and_log_err(cx);
621 },
622 ),
623 update_required: false,
624 #[cfg(feature = "cli-support")]
625 eval_cache: None,
626 edit_prediction_model: EditPredictionModel::Zeta2,
627 sweep_ai: SweepAi::new(cx),
628 mercury: Mercury::new(cx),
629 data_collection_choice,
630 reject_predictions_tx: reject_tx,
631 rated_predictions: Default::default(),
632 shown_predictions: Default::default(),
633 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
634 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
635 Err(_) => {
636 if *USE_OLLAMA {
637 Some(
638 Url::parse("http://localhost:11434/v1/chat/completions")
639 .unwrap()
640 .into(),
641 )
642 } else {
643 None
644 }
645 }
646 },
647 };
648
649 this.configure_context_retrieval(cx);
650 let weak_this = cx.weak_entity();
651 cx.on_flags_ready(move |_, cx| {
652 weak_this
653 .update(cx, |this, cx| this.configure_context_retrieval(cx))
654 .ok();
655 })
656 .detach();
657 cx.observe_global::<SettingsStore>(|this, cx| {
658 this.configure_context_retrieval(cx);
659 })
660 .detach();
661
662 this
663 }
664
665 #[cfg(test)]
666 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
667 self.custom_predict_edits_url = Some(url.into());
668 }
669
670 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
671 self.edit_prediction_model = model;
672 }
673
674 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
675 self.sweep_ai.api_token.read(cx).has_key()
676 }
677
678 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
679 self.mercury.api_token.read(cx).has_key()
680 }
681
682 #[cfg(feature = "cli-support")]
683 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
684 self.eval_cache = Some(cache);
685 }
686
687 pub fn options(&self) -> &ZetaOptions {
688 &self.options
689 }
690
691 pub fn set_options(&mut self, options: ZetaOptions) {
692 self.options = options;
693 }
694
695 pub fn set_use_context(&mut self, use_context: bool) {
696 self.use_context = use_context;
697 }
698
699 pub fn clear_history(&mut self) {
700 for project_state in self.projects.values_mut() {
701 project_state.events.clear();
702 project_state.last_event.take();
703 }
704 }
705
706 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
707 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
708 project_state.events.clear();
709 project_state.last_event.take();
710 }
711 }
712
713 pub fn edit_history_for_project(
714 &self,
715 project: &Entity<Project>,
716 cx: &App,
717 ) -> Vec<StoredEvent> {
718 self.projects
719 .get(&project.entity_id())
720 .map(|project_state| project_state.events(cx))
721 .unwrap_or_default()
722 }
723
724 pub fn edit_history_for_project_with_pause_split_last_event(
725 &self,
726 project: &Entity<Project>,
727 cx: &App,
728 ) -> Vec<StoredEvent> {
729 self.projects
730 .get(&project.entity_id())
731 .map(|project_state| project_state.events_split_by_pause(cx))
732 .unwrap_or_default()
733 }
734
735 pub fn context_for_project<'a>(
736 &'a self,
737 project: &Entity<Project>,
738 cx: &'a App,
739 ) -> Arc<[RelatedFile]> {
740 self.projects
741 .get(&project.entity_id())
742 .map(|project| project.context.read(cx).related_files())
743 .unwrap_or_else(|| vec![].into())
744 }
745
746 pub fn context_for_project_with_buffers<'a>(
747 &'a self,
748 project: &Entity<Project>,
749 cx: &'a App,
750 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
751 self.projects
752 .get(&project.entity_id())
753 .map(|project| project.context.read(cx).related_files_with_buffers())
754 }
755
756 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
757 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
758 self.user_store.read(cx).edit_prediction_usage()
759 } else {
760 None
761 }
762 }
763
764 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
765 self.get_or_init_project(project, cx);
766 }
767
768 pub fn register_buffer(
769 &mut self,
770 buffer: &Entity<Buffer>,
771 project: &Entity<Project>,
772 cx: &mut Context<Self>,
773 ) {
774 let project_state = self.get_or_init_project(project, cx);
775 Self::register_buffer_impl(project_state, buffer, project, cx);
776 }
777
778 fn get_or_init_project(
779 &mut self,
780 project: &Entity<Project>,
781 cx: &mut Context<Self>,
782 ) -> &mut ProjectState {
783 let entity_id = project.entity_id();
784 self.projects
785 .entry(entity_id)
786 .or_insert_with(|| ProjectState {
787 context: {
788 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
789 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
790 this.handle_excerpt_store_event(entity_id, event);
791 })
792 .detach();
793 related_excerpt_store
794 },
795 events: VecDeque::new(),
796 last_event: None,
797 recent_paths: VecDeque::new(),
798 debug_tx: None,
799 registered_buffers: HashMap::default(),
800 current_prediction: None,
801 cancelled_predictions: HashSet::default(),
802 pending_predictions: ArrayVec::new(),
803 next_pending_prediction_id: 0,
804 last_prediction_refresh: None,
805 license_detection_watchers: HashMap::default(),
806 _subscription: cx.subscribe(&project, Self::handle_project_event),
807 })
808 }
809
810 pub fn remove_project(&mut self, project: &Entity<Project>) {
811 self.projects.remove(&project.entity_id());
812 }
813
814 fn handle_excerpt_store_event(
815 &mut self,
816 project_entity_id: EntityId,
817 event: &RelatedExcerptStoreEvent,
818 ) {
819 if let Some(project_state) = self.projects.get(&project_entity_id) {
820 if let Some(debug_tx) = project_state.debug_tx.clone() {
821 match event {
822 RelatedExcerptStoreEvent::StartedRefresh => {
823 debug_tx
824 .unbounded_send(DebugEvent::ContextRetrievalStarted(
825 ContextRetrievalStartedDebugEvent {
826 project_entity_id: project_entity_id,
827 timestamp: Instant::now(),
828 search_prompt: String::new(),
829 },
830 ))
831 .ok();
832 }
833 RelatedExcerptStoreEvent::FinishedRefresh {
834 cache_hit_count,
835 cache_miss_count,
836 mean_definition_latency,
837 max_definition_latency,
838 } => {
839 debug_tx
840 .unbounded_send(DebugEvent::ContextRetrievalFinished(
841 ContextRetrievalFinishedDebugEvent {
842 project_entity_id: project_entity_id,
843 timestamp: Instant::now(),
844 metadata: vec![
845 (
846 "Cache Hits",
847 format!(
848 "{}/{}",
849 cache_hit_count,
850 cache_hit_count + cache_miss_count
851 )
852 .into(),
853 ),
854 (
855 "Max LSP Time",
856 format!("{} ms", max_definition_latency.as_millis())
857 .into(),
858 ),
859 (
860 "Mean LSP Time",
861 format!("{} ms", mean_definition_latency.as_millis())
862 .into(),
863 ),
864 ],
865 },
866 ))
867 .ok();
868 }
869 }
870 }
871 }
872 }
873
874 pub fn debug_info(
875 &mut self,
876 project: &Entity<Project>,
877 cx: &mut Context<Self>,
878 ) -> mpsc::UnboundedReceiver<DebugEvent> {
879 let project_state = self.get_or_init_project(project, cx);
880 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
881 project_state.debug_tx = Some(debug_watch_tx);
882 debug_watch_rx
883 }
884
885 fn handle_project_event(
886 &mut self,
887 project: Entity<Project>,
888 event: &project::Event,
889 cx: &mut Context<Self>,
890 ) {
891 // TODO [zeta2] init with recent paths
892 match event {
893 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
894 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
895 return;
896 };
897 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
898 if let Some(path) = path {
899 if let Some(ix) = project_state
900 .recent_paths
901 .iter()
902 .position(|probe| probe == &path)
903 {
904 project_state.recent_paths.remove(ix);
905 }
906 project_state.recent_paths.push_front(path);
907 }
908 }
909 project::Event::DiagnosticsUpdated { .. } => {
910 if cx.has_flag::<Zeta2FeatureFlag>() {
911 self.refresh_prediction_from_diagnostics(project, cx);
912 }
913 }
914 _ => (),
915 }
916 }
917
918 fn register_buffer_impl<'a>(
919 project_state: &'a mut ProjectState,
920 buffer: &Entity<Buffer>,
921 project: &Entity<Project>,
922 cx: &mut Context<Self>,
923 ) -> &'a mut RegisteredBuffer {
924 let buffer_id = buffer.entity_id();
925
926 if let Some(file) = buffer.read(cx).file() {
927 let worktree_id = file.worktree_id(cx);
928 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
929 project_state
930 .license_detection_watchers
931 .entry(worktree_id)
932 .or_insert_with(|| {
933 let project_entity_id = project.entity_id();
934 cx.observe_release(&worktree, move |this, _worktree, _cx| {
935 let Some(project_state) = this.projects.get_mut(&project_entity_id)
936 else {
937 return;
938 };
939 project_state
940 .license_detection_watchers
941 .remove(&worktree_id);
942 })
943 .detach();
944 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
945 });
946 }
947 }
948
949 match project_state.registered_buffers.entry(buffer_id) {
950 hash_map::Entry::Occupied(entry) => entry.into_mut(),
951 hash_map::Entry::Vacant(entry) => {
952 let buf = buffer.read(cx);
953 let snapshot = buf.text_snapshot();
954 let file = buf.file().cloned();
955 let project_entity_id = project.entity_id();
956 entry.insert(RegisteredBuffer {
957 snapshot,
958 file,
959 last_position: None,
960 _subscriptions: [
961 cx.subscribe(buffer, {
962 let project = project.downgrade();
963 move |this, buffer, event, cx| {
964 if let language::BufferEvent::Edited = event
965 && let Some(project) = project.upgrade()
966 {
967 this.report_changes_for_buffer(&buffer, &project, cx);
968 }
969 }
970 }),
971 cx.observe_release(buffer, move |this, _buffer, _cx| {
972 let Some(project_state) = this.projects.get_mut(&project_entity_id)
973 else {
974 return;
975 };
976 project_state.registered_buffers.remove(&buffer_id);
977 }),
978 ],
979 })
980 }
981 }
982 }
983
984 fn report_changes_for_buffer(
985 &mut self,
986 buffer: &Entity<Buffer>,
987 project: &Entity<Project>,
988 cx: &mut Context<Self>,
989 ) {
990 let project_state = self.get_or_init_project(project, cx);
991 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
992
993 let buf = buffer.read(cx);
994 let new_file = buf.file().cloned();
995 let new_snapshot = buf.text_snapshot();
996 if new_snapshot.version == registered_buffer.snapshot.version {
997 return;
998 }
999
1000 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1001 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1002 let end_edit_anchor = new_snapshot
1003 .anchored_edits_since::<Point>(&old_snapshot.version)
1004 .last()
1005 .map(|(_, range)| range.end);
1006 let events = &mut project_state.events;
1007
1008 let now = cx.background_executor().now();
1009 if let Some(last_event) = project_state.last_event.as_mut() {
1010 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1011 == last_event.new_snapshot.remote_id()
1012 && old_snapshot.version == last_event.new_snapshot.version;
1013
1014 let should_coalesce = is_next_snapshot_of_same_buffer
1015 && end_edit_anchor
1016 .as_ref()
1017 .zip(last_event.end_edit_anchor.as_ref())
1018 .is_some_and(|(a, b)| {
1019 let a = a.to_point(&new_snapshot);
1020 let b = b.to_point(&new_snapshot);
1021 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
1022 });
1023
1024 if should_coalesce {
1025 let pause_elapsed = last_event
1026 .last_edit_time
1027 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1028 .unwrap_or(false);
1029 if pause_elapsed {
1030 last_event.snapshot_after_last_editing_pause =
1031 Some(last_event.new_snapshot.clone());
1032 }
1033
1034 last_event.end_edit_anchor = end_edit_anchor;
1035 last_event.new_snapshot = new_snapshot;
1036 last_event.last_edit_time = Some(now);
1037 return;
1038 }
1039 }
1040
1041 if events.len() + 1 >= EVENT_COUNT_MAX {
1042 events.pop_front();
1043 }
1044
1045 if let Some(event) = project_state.last_event.take() {
1046 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
1047 }
1048
1049 project_state.last_event = Some(LastEvent {
1050 old_file,
1051 new_file,
1052 old_snapshot,
1053 new_snapshot,
1054 end_edit_anchor,
1055 snapshot_after_last_editing_pause: None,
1056 last_edit_time: Some(now),
1057 });
1058 }
1059
1060 fn prediction_at(
1061 &mut self,
1062 buffer: &Entity<Buffer>,
1063 position: Option<language::Anchor>,
1064 project: &Entity<Project>,
1065 cx: &App,
1066 ) -> Option<BufferEditPrediction<'_>> {
1067 let project_state = self.projects.get_mut(&project.entity_id())?;
1068 if let Some(position) = position
1069 && let Some(buffer) = project_state
1070 .registered_buffers
1071 .get_mut(&buffer.entity_id())
1072 {
1073 buffer.last_position = Some(position);
1074 }
1075
1076 let CurrentEditPrediction {
1077 requested_by,
1078 prediction,
1079 ..
1080 } = project_state.current_prediction.as_ref()?;
1081
1082 if prediction.targets_buffer(buffer.read(cx)) {
1083 Some(BufferEditPrediction::Local { prediction })
1084 } else {
1085 let show_jump = match requested_by {
1086 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1087 requested_by_buffer_id == &buffer.entity_id()
1088 }
1089 PredictionRequestedBy::DiagnosticsUpdate => true,
1090 };
1091
1092 if show_jump {
1093 Some(BufferEditPrediction::Jump { prediction })
1094 } else {
1095 None
1096 }
1097 }
1098 }
1099
1100 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1101 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
1102 match self.edit_prediction_model {
1103 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1104 if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
1105 return;
1106 }
1107 }
1108 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1109 }
1110
1111 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1112 return;
1113 };
1114
1115 let Some(prediction) = project_state.current_prediction.take() else {
1116 return;
1117 };
1118 let request_id = prediction.prediction.id.to_string();
1119 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1120 project_state.cancel_pending_prediction(pending_prediction, cx);
1121 }
1122
1123 let client = self.client.clone();
1124 let llm_token = self.llm_token.clone();
1125 let app_version = AppVersion::global(cx);
1126 cx.spawn(async move |this, cx| {
1127 let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
1128 (http_client::Url::parse(&accept_edits_url)?, false)
1129 } else {
1130 (
1131 client
1132 .http_client()
1133 .build_zed_llm_url("/predict_edits/accept", &[])?,
1134 true,
1135 )
1136 };
1137
1138 let response = cx
1139 .background_spawn(Self::send_api_request::<()>(
1140 move |builder| {
1141 let req = builder.uri(url.as_ref()).body(
1142 serde_json::to_string(&AcceptEditPredictionBody {
1143 request_id: request_id.clone(),
1144 })?
1145 .into(),
1146 );
1147 Ok(req?)
1148 },
1149 client,
1150 llm_token,
1151 app_version,
1152 require_auth,
1153 ))
1154 .await;
1155
1156 Self::handle_api_response(&this, response, cx)?;
1157 anyhow::Ok(())
1158 })
1159 .detach_and_log_err(cx);
1160 }
1161
1162 async fn handle_rejected_predictions(
1163 rx: UnboundedReceiver<EditPredictionRejection>,
1164 client: Arc<Client>,
1165 llm_token: LlmApiToken,
1166 app_version: Version,
1167 background_executor: BackgroundExecutor,
1168 ) {
1169 let mut rx = std::pin::pin!(rx.peekable());
1170 let mut batched = Vec::new();
1171
1172 while let Some(rejection) = rx.next().await {
1173 batched.push(rejection);
1174
1175 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1176 select_biased! {
1177 next = rx.as_mut().peek().fuse() => {
1178 if next.is_some() {
1179 continue;
1180 }
1181 }
1182 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1183 }
1184 }
1185
1186 let url = client
1187 .http_client()
1188 .build_zed_llm_url("/predict_edits/reject", &[])
1189 .unwrap();
1190
1191 let flush_count = batched
1192 .len()
1193 // in case items have accumulated after failure
1194 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1195 let start = batched.len() - flush_count;
1196
1197 let body = RejectEditPredictionsBodyRef {
1198 rejections: &batched[start..],
1199 };
1200
1201 let result = Self::send_api_request::<()>(
1202 |builder| {
1203 let req = builder
1204 .uri(url.as_ref())
1205 .body(serde_json::to_string(&body)?.into());
1206 anyhow::Ok(req?)
1207 },
1208 client.clone(),
1209 llm_token.clone(),
1210 app_version.clone(),
1211 true,
1212 )
1213 .await;
1214
1215 if result.log_err().is_some() {
1216 batched.drain(start..);
1217 }
1218 }
1219 }
1220
1221 fn reject_current_prediction(
1222 &mut self,
1223 reason: EditPredictionRejectReason,
1224 project: &Entity<Project>,
1225 ) {
1226 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1227 project_state.pending_predictions.clear();
1228 if let Some(prediction) = project_state.current_prediction.take() {
1229 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1230 }
1231 };
1232 }
1233
1234 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1235 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1236 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1237 if !current_prediction.was_shown {
1238 current_prediction.was_shown = true;
1239 self.shown_predictions
1240 .push_front(current_prediction.prediction.clone());
1241 if self.shown_predictions.len() > 50 {
1242 let completion = self.shown_predictions.pop_back().unwrap();
1243 self.rated_predictions.remove(&completion.id);
1244 }
1245 }
1246 }
1247 }
1248 }
1249
1250 fn reject_prediction(
1251 &mut self,
1252 prediction_id: EditPredictionId,
1253 reason: EditPredictionRejectReason,
1254 was_shown: bool,
1255 ) {
1256 match self.edit_prediction_model {
1257 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1258 if self.custom_predict_edits_url.is_some() {
1259 return;
1260 }
1261 }
1262 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1263 }
1264
1265 self.reject_predictions_tx
1266 .unbounded_send(EditPredictionRejection {
1267 request_id: prediction_id.to_string(),
1268 reason,
1269 was_shown,
1270 })
1271 .log_err();
1272 }
1273
1274 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1275 self.projects
1276 .get(&project.entity_id())
1277 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1278 }
1279
1280 pub fn refresh_prediction_from_buffer(
1281 &mut self,
1282 project: Entity<Project>,
1283 buffer: Entity<Buffer>,
1284 position: language::Anchor,
1285 cx: &mut Context<Self>,
1286 ) {
1287 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1288 let Some(request_task) = this
1289 .update(cx, |this, cx| {
1290 this.request_prediction(
1291 &project,
1292 &buffer,
1293 position,
1294 PredictEditsRequestTrigger::Other,
1295 cx,
1296 )
1297 })
1298 .log_err()
1299 else {
1300 return Task::ready(anyhow::Ok(None));
1301 };
1302
1303 cx.spawn(async move |_cx| {
1304 request_task.await.map(|prediction_result| {
1305 prediction_result.map(|prediction_result| {
1306 (
1307 prediction_result,
1308 PredictionRequestedBy::Buffer(buffer.entity_id()),
1309 )
1310 })
1311 })
1312 })
1313 })
1314 }
1315
1316 pub fn refresh_prediction_from_diagnostics(
1317 &mut self,
1318 project: Entity<Project>,
1319 cx: &mut Context<Self>,
1320 ) {
1321 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1322 return;
1323 };
1324
1325 // Prefer predictions from buffer
1326 if project_state.current_prediction.is_some() {
1327 return;
1328 };
1329
1330 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1331 let Some((active_buffer, snapshot, cursor_point)) = this
1332 .read_with(cx, |this, cx| {
1333 let project_state = this.projects.get(&project.entity_id())?;
1334 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1335 let snapshot = buffer.read(cx).snapshot();
1336
1337 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1338 return None;
1339 }
1340
1341 let cursor_point = position
1342 .map(|pos| pos.to_point(&snapshot))
1343 .unwrap_or_default();
1344
1345 Some((buffer, snapshot, cursor_point))
1346 })
1347 .log_err()
1348 .flatten()
1349 else {
1350 return Task::ready(anyhow::Ok(None));
1351 };
1352
1353 cx.spawn(async move |cx| {
1354 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1355 active_buffer,
1356 &snapshot,
1357 Default::default(),
1358 cursor_point,
1359 &project,
1360 cx,
1361 )
1362 .await?
1363 else {
1364 return anyhow::Ok(None);
1365 };
1366
1367 let Some(prediction_result) = this
1368 .update(cx, |this, cx| {
1369 this.request_prediction(
1370 &project,
1371 &jump_buffer,
1372 jump_position,
1373 PredictEditsRequestTrigger::Diagnostics,
1374 cx,
1375 )
1376 })?
1377 .await?
1378 else {
1379 return anyhow::Ok(None);
1380 };
1381
1382 this.update(cx, |this, cx| {
1383 Some((
1384 if this
1385 .get_or_init_project(&project, cx)
1386 .current_prediction
1387 .is_none()
1388 {
1389 prediction_result
1390 } else {
1391 EditPredictionResult {
1392 id: prediction_result.id,
1393 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1394 }
1395 },
1396 PredictionRequestedBy::DiagnosticsUpdate,
1397 ))
1398 })
1399 })
1400 });
1401 }
1402
1403 fn predictions_enabled_at(
1404 snapshot: &BufferSnapshot,
1405 position: Option<language::Anchor>,
1406 cx: &App,
1407 ) -> bool {
1408 let file = snapshot.file();
1409 let all_settings = all_language_settings(file, cx);
1410 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1411 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1412 {
1413 return false;
1414 }
1415
1416 if let Some(last_position) = position {
1417 let settings = snapshot.settings_at(last_position, cx);
1418
1419 if !settings.edit_predictions_disabled_in.is_empty()
1420 && let Some(scope) = snapshot.language_scope_at(last_position)
1421 && let Some(scope_name) = scope.override_name()
1422 && settings
1423 .edit_predictions_disabled_in
1424 .iter()
1425 .any(|s| s == scope_name)
1426 {
1427 return false;
1428 }
1429 }
1430
1431 true
1432 }
1433
1434 #[cfg(not(test))]
1435 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1436 #[cfg(test)]
1437 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1438
1439 fn queue_prediction_refresh(
1440 &mut self,
1441 project: Entity<Project>,
1442 throttle_entity: EntityId,
1443 cx: &mut Context<Self>,
1444 do_refresh: impl FnOnce(
1445 WeakEntity<Self>,
1446 &mut AsyncApp,
1447 )
1448 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1449 + 'static,
1450 ) {
1451 let project_state = self.get_or_init_project(&project, cx);
1452 let pending_prediction_id = project_state.next_pending_prediction_id;
1453 project_state.next_pending_prediction_id += 1;
1454 let last_request = project_state.last_prediction_refresh;
1455
1456 let task = cx.spawn(async move |this, cx| {
1457 if let Some((last_entity, last_timestamp)) = last_request
1458 && throttle_entity == last_entity
1459 && let Some(timeout) =
1460 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1461 {
1462 cx.background_executor().timer(timeout).await;
1463 }
1464
1465 // If this task was cancelled before the throttle timeout expired,
1466 // do not perform a request.
1467 let mut is_cancelled = true;
1468 this.update(cx, |this, cx| {
1469 let project_state = this.get_or_init_project(&project, cx);
1470 if !project_state
1471 .cancelled_predictions
1472 .remove(&pending_prediction_id)
1473 {
1474 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1475 is_cancelled = false;
1476 }
1477 })
1478 .ok();
1479 if is_cancelled {
1480 return None;
1481 }
1482
1483 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1484 let new_prediction_id = new_prediction_result
1485 .as_ref()
1486 .map(|(prediction, _)| prediction.id.clone());
1487
1488 // When a prediction completes, remove it from the pending list, and cancel
1489 // any pending predictions that were enqueued before it.
1490 this.update(cx, |this, cx| {
1491 let project_state = this.get_or_init_project(&project, cx);
1492
1493 let is_cancelled = project_state
1494 .cancelled_predictions
1495 .remove(&pending_prediction_id);
1496
1497 let new_current_prediction = if !is_cancelled
1498 && let Some((prediction_result, requested_by)) = new_prediction_result
1499 {
1500 match prediction_result.prediction {
1501 Ok(prediction) => {
1502 let new_prediction = CurrentEditPrediction {
1503 requested_by,
1504 prediction,
1505 was_shown: false,
1506 };
1507
1508 if let Some(current_prediction) =
1509 project_state.current_prediction.as_ref()
1510 {
1511 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1512 {
1513 this.reject_current_prediction(
1514 EditPredictionRejectReason::Replaced,
1515 &project,
1516 );
1517
1518 Some(new_prediction)
1519 } else {
1520 this.reject_prediction(
1521 new_prediction.prediction.id,
1522 EditPredictionRejectReason::CurrentPreferred,
1523 false,
1524 );
1525 None
1526 }
1527 } else {
1528 Some(new_prediction)
1529 }
1530 }
1531 Err(reject_reason) => {
1532 this.reject_prediction(prediction_result.id, reject_reason, false);
1533 None
1534 }
1535 }
1536 } else {
1537 None
1538 };
1539
1540 let project_state = this.get_or_init_project(&project, cx);
1541
1542 if let Some(new_prediction) = new_current_prediction {
1543 project_state.current_prediction = Some(new_prediction);
1544 }
1545
1546 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1547 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1548 if pending_prediction.id == pending_prediction_id {
1549 pending_predictions.remove(ix);
1550 for pending_prediction in pending_predictions.drain(0..ix) {
1551 project_state.cancel_pending_prediction(pending_prediction, cx)
1552 }
1553 break;
1554 }
1555 }
1556 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1557 cx.notify();
1558 })
1559 .ok();
1560
1561 new_prediction_id
1562 });
1563
1564 if project_state.pending_predictions.len() <= 1 {
1565 project_state.pending_predictions.push(PendingPrediction {
1566 id: pending_prediction_id,
1567 task,
1568 });
1569 } else if project_state.pending_predictions.len() == 2 {
1570 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1571 project_state.pending_predictions.push(PendingPrediction {
1572 id: pending_prediction_id,
1573 task,
1574 });
1575 project_state.cancel_pending_prediction(pending_prediction, cx);
1576 }
1577 }
1578
1579 pub fn request_prediction(
1580 &mut self,
1581 project: &Entity<Project>,
1582 active_buffer: &Entity<Buffer>,
1583 position: language::Anchor,
1584 trigger: PredictEditsRequestTrigger,
1585 cx: &mut Context<Self>,
1586 ) -> Task<Result<Option<EditPredictionResult>>> {
1587 self.request_prediction_internal(
1588 project.clone(),
1589 active_buffer.clone(),
1590 position,
1591 trigger,
1592 cx.has_flag::<Zeta2FeatureFlag>(),
1593 cx,
1594 )
1595 }
1596
1597 fn request_prediction_internal(
1598 &mut self,
1599 project: Entity<Project>,
1600 active_buffer: Entity<Buffer>,
1601 position: language::Anchor,
1602 trigger: PredictEditsRequestTrigger,
1603 allow_jump: bool,
1604 cx: &mut Context<Self>,
1605 ) -> Task<Result<Option<EditPredictionResult>>> {
1606 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1607
1608 self.get_or_init_project(&project, cx);
1609 let project_state = self.projects.get(&project.entity_id()).unwrap();
1610 let stored_events = project_state.events(cx);
1611 let has_events = !stored_events.is_empty();
1612 let events: Vec<Arc<zeta_prompt::Event>> =
1613 stored_events.into_iter().map(|e| e.event).collect();
1614 let debug_tx = project_state.debug_tx.clone();
1615
1616 let snapshot = active_buffer.read(cx).snapshot();
1617 let cursor_point = position.to_point(&snapshot);
1618 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1619 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1620 let diagnostic_search_range =
1621 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1622
1623 let related_files = if self.use_context {
1624 self.context_for_project(&project, cx)
1625 } else {
1626 Vec::new().into()
1627 };
1628
1629 let inputs = EditPredictionModelInput {
1630 project: project.clone(),
1631 buffer: active_buffer.clone(),
1632 snapshot: snapshot.clone(),
1633 position,
1634 events,
1635 related_files,
1636 recent_paths: project_state.recent_paths.clone(),
1637 trigger,
1638 diagnostic_search_range: diagnostic_search_range.clone(),
1639 debug_tx,
1640 };
1641
1642 let can_collect_example = snapshot
1643 .file()
1644 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1645 && self.can_collect_events(&inputs.events, cx);
1646
1647 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1648 let events_for_capture =
1649 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1650 if let Some(example_task) = capture_example::capture_example(
1651 project.clone(),
1652 active_buffer.clone(),
1653 position,
1654 events_for_capture,
1655 cx,
1656 ) {
1657 cx.spawn(async move |_this, _cx| {
1658 let example = example_task.await?;
1659 telemetry::event!("Edit Prediction Example Captured", example = example);
1660 anyhow::Ok(())
1661 })
1662 .detach_and_log_err(cx);
1663 }
1664 }
1665 let task = match self.edit_prediction_model {
1666 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1667 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1668 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1669 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1670 };
1671
1672 cx.spawn(async move |this, cx| {
1673 let prediction = task.await?;
1674
1675 if prediction.is_none() && allow_jump {
1676 let cursor_point = position.to_point(&snapshot);
1677 if has_events
1678 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1679 active_buffer.clone(),
1680 &snapshot,
1681 diagnostic_search_range,
1682 cursor_point,
1683 &project,
1684 cx,
1685 )
1686 .await?
1687 {
1688 return this
1689 .update(cx, |this, cx| {
1690 this.request_prediction_internal(
1691 project,
1692 jump_buffer,
1693 jump_position,
1694 trigger,
1695 false,
1696 cx,
1697 )
1698 })?
1699 .await;
1700 }
1701
1702 return anyhow::Ok(None);
1703 }
1704
1705 Ok(prediction)
1706 })
1707 }
1708
1709 async fn next_diagnostic_location(
1710 active_buffer: Entity<Buffer>,
1711 active_buffer_snapshot: &BufferSnapshot,
1712 active_buffer_diagnostic_search_range: Range<Point>,
1713 active_buffer_cursor_point: Point,
1714 project: &Entity<Project>,
1715 cx: &mut AsyncApp,
1716 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1717 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1718 let mut jump_location = active_buffer_snapshot
1719 .diagnostic_groups(None)
1720 .into_iter()
1721 .filter_map(|(_, group)| {
1722 let range = &group.entries[group.primary_ix]
1723 .range
1724 .to_point(&active_buffer_snapshot);
1725 if range.overlaps(&active_buffer_diagnostic_search_range) {
1726 None
1727 } else {
1728 Some(range.start)
1729 }
1730 })
1731 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1732 .map(|position| {
1733 (
1734 active_buffer.clone(),
1735 active_buffer_snapshot.anchor_before(position),
1736 )
1737 });
1738
1739 if jump_location.is_none() {
1740 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1741 let file = buffer.file()?;
1742
1743 Some(ProjectPath {
1744 worktree_id: file.worktree_id(cx),
1745 path: file.path().clone(),
1746 })
1747 })?;
1748
1749 let buffer_task = project.update(cx, |project, cx| {
1750 let (path, _, _) = project
1751 .diagnostic_summaries(false, cx)
1752 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1753 .max_by_key(|(path, _, _)| {
1754 // find the buffer with errors that shares most parent directories
1755 path.path
1756 .components()
1757 .zip(
1758 active_buffer_path
1759 .as_ref()
1760 .map(|p| p.path.components())
1761 .unwrap_or_default(),
1762 )
1763 .take_while(|(a, b)| a == b)
1764 .count()
1765 })?;
1766
1767 Some(project.open_buffer(path, cx))
1768 })?;
1769
1770 if let Some(buffer_task) = buffer_task {
1771 let closest_buffer = buffer_task.await?;
1772
1773 jump_location = closest_buffer
1774 .read_with(cx, |buffer, _cx| {
1775 buffer
1776 .buffer_diagnostics(None)
1777 .into_iter()
1778 .min_by_key(|entry| entry.diagnostic.severity)
1779 .map(|entry| entry.range.start)
1780 })?
1781 .map(|position| (closest_buffer, position));
1782 }
1783 }
1784
1785 anyhow::Ok(jump_location)
1786 }
1787
1788 async fn send_raw_llm_request(
1789 request: open_ai::Request,
1790 client: Arc<Client>,
1791 llm_token: LlmApiToken,
1792 app_version: Version,
1793 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1794 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1795 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1796 let url = client
1797 .http_client()
1798 .build_zed_llm_url("/predict_edits/raw", &[])?;
1799
1800 #[cfg(feature = "cli-support")]
1801 let cache_key = if let Some(cache) = eval_cache {
1802 use collections::FxHasher;
1803 use std::hash::{Hash, Hasher};
1804
1805 let mut hasher = FxHasher::default();
1806 url.hash(&mut hasher);
1807 let request_str = serde_json::to_string_pretty(&request)?;
1808 request_str.hash(&mut hasher);
1809 let hash = hasher.finish();
1810
1811 let key = (eval_cache_kind, hash);
1812 if let Some(response_str) = cache.read(key) {
1813 return Ok((serde_json::from_str(&response_str)?, None));
1814 }
1815
1816 Some((cache, request_str, key))
1817 } else {
1818 None
1819 };
1820
1821 let (response, usage) = Self::send_api_request(
1822 |builder| {
1823 let req = builder
1824 .uri(url.as_ref())
1825 .body(serde_json::to_string(&request)?.into());
1826 Ok(req?)
1827 },
1828 client,
1829 llm_token,
1830 app_version,
1831 true,
1832 )
1833 .await?;
1834
1835 #[cfg(feature = "cli-support")]
1836 if let Some((cache, request, key)) = cache_key {
1837 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1838 }
1839
1840 Ok((response, usage))
1841 }
1842
1843 fn handle_api_response<T>(
1844 this: &WeakEntity<Self>,
1845 response: Result<(T, Option<EditPredictionUsage>)>,
1846 cx: &mut gpui::AsyncApp,
1847 ) -> Result<T> {
1848 match response {
1849 Ok((data, usage)) => {
1850 if let Some(usage) = usage {
1851 this.update(cx, |this, cx| {
1852 this.user_store.update(cx, |user_store, cx| {
1853 user_store.update_edit_prediction_usage(usage, cx);
1854 });
1855 })
1856 .ok();
1857 }
1858 Ok(data)
1859 }
1860 Err(err) => {
1861 if err.is::<ZedUpdateRequiredError>() {
1862 cx.update(|cx| {
1863 this.update(cx, |this, _cx| {
1864 this.update_required = true;
1865 })
1866 .ok();
1867
1868 let error_message: SharedString = err.to_string().into();
1869 show_app_notification(
1870 NotificationId::unique::<ZedUpdateRequiredError>(),
1871 cx,
1872 move |cx| {
1873 cx.new(|cx| {
1874 ErrorMessagePrompt::new(error_message.clone(), cx)
1875 .with_link_button("Update Zed", "https://zed.dev/releases")
1876 })
1877 },
1878 );
1879 })
1880 .ok();
1881 }
1882 Err(err)
1883 }
1884 }
1885 }
1886
1887 async fn send_api_request<Res>(
1888 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1889 client: Arc<Client>,
1890 llm_token: LlmApiToken,
1891 app_version: Version,
1892 require_auth: bool,
1893 ) -> Result<(Res, Option<EditPredictionUsage>)>
1894 where
1895 Res: DeserializeOwned,
1896 {
1897 let http_client = client.http_client();
1898
1899 let mut token = if require_auth {
1900 Some(llm_token.acquire(&client).await?)
1901 } else {
1902 llm_token.acquire(&client).await.ok()
1903 };
1904 let mut did_retry = false;
1905
1906 loop {
1907 let request_builder = http_client::Request::builder().method(Method::POST);
1908
1909 let mut request_builder = request_builder
1910 .header("Content-Type", "application/json")
1911 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
1912
1913 // Only add Authorization header if we have a token
1914 if let Some(ref token_value) = token {
1915 request_builder =
1916 request_builder.header("Authorization", format!("Bearer {}", token_value));
1917 }
1918
1919 let request = build(request_builder)?;
1920
1921 let mut response = http_client.send(request).await?;
1922
1923 if let Some(minimum_required_version) = response
1924 .headers()
1925 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1926 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1927 {
1928 anyhow::ensure!(
1929 app_version >= minimum_required_version,
1930 ZedUpdateRequiredError {
1931 minimum_version: minimum_required_version
1932 }
1933 );
1934 }
1935
1936 if response.status().is_success() {
1937 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1938
1939 let mut body = Vec::new();
1940 response.body_mut().read_to_end(&mut body).await?;
1941 return Ok((serde_json::from_slice(&body)?, usage));
1942 } else if !did_retry
1943 && token.is_some()
1944 && response
1945 .headers()
1946 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1947 .is_some()
1948 {
1949 did_retry = true;
1950 token = Some(llm_token.refresh(&client).await?);
1951 } else {
1952 let mut body = String::new();
1953 response.body_mut().read_to_string(&mut body).await?;
1954 anyhow::bail!(
1955 "Request failed with status: {:?}\nBody: {}",
1956 response.status(),
1957 body
1958 );
1959 }
1960 }
1961 }
1962
1963 pub fn refresh_context(
1964 &mut self,
1965 project: &Entity<Project>,
1966 buffer: &Entity<language::Buffer>,
1967 cursor_position: language::Anchor,
1968 cx: &mut Context<Self>,
1969 ) {
1970 if self.use_context {
1971 self.get_or_init_project(project, cx)
1972 .context
1973 .update(cx, |store, cx| {
1974 store.refresh(buffer.clone(), cursor_position, cx);
1975 });
1976 }
1977 }
1978
1979 #[cfg(feature = "cli-support")]
1980 pub fn set_context_for_buffer(
1981 &mut self,
1982 project: &Entity<Project>,
1983 related_files: Vec<RelatedFile>,
1984 cx: &mut Context<Self>,
1985 ) {
1986 self.get_or_init_project(project, cx)
1987 .context
1988 .update(cx, |store, _| {
1989 store.set_related_files(related_files);
1990 });
1991 }
1992
1993 fn is_file_open_source(
1994 &self,
1995 project: &Entity<Project>,
1996 file: &Arc<dyn File>,
1997 cx: &App,
1998 ) -> bool {
1999 if !file.is_local() || file.is_private() {
2000 return false;
2001 }
2002 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2003 return false;
2004 };
2005 project_state
2006 .license_detection_watchers
2007 .get(&file.worktree_id(cx))
2008 .as_ref()
2009 .is_some_and(|watcher| watcher.is_project_open_source())
2010 }
2011
2012 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2013 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2014 }
2015
2016 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2017 if !self.data_collection_choice.is_enabled(cx) {
2018 return false;
2019 }
2020 events.iter().all(|event| {
2021 matches!(
2022 event.as_ref(),
2023 zeta_prompt::Event::BufferChange {
2024 in_open_source_repo: true,
2025 ..
2026 }
2027 )
2028 })
2029 }
2030
2031 fn load_data_collection_choice() -> DataCollectionChoice {
2032 let choice = KEY_VALUE_STORE
2033 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2034 .log_err()
2035 .flatten();
2036
2037 match choice.as_deref() {
2038 Some("true") => DataCollectionChoice::Enabled,
2039 Some("false") => DataCollectionChoice::Disabled,
2040 Some(_) => {
2041 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2042 DataCollectionChoice::NotAnswered
2043 }
2044 None => DataCollectionChoice::NotAnswered,
2045 }
2046 }
2047
2048 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2049 self.data_collection_choice = self.data_collection_choice.toggle();
2050 let new_choice = self.data_collection_choice;
2051 let is_enabled = new_choice.is_enabled(cx);
2052 db::write_and_log(cx, move || {
2053 KEY_VALUE_STORE.write_kvp(
2054 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2055 is_enabled.to_string(),
2056 )
2057 });
2058 }
2059
2060 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2061 self.shown_predictions.iter()
2062 }
2063
2064 pub fn shown_completions_len(&self) -> usize {
2065 self.shown_predictions.len()
2066 }
2067
2068 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2069 self.rated_predictions.contains(id)
2070 }
2071
2072 pub fn rate_prediction(
2073 &mut self,
2074 prediction: &EditPrediction,
2075 rating: EditPredictionRating,
2076 feedback: String,
2077 cx: &mut Context<Self>,
2078 ) {
2079 self.rated_predictions.insert(prediction.id.clone());
2080 telemetry::event!(
2081 "Edit Prediction Rated",
2082 rating,
2083 inputs = prediction.inputs,
2084 output = prediction
2085 .edit_preview
2086 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2087 feedback
2088 );
2089 self.client.telemetry().flush_events().detach();
2090 cx.notify();
2091 }
2092
2093 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2094 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2095 && all_language_settings(None, cx).edit_predictions.use_context;
2096 }
2097}
2098
2099#[derive(Error, Debug)]
2100#[error(
2101 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2102)]
2103pub struct ZedUpdateRequiredError {
2104 minimum_version: Version,
2105}
2106
2107#[cfg(feature = "cli-support")]
2108pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2109
2110#[cfg(feature = "cli-support")]
2111#[derive(Debug, Clone, Copy, PartialEq)]
2112pub enum EvalCacheEntryKind {
2113 Context,
2114 Search,
2115 Prediction,
2116}
2117
2118#[cfg(feature = "cli-support")]
2119impl std::fmt::Display for EvalCacheEntryKind {
2120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2121 match self {
2122 EvalCacheEntryKind::Search => write!(f, "search"),
2123 EvalCacheEntryKind::Context => write!(f, "context"),
2124 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2125 }
2126 }
2127}
2128
2129#[cfg(feature = "cli-support")]
2130pub trait EvalCache: Send + Sync {
2131 fn read(&self, key: EvalCacheKey) -> Option<String>;
2132 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2133}
2134
2135#[derive(Debug, Clone, Copy)]
2136pub enum DataCollectionChoice {
2137 NotAnswered,
2138 Enabled,
2139 Disabled,
2140}
2141
2142impl DataCollectionChoice {
2143 pub fn is_enabled(self, cx: &App) -> bool {
2144 if cx.is_staff() {
2145 return true;
2146 }
2147 match self {
2148 Self::Enabled => true,
2149 Self::NotAnswered | Self::Disabled => false,
2150 }
2151 }
2152
2153 #[must_use]
2154 pub fn toggle(&self) -> DataCollectionChoice {
2155 match self {
2156 Self::Enabled => Self::Disabled,
2157 Self::Disabled => Self::Enabled,
2158 Self::NotAnswered => Self::Enabled,
2159 }
2160 }
2161}
2162
2163impl From<bool> for DataCollectionChoice {
2164 fn from(value: bool) -> Self {
2165 match value {
2166 true => DataCollectionChoice::Enabled,
2167 false => DataCollectionChoice::Disabled,
2168 }
2169 }
2170}
2171
2172struct ZedPredictUpsell;
2173
2174impl Dismissable for ZedPredictUpsell {
2175 const KEY: &'static str = "dismissed-edit-predict-upsell";
2176
2177 fn dismissed() -> bool {
2178 // To make this backwards compatible with older versions of Zed, we
2179 // check if the user has seen the previous Edit Prediction Onboarding
2180 // before, by checking the data collection choice which was written to
2181 // the database once the user clicked on "Accept and Enable"
2182 if KEY_VALUE_STORE
2183 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2184 .log_err()
2185 .is_some_and(|s| s.is_some())
2186 {
2187 return true;
2188 }
2189
2190 KEY_VALUE_STORE
2191 .read_kvp(Self::KEY)
2192 .log_err()
2193 .is_some_and(|s| s.is_some())
2194 }
2195}
2196
2197pub fn should_show_upsell_modal() -> bool {
2198 !ZedPredictUpsell::dismissed()
2199}
2200
2201pub fn init(cx: &mut App) {
2202 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2203 workspace.register_action(
2204 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2205 ZedPredictModal::toggle(
2206 workspace,
2207 workspace.user_store().clone(),
2208 workspace.client().clone(),
2209 window,
2210 cx,
2211 )
2212 },
2213 );
2214
2215 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2216 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2217 settings
2218 .project
2219 .all_languages
2220 .features
2221 .get_or_insert_default()
2222 .edit_prediction_provider = Some(EditPredictionProvider::None)
2223 });
2224 });
2225 })
2226 .detach();
2227}