1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
7 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
8 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
9};
10use collections::{HashMap, HashSet};
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction_context::EditPredictionExcerptOptions;
13use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
14use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
15use futures::{
16 AsyncReadExt as _, FutureExt as _, StreamExt as _,
17 channel::mpsc::{self, UnboundedReceiver},
18 select_biased,
19};
20use gpui::BackgroundExecutor;
21use gpui::http_client::Url;
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
24 http_client::{self, AsyncBody, Method},
25 prelude::*,
26};
27use language::language_settings::all_language_settings;
28use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
29use language::{BufferSnapshot, OffsetRangeExt};
30use language_model::{LlmApiToken, RefreshLlmTokenListener};
31use project::{Project, ProjectPath, WorktreeId};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
36use std::collections::{VecDeque, hash_map};
37use text::Edit;
38use workspace::Workspace;
39
40use std::ops::Range;
41use std::path::Path;
42use std::rc::Rc;
43use std::str::FromStr as _;
44use std::sync::{Arc, LazyLock};
45use std::time::{Duration, Instant};
46use std::{env, mem};
47use thiserror::Error;
48use util::{RangeExt as _, ResultExt as _};
49use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
50
51pub mod cursor_excerpt;
52pub mod example_spec;
53mod license_detection;
54pub mod mercury;
55mod onboarding_modal;
56pub mod open_ai_response;
57mod prediction;
58pub mod sweep_ai;
59
60pub mod udiff;
61
62mod capture_example;
63mod zed_edit_prediction_delegate;
64pub mod zeta1;
65pub mod zeta2;
66
67#[cfg(test)]
68mod edit_prediction_tests;
69
70use crate::capture_example::should_sample_edit_prediction_example_capture;
71use crate::license_detection::LicenseDetectionWatcher;
72use crate::mercury::Mercury;
73use crate::onboarding_modal::ZedPredictModal;
74pub use crate::prediction::EditPrediction;
75pub use crate::prediction::EditPredictionId;
76use crate::prediction::EditPredictionResult;
77pub use crate::sweep_ai::SweepAi;
78pub use capture_example::capture_example;
79pub use language_model::ApiKeyState;
80pub use telemetry_events::EditPredictionRating;
81pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
82
83actions!(
84 edit_prediction,
85 [
86 /// Resets the edit prediction onboarding state.
87 ResetOnboarding,
88 /// Clears the edit prediction history.
89 ClearHistory,
90 ]
91);
92
93/// Maximum number of events to track.
94const EVENT_COUNT_MAX: usize = 6;
95const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
96const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
97const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
98const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
99
100pub struct SweepFeatureFlag;
101
102impl FeatureFlag for SweepFeatureFlag {
103 const NAME: &str = "sweep-ai";
104}
105
106pub struct MercuryFeatureFlag;
107
108impl FeatureFlag for MercuryFeatureFlag {
109 const NAME: &str = "mercury";
110}
111
112pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
113 context: EditPredictionExcerptOptions {
114 max_bytes: 512,
115 min_bytes: 128,
116 target_before_cursor_over_total_bytes: 0.5,
117 },
118 prompt_format: PromptFormat::DEFAULT,
119};
120
121static USE_OLLAMA: LazyLock<bool> =
122 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
123
124static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
125 match env::var("ZED_ZETA2_MODEL").as_deref() {
126 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
127 Ok(model) => model,
128 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
129 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
130 }
131 .to_string()
132});
133
134pub struct Zeta2FeatureFlag;
135
136impl FeatureFlag for Zeta2FeatureFlag {
137 const NAME: &'static str = "zeta2";
138
139 fn enabled_for_staff() -> bool {
140 true
141 }
142}
143
144pub struct EditPredictionExampleCaptureFeatureFlag;
145
146impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
147 const NAME: &'static str = "edit-prediction-example-capture";
148
149 fn enabled_for_staff() -> bool {
150 true
151 }
152}
153
154#[derive(Clone)]
155struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
156
157impl Global for EditPredictionStoreGlobal {}
158
159pub struct EditPredictionStore {
160 client: Arc<Client>,
161 user_store: Entity<UserStore>,
162 llm_token: LlmApiToken,
163 _llm_token_subscription: Subscription,
164 projects: HashMap<EntityId, ProjectState>,
165 use_context: bool,
166 options: ZetaOptions,
167 update_required: bool,
168 #[cfg(feature = "cli-support")]
169 eval_cache: Option<Arc<dyn EvalCache>>,
170 edit_prediction_model: EditPredictionModel,
171 pub sweep_ai: SweepAi,
172 pub mercury: Mercury,
173 data_collection_choice: DataCollectionChoice,
174 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
175 shown_predictions: VecDeque<EditPrediction>,
176 rated_predictions: HashSet<EditPredictionId>,
177 custom_predict_edits_url: Option<Arc<Url>>,
178}
179
180#[derive(Copy, Clone, Default, PartialEq, Eq)]
181pub enum EditPredictionModel {
182 #[default]
183 Zeta1,
184 Zeta2,
185 Sweep,
186 Mercury,
187}
188
189pub struct EditPredictionModelInput {
190 project: Entity<Project>,
191 buffer: Entity<Buffer>,
192 snapshot: BufferSnapshot,
193 position: Anchor,
194 events: Vec<Arc<zeta_prompt::Event>>,
195 related_files: Arc<[RelatedFile]>,
196 recent_paths: VecDeque<ProjectPath>,
197 trigger: PredictEditsRequestTrigger,
198 diagnostic_search_range: Range<Point>,
199 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
200}
201
202#[derive(Debug, Clone, PartialEq)]
203pub struct ZetaOptions {
204 pub context: EditPredictionExcerptOptions,
205 pub prompt_format: predict_edits_v3::PromptFormat,
206}
207
208#[derive(Debug)]
209pub enum DebugEvent {
210 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
211 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
212 EditPredictionStarted(EditPredictionStartedDebugEvent),
213 EditPredictionFinished(EditPredictionFinishedDebugEvent),
214}
215
216#[derive(Debug)]
217pub struct ContextRetrievalStartedDebugEvent {
218 pub project_entity_id: EntityId,
219 pub timestamp: Instant,
220 pub search_prompt: String,
221}
222
223#[derive(Debug)]
224pub struct ContextRetrievalFinishedDebugEvent {
225 pub project_entity_id: EntityId,
226 pub timestamp: Instant,
227 pub metadata: Vec<(&'static str, SharedString)>,
228}
229
230#[derive(Debug)]
231pub struct EditPredictionStartedDebugEvent {
232 pub buffer: WeakEntity<Buffer>,
233 pub position: Anchor,
234 pub prompt: Option<String>,
235}
236
237#[derive(Debug)]
238pub struct EditPredictionFinishedDebugEvent {
239 pub buffer: WeakEntity<Buffer>,
240 pub position: Anchor,
241 pub model_output: Option<String>,
242}
243
244pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
245
246/// An event with associated metadata for reconstructing buffer state.
247#[derive(Clone)]
248pub struct StoredEvent {
249 pub event: Arc<zeta_prompt::Event>,
250 pub old_snapshot: TextBufferSnapshot,
251}
252
253struct ProjectState {
254 events: VecDeque<StoredEvent>,
255 last_event: Option<LastEvent>,
256 recent_paths: VecDeque<ProjectPath>,
257 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
258 current_prediction: Option<CurrentEditPrediction>,
259 next_pending_prediction_id: usize,
260 pending_predictions: ArrayVec<PendingPrediction, 2>,
261 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
262 last_prediction_refresh: Option<(EntityId, Instant)>,
263 cancelled_predictions: HashSet<usize>,
264 context: Entity<RelatedExcerptStore>,
265 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
266 _subscription: gpui::Subscription,
267}
268
269impl ProjectState {
270 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
271 self.events
272 .iter()
273 .cloned()
274 .chain(
275 self.last_event
276 .as_ref()
277 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
278 )
279 .collect()
280 }
281
282 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
283 self.events
284 .iter()
285 .cloned()
286 .chain(self.last_event.as_ref().iter().flat_map(|event| {
287 let (one, two) = event.split_by_pause();
288 let one = one.finalize(&self.license_detection_watchers, cx);
289 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
290 one.into_iter().chain(two)
291 }))
292 .collect()
293 }
294
295 fn cancel_pending_prediction(
296 &mut self,
297 pending_prediction: PendingPrediction,
298 cx: &mut Context<EditPredictionStore>,
299 ) {
300 self.cancelled_predictions.insert(pending_prediction.id);
301
302 cx.spawn(async move |this, cx| {
303 let Some(prediction_id) = pending_prediction.task.await else {
304 return;
305 };
306
307 this.update(cx, |this, _cx| {
308 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
309 })
310 .ok();
311 })
312 .detach()
313 }
314
315 fn active_buffer(
316 &self,
317 project: &Entity<Project>,
318 cx: &App,
319 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
320 let project = project.read(cx);
321 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
322 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
323 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
324 Some((active_buffer, registered_buffer.last_position))
325 }
326}
327
328#[derive(Debug, Clone)]
329struct CurrentEditPrediction {
330 pub requested_by: PredictionRequestedBy,
331 pub prediction: EditPrediction,
332 pub was_shown: bool,
333 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
334}
335
336impl CurrentEditPrediction {
337 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
338 let Some(new_edits) = self
339 .prediction
340 .interpolate(&self.prediction.buffer.read(cx))
341 else {
342 return false;
343 };
344
345 if self.prediction.buffer != old_prediction.prediction.buffer {
346 return true;
347 }
348
349 let Some(old_edits) = old_prediction
350 .prediction
351 .interpolate(&old_prediction.prediction.buffer.read(cx))
352 else {
353 return true;
354 };
355
356 let requested_by_buffer_id = self.requested_by.buffer_id();
357
358 // This reduces the occurrence of UI thrash from replacing edits
359 //
360 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
361 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
362 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
363 && old_edits.len() == 1
364 && new_edits.len() == 1
365 {
366 let (old_range, old_text) = &old_edits[0];
367 let (new_range, new_text) = &new_edits[0];
368 new_range == old_range && new_text.starts_with(old_text.as_ref())
369 } else {
370 true
371 }
372 }
373}
374
375#[derive(Debug, Clone)]
376enum PredictionRequestedBy {
377 DiagnosticsUpdate,
378 Buffer(EntityId),
379}
380
381impl PredictionRequestedBy {
382 pub fn buffer_id(&self) -> Option<EntityId> {
383 match self {
384 PredictionRequestedBy::DiagnosticsUpdate => None,
385 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
386 }
387 }
388}
389
390#[derive(Debug)]
391struct PendingPrediction {
392 id: usize,
393 task: Task<Option<EditPredictionId>>,
394}
395
396/// A prediction from the perspective of a buffer.
397#[derive(Debug)]
398enum BufferEditPrediction<'a> {
399 Local { prediction: &'a EditPrediction },
400 Jump { prediction: &'a EditPrediction },
401}
402
403#[cfg(test)]
404impl std::ops::Deref for BufferEditPrediction<'_> {
405 type Target = EditPrediction;
406
407 fn deref(&self) -> &Self::Target {
408 match self {
409 BufferEditPrediction::Local { prediction } => prediction,
410 BufferEditPrediction::Jump { prediction } => prediction,
411 }
412 }
413}
414
415struct RegisteredBuffer {
416 file: Option<Arc<dyn File>>,
417 snapshot: TextBufferSnapshot,
418 last_position: Option<Anchor>,
419 _subscriptions: [gpui::Subscription; 2],
420}
421
422#[derive(Clone)]
423struct LastEvent {
424 old_snapshot: TextBufferSnapshot,
425 new_snapshot: TextBufferSnapshot,
426 old_file: Option<Arc<dyn File>>,
427 new_file: Option<Arc<dyn File>>,
428 end_edit_anchor: Option<Anchor>,
429 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
430 last_edit_time: Option<Instant>,
431}
432
433impl LastEvent {
434 pub fn finalize(
435 &self,
436 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
437 cx: &App,
438 ) -> Option<StoredEvent> {
439 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
440 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
441
442 let in_open_source_repo =
443 [self.new_file.as_ref(), self.old_file.as_ref()]
444 .iter()
445 .all(|file| {
446 file.is_some_and(|file| {
447 license_detection_watchers
448 .get(&file.worktree_id(cx))
449 .is_some_and(|watcher| watcher.is_project_open_source())
450 })
451 });
452
453 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
454
455 if path == old_path && diff.is_empty() {
456 None
457 } else {
458 Some(StoredEvent {
459 event: Arc::new(zeta_prompt::Event::BufferChange {
460 old_path,
461 path,
462 diff,
463 in_open_source_repo,
464 // TODO: Actually detect if this edit was predicted or not
465 predicted: false,
466 }),
467 old_snapshot: self.old_snapshot.clone(),
468 })
469 }
470 }
471
472 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
473 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
474 return (self.clone(), None);
475 };
476
477 let before = LastEvent {
478 old_snapshot: self.old_snapshot.clone(),
479 new_snapshot: boundary_snapshot.clone(),
480 old_file: self.old_file.clone(),
481 new_file: self.new_file.clone(),
482 end_edit_anchor: self.end_edit_anchor,
483 snapshot_after_last_editing_pause: None,
484 last_edit_time: self.last_edit_time,
485 };
486
487 let after = LastEvent {
488 old_snapshot: boundary_snapshot.clone(),
489 new_snapshot: self.new_snapshot.clone(),
490 old_file: self.old_file.clone(),
491 new_file: self.new_file.clone(),
492 end_edit_anchor: self.end_edit_anchor,
493 snapshot_after_last_editing_pause: None,
494 last_edit_time: self.last_edit_time,
495 };
496
497 (before, Some(after))
498 }
499}
500
501pub(crate) fn compute_diff_between_snapshots(
502 old_snapshot: &TextBufferSnapshot,
503 new_snapshot: &TextBufferSnapshot,
504) -> Option<String> {
505 let edits: Vec<Edit<usize>> = new_snapshot
506 .edits_since::<usize>(&old_snapshot.version)
507 .collect();
508
509 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
510
511 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
512 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
513 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
514 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
515
516 const CONTEXT_LINES: u32 = 3;
517
518 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
519 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
520 let old_context_end_row =
521 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
522 let new_context_end_row =
523 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
524
525 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
526 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
527 let old_end_line_offset = old_snapshot
528 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
529 let new_end_line_offset = new_snapshot
530 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
531 let old_edit_range = old_start_line_offset..old_end_line_offset;
532 let new_edit_range = new_start_line_offset..new_end_line_offset;
533
534 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
535 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
536
537 let diff = language::unified_diff_with_offsets(
538 &old_region_text,
539 &new_region_text,
540 old_context_start_row,
541 new_context_start_row,
542 );
543
544 Some(diff)
545}
546
547fn buffer_path_with_id_fallback(
548 file: Option<&Arc<dyn File>>,
549 snapshot: &TextBufferSnapshot,
550 cx: &App,
551) -> Arc<Path> {
552 if let Some(file) = file {
553 file.full_path(cx).into()
554 } else {
555 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
556 }
557}
558
559impl EditPredictionStore {
560 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
561 cx.try_global::<EditPredictionStoreGlobal>()
562 .map(|global| global.0.clone())
563 }
564
565 pub fn global(
566 client: &Arc<Client>,
567 user_store: &Entity<UserStore>,
568 cx: &mut App,
569 ) -> Entity<Self> {
570 cx.try_global::<EditPredictionStoreGlobal>()
571 .map(|global| global.0.clone())
572 .unwrap_or_else(|| {
573 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
574 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
575 ep_store
576 })
577 }
578
579 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
580 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
581 let data_collection_choice = Self::load_data_collection_choice();
582
583 let llm_token = LlmApiToken::default();
584
585 let (reject_tx, reject_rx) = mpsc::unbounded();
586 cx.background_spawn({
587 let client = client.clone();
588 let llm_token = llm_token.clone();
589 let app_version = AppVersion::global(cx);
590 let background_executor = cx.background_executor().clone();
591 async move {
592 Self::handle_rejected_predictions(
593 reject_rx,
594 client,
595 llm_token,
596 app_version,
597 background_executor,
598 )
599 .await
600 }
601 })
602 .detach();
603
604 let mut this = Self {
605 projects: HashMap::default(),
606 client,
607 user_store,
608 options: DEFAULT_OPTIONS,
609 use_context: false,
610 llm_token,
611 _llm_token_subscription: cx.subscribe(
612 &refresh_llm_token_listener,
613 |this, _listener, _event, cx| {
614 let client = this.client.clone();
615 let llm_token = this.llm_token.clone();
616 cx.spawn(async move |_this, _cx| {
617 llm_token.refresh(&client).await?;
618 anyhow::Ok(())
619 })
620 .detach_and_log_err(cx);
621 },
622 ),
623 update_required: false,
624 #[cfg(feature = "cli-support")]
625 eval_cache: None,
626 edit_prediction_model: EditPredictionModel::Zeta2,
627 sweep_ai: SweepAi::new(cx),
628 mercury: Mercury::new(cx),
629 data_collection_choice,
630 reject_predictions_tx: reject_tx,
631 rated_predictions: Default::default(),
632 shown_predictions: Default::default(),
633 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
634 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
635 Err(_) => {
636 if *USE_OLLAMA {
637 Some(
638 Url::parse("http://localhost:11434/v1/chat/completions")
639 .unwrap()
640 .into(),
641 )
642 } else {
643 None
644 }
645 }
646 },
647 };
648
649 this.configure_context_retrieval(cx);
650 let weak_this = cx.weak_entity();
651 cx.on_flags_ready(move |_, cx| {
652 weak_this
653 .update(cx, |this, cx| this.configure_context_retrieval(cx))
654 .ok();
655 })
656 .detach();
657 cx.observe_global::<SettingsStore>(|this, cx| {
658 this.configure_context_retrieval(cx);
659 })
660 .detach();
661
662 this
663 }
664
665 #[cfg(test)]
666 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
667 self.custom_predict_edits_url = Some(url.into());
668 }
669
670 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
671 self.edit_prediction_model = model;
672 }
673
674 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
675 self.sweep_ai.api_token.read(cx).has_key()
676 }
677
678 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
679 self.mercury.api_token.read(cx).has_key()
680 }
681
682 #[cfg(feature = "cli-support")]
683 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
684 self.eval_cache = Some(cache);
685 }
686
687 pub fn options(&self) -> &ZetaOptions {
688 &self.options
689 }
690
691 pub fn set_options(&mut self, options: ZetaOptions) {
692 self.options = options;
693 }
694
695 pub fn set_use_context(&mut self, use_context: bool) {
696 self.use_context = use_context;
697 }
698
699 pub fn clear_history(&mut self) {
700 for project_state in self.projects.values_mut() {
701 project_state.events.clear();
702 project_state.last_event.take();
703 }
704 }
705
706 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
707 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
708 project_state.events.clear();
709 project_state.last_event.take();
710 }
711 }
712
713 pub fn edit_history_for_project(
714 &self,
715 project: &Entity<Project>,
716 cx: &App,
717 ) -> Vec<StoredEvent> {
718 self.projects
719 .get(&project.entity_id())
720 .map(|project_state| project_state.events(cx))
721 .unwrap_or_default()
722 }
723
724 pub fn edit_history_for_project_with_pause_split_last_event(
725 &self,
726 project: &Entity<Project>,
727 cx: &App,
728 ) -> Vec<StoredEvent> {
729 self.projects
730 .get(&project.entity_id())
731 .map(|project_state| project_state.events_split_by_pause(cx))
732 .unwrap_or_default()
733 }
734
735 pub fn context_for_project<'a>(
736 &'a self,
737 project: &Entity<Project>,
738 cx: &'a App,
739 ) -> Arc<[RelatedFile]> {
740 self.projects
741 .get(&project.entity_id())
742 .map(|project| project.context.read(cx).related_files())
743 .unwrap_or_else(|| vec![].into())
744 }
745
746 pub fn context_for_project_with_buffers<'a>(
747 &'a self,
748 project: &Entity<Project>,
749 cx: &'a App,
750 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
751 self.projects
752 .get(&project.entity_id())
753 .map(|project| project.context.read(cx).related_files_with_buffers())
754 }
755
756 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
757 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
758 self.user_store.read(cx).edit_prediction_usage()
759 } else {
760 None
761 }
762 }
763
764 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
765 self.get_or_init_project(project, cx);
766 }
767
768 pub fn register_buffer(
769 &mut self,
770 buffer: &Entity<Buffer>,
771 project: &Entity<Project>,
772 cx: &mut Context<Self>,
773 ) {
774 let project_state = self.get_or_init_project(project, cx);
775 Self::register_buffer_impl(project_state, buffer, project, cx);
776 }
777
778 fn get_or_init_project(
779 &mut self,
780 project: &Entity<Project>,
781 cx: &mut Context<Self>,
782 ) -> &mut ProjectState {
783 let entity_id = project.entity_id();
784 self.projects
785 .entry(entity_id)
786 .or_insert_with(|| ProjectState {
787 context: {
788 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
789 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
790 this.handle_excerpt_store_event(entity_id, event);
791 })
792 .detach();
793 related_excerpt_store
794 },
795 events: VecDeque::new(),
796 last_event: None,
797 recent_paths: VecDeque::new(),
798 debug_tx: None,
799 registered_buffers: HashMap::default(),
800 current_prediction: None,
801 cancelled_predictions: HashSet::default(),
802 pending_predictions: ArrayVec::new(),
803 next_pending_prediction_id: 0,
804 last_prediction_refresh: None,
805 license_detection_watchers: HashMap::default(),
806 _subscription: cx.subscribe(&project, Self::handle_project_event),
807 })
808 }
809
810 pub fn remove_project(&mut self, project: &Entity<Project>) {
811 self.projects.remove(&project.entity_id());
812 }
813
814 fn handle_excerpt_store_event(
815 &mut self,
816 project_entity_id: EntityId,
817 event: &RelatedExcerptStoreEvent,
818 ) {
819 if let Some(project_state) = self.projects.get(&project_entity_id) {
820 if let Some(debug_tx) = project_state.debug_tx.clone() {
821 match event {
822 RelatedExcerptStoreEvent::StartedRefresh => {
823 debug_tx
824 .unbounded_send(DebugEvent::ContextRetrievalStarted(
825 ContextRetrievalStartedDebugEvent {
826 project_entity_id: project_entity_id,
827 timestamp: Instant::now(),
828 search_prompt: String::new(),
829 },
830 ))
831 .ok();
832 }
833 RelatedExcerptStoreEvent::FinishedRefresh {
834 cache_hit_count,
835 cache_miss_count,
836 mean_definition_latency,
837 max_definition_latency,
838 } => {
839 debug_tx
840 .unbounded_send(DebugEvent::ContextRetrievalFinished(
841 ContextRetrievalFinishedDebugEvent {
842 project_entity_id: project_entity_id,
843 timestamp: Instant::now(),
844 metadata: vec![
845 (
846 "Cache Hits",
847 format!(
848 "{}/{}",
849 cache_hit_count,
850 cache_hit_count + cache_miss_count
851 )
852 .into(),
853 ),
854 (
855 "Max LSP Time",
856 format!("{} ms", max_definition_latency.as_millis())
857 .into(),
858 ),
859 (
860 "Mean LSP Time",
861 format!("{} ms", mean_definition_latency.as_millis())
862 .into(),
863 ),
864 ],
865 },
866 ))
867 .ok();
868 }
869 }
870 }
871 }
872 }
873
874 pub fn debug_info(
875 &mut self,
876 project: &Entity<Project>,
877 cx: &mut Context<Self>,
878 ) -> mpsc::UnboundedReceiver<DebugEvent> {
879 let project_state = self.get_or_init_project(project, cx);
880 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
881 project_state.debug_tx = Some(debug_watch_tx);
882 debug_watch_rx
883 }
884
885 fn handle_project_event(
886 &mut self,
887 project: Entity<Project>,
888 event: &project::Event,
889 cx: &mut Context<Self>,
890 ) {
891 // TODO [zeta2] init with recent paths
892 match event {
893 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
894 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
895 return;
896 };
897 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
898 if let Some(path) = path {
899 if let Some(ix) = project_state
900 .recent_paths
901 .iter()
902 .position(|probe| probe == &path)
903 {
904 project_state.recent_paths.remove(ix);
905 }
906 project_state.recent_paths.push_front(path);
907 }
908 }
909 project::Event::DiagnosticsUpdated { .. } => {
910 if cx.has_flag::<Zeta2FeatureFlag>() {
911 self.refresh_prediction_from_diagnostics(project, cx);
912 }
913 }
914 _ => (),
915 }
916 }
917
918 fn register_buffer_impl<'a>(
919 project_state: &'a mut ProjectState,
920 buffer: &Entity<Buffer>,
921 project: &Entity<Project>,
922 cx: &mut Context<Self>,
923 ) -> &'a mut RegisteredBuffer {
924 let buffer_id = buffer.entity_id();
925
926 if let Some(file) = buffer.read(cx).file() {
927 let worktree_id = file.worktree_id(cx);
928 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
929 project_state
930 .license_detection_watchers
931 .entry(worktree_id)
932 .or_insert_with(|| {
933 let project_entity_id = project.entity_id();
934 cx.observe_release(&worktree, move |this, _worktree, _cx| {
935 let Some(project_state) = this.projects.get_mut(&project_entity_id)
936 else {
937 return;
938 };
939 project_state
940 .license_detection_watchers
941 .remove(&worktree_id);
942 })
943 .detach();
944 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
945 });
946 }
947 }
948
949 match project_state.registered_buffers.entry(buffer_id) {
950 hash_map::Entry::Occupied(entry) => entry.into_mut(),
951 hash_map::Entry::Vacant(entry) => {
952 let buf = buffer.read(cx);
953 let snapshot = buf.text_snapshot();
954 let file = buf.file().cloned();
955 let project_entity_id = project.entity_id();
956 entry.insert(RegisteredBuffer {
957 snapshot,
958 file,
959 last_position: None,
960 _subscriptions: [
961 cx.subscribe(buffer, {
962 let project = project.downgrade();
963 move |this, buffer, event, cx| {
964 if let language::BufferEvent::Edited = event
965 && let Some(project) = project.upgrade()
966 {
967 this.report_changes_for_buffer(&buffer, &project, cx);
968 }
969 }
970 }),
971 cx.observe_release(buffer, move |this, _buffer, _cx| {
972 let Some(project_state) = this.projects.get_mut(&project_entity_id)
973 else {
974 return;
975 };
976 project_state.registered_buffers.remove(&buffer_id);
977 }),
978 ],
979 })
980 }
981 }
982 }
983
984 fn report_changes_for_buffer(
985 &mut self,
986 buffer: &Entity<Buffer>,
987 project: &Entity<Project>,
988 cx: &mut Context<Self>,
989 ) {
990 let project_state = self.get_or_init_project(project, cx);
991 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
992
993 let buf = buffer.read(cx);
994 let new_file = buf.file().cloned();
995 let new_snapshot = buf.text_snapshot();
996 if new_snapshot.version == registered_buffer.snapshot.version {
997 return;
998 }
999
1000 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1001 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1002 let end_edit_anchor = new_snapshot
1003 .anchored_edits_since::<Point>(&old_snapshot.version)
1004 .last()
1005 .map(|(_, range)| range.end);
1006 let events = &mut project_state.events;
1007
1008 let now = cx.background_executor().now();
1009 if let Some(last_event) = project_state.last_event.as_mut() {
1010 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1011 == last_event.new_snapshot.remote_id()
1012 && old_snapshot.version == last_event.new_snapshot.version;
1013
1014 let should_coalesce = is_next_snapshot_of_same_buffer
1015 && end_edit_anchor
1016 .as_ref()
1017 .zip(last_event.end_edit_anchor.as_ref())
1018 .is_some_and(|(a, b)| {
1019 let a = a.to_point(&new_snapshot);
1020 let b = b.to_point(&new_snapshot);
1021 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
1022 });
1023
1024 if should_coalesce {
1025 let pause_elapsed = last_event
1026 .last_edit_time
1027 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1028 .unwrap_or(false);
1029 if pause_elapsed {
1030 last_event.snapshot_after_last_editing_pause =
1031 Some(last_event.new_snapshot.clone());
1032 }
1033
1034 last_event.end_edit_anchor = end_edit_anchor;
1035 last_event.new_snapshot = new_snapshot;
1036 last_event.last_edit_time = Some(now);
1037 return;
1038 }
1039 }
1040
1041 if events.len() + 1 >= EVENT_COUNT_MAX {
1042 events.pop_front();
1043 }
1044
1045 if let Some(event) = project_state.last_event.take() {
1046 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
1047 }
1048
1049 project_state.last_event = Some(LastEvent {
1050 old_file,
1051 new_file,
1052 old_snapshot,
1053 new_snapshot,
1054 end_edit_anchor,
1055 snapshot_after_last_editing_pause: None,
1056 last_edit_time: Some(now),
1057 });
1058 }
1059
1060 fn prediction_at(
1061 &mut self,
1062 buffer: &Entity<Buffer>,
1063 position: Option<language::Anchor>,
1064 project: &Entity<Project>,
1065 cx: &App,
1066 ) -> Option<BufferEditPrediction<'_>> {
1067 let project_state = self.projects.get_mut(&project.entity_id())?;
1068 if let Some(position) = position
1069 && let Some(buffer) = project_state
1070 .registered_buffers
1071 .get_mut(&buffer.entity_id())
1072 {
1073 buffer.last_position = Some(position);
1074 }
1075
1076 let CurrentEditPrediction {
1077 requested_by,
1078 prediction,
1079 ..
1080 } = project_state.current_prediction.as_ref()?;
1081
1082 if prediction.targets_buffer(buffer.read(cx)) {
1083 Some(BufferEditPrediction::Local { prediction })
1084 } else {
1085 let show_jump = match requested_by {
1086 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1087 requested_by_buffer_id == &buffer.entity_id()
1088 }
1089 PredictionRequestedBy::DiagnosticsUpdate => true,
1090 };
1091
1092 if show_jump {
1093 Some(BufferEditPrediction::Jump { prediction })
1094 } else {
1095 None
1096 }
1097 }
1098 }
1099
1100 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1101 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1102 return;
1103 };
1104
1105 let Some(current_prediction) = project_state.current_prediction.take() else {
1106 return;
1107 };
1108
1109 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1110 project_state.cancel_pending_prediction(pending_prediction, cx);
1111 }
1112
1113 match self.edit_prediction_model {
1114 EditPredictionModel::Sweep => {
1115 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1116 }
1117 EditPredictionModel::Mercury => {}
1118 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1119 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1120 }
1121 }
1122 }
1123
1124 async fn handle_rejected_predictions(
1125 rx: UnboundedReceiver<EditPredictionRejection>,
1126 client: Arc<Client>,
1127 llm_token: LlmApiToken,
1128 app_version: Version,
1129 background_executor: BackgroundExecutor,
1130 ) {
1131 let mut rx = std::pin::pin!(rx.peekable());
1132 let mut batched = Vec::new();
1133
1134 while let Some(rejection) = rx.next().await {
1135 batched.push(rejection);
1136
1137 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1138 select_biased! {
1139 next = rx.as_mut().peek().fuse() => {
1140 if next.is_some() {
1141 continue;
1142 }
1143 }
1144 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1145 }
1146 }
1147
1148 let url = client
1149 .http_client()
1150 .build_zed_llm_url("/predict_edits/reject", &[])
1151 .unwrap();
1152
1153 let flush_count = batched
1154 .len()
1155 // in case items have accumulated after failure
1156 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1157 let start = batched.len() - flush_count;
1158
1159 let body = RejectEditPredictionsBodyRef {
1160 rejections: &batched[start..],
1161 };
1162
1163 let result = Self::send_api_request::<()>(
1164 |builder| {
1165 let req = builder
1166 .uri(url.as_ref())
1167 .body(serde_json::to_string(&body)?.into());
1168 anyhow::Ok(req?)
1169 },
1170 client.clone(),
1171 llm_token.clone(),
1172 app_version.clone(),
1173 true,
1174 )
1175 .await;
1176
1177 if result.log_err().is_some() {
1178 batched.drain(start..);
1179 }
1180 }
1181 }
1182
1183 fn reject_current_prediction(
1184 &mut self,
1185 reason: EditPredictionRejectReason,
1186 project: &Entity<Project>,
1187 ) {
1188 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1189 project_state.pending_predictions.clear();
1190 if let Some(prediction) = project_state.current_prediction.take() {
1191 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1192 }
1193 };
1194 }
1195
1196 fn did_show_current_prediction(
1197 &mut self,
1198 project: &Entity<Project>,
1199 display_type: edit_prediction_types::SuggestionDisplayType,
1200 cx: &mut Context<Self>,
1201 ) {
1202 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1203 return;
1204 };
1205
1206 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1207 return;
1208 };
1209
1210 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1211 let previous_shown_with = current_prediction.shown_with;
1212
1213 if previous_shown_with.is_none() || !is_jump {
1214 current_prediction.shown_with = Some(display_type);
1215 }
1216
1217 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1218
1219 if is_first_non_jump_show {
1220 current_prediction.was_shown = true;
1221 }
1222
1223 let display_type_changed = previous_shown_with != Some(display_type);
1224
1225 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1226 sweep_ai::edit_prediction_shown(
1227 &self.sweep_ai,
1228 self.client.clone(),
1229 ¤t_prediction.prediction,
1230 display_type,
1231 cx,
1232 );
1233 }
1234
1235 if is_first_non_jump_show {
1236 self.shown_predictions
1237 .push_front(current_prediction.prediction.clone());
1238 if self.shown_predictions.len() > 50 {
1239 let completion = self.shown_predictions.pop_back().unwrap();
1240 self.rated_predictions.remove(&completion.id);
1241 }
1242 }
1243 }
1244
1245 fn reject_prediction(
1246 &mut self,
1247 prediction_id: EditPredictionId,
1248 reason: EditPredictionRejectReason,
1249 was_shown: bool,
1250 ) {
1251 match self.edit_prediction_model {
1252 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1253 if self.custom_predict_edits_url.is_some() {
1254 return;
1255 }
1256 }
1257 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1258 }
1259
1260 self.reject_predictions_tx
1261 .unbounded_send(EditPredictionRejection {
1262 request_id: prediction_id.to_string(),
1263 reason,
1264 was_shown,
1265 })
1266 .log_err();
1267 }
1268
1269 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1270 self.projects
1271 .get(&project.entity_id())
1272 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1273 }
1274
1275 pub fn refresh_prediction_from_buffer(
1276 &mut self,
1277 project: Entity<Project>,
1278 buffer: Entity<Buffer>,
1279 position: language::Anchor,
1280 cx: &mut Context<Self>,
1281 ) {
1282 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1283 let Some(request_task) = this
1284 .update(cx, |this, cx| {
1285 this.request_prediction(
1286 &project,
1287 &buffer,
1288 position,
1289 PredictEditsRequestTrigger::Other,
1290 cx,
1291 )
1292 })
1293 .log_err()
1294 else {
1295 return Task::ready(anyhow::Ok(None));
1296 };
1297
1298 cx.spawn(async move |_cx| {
1299 request_task.await.map(|prediction_result| {
1300 prediction_result.map(|prediction_result| {
1301 (
1302 prediction_result,
1303 PredictionRequestedBy::Buffer(buffer.entity_id()),
1304 )
1305 })
1306 })
1307 })
1308 })
1309 }
1310
1311 pub fn refresh_prediction_from_diagnostics(
1312 &mut self,
1313 project: Entity<Project>,
1314 cx: &mut Context<Self>,
1315 ) {
1316 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1317 return;
1318 };
1319
1320 // Prefer predictions from buffer
1321 if project_state.current_prediction.is_some() {
1322 return;
1323 };
1324
1325 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1326 let Some((active_buffer, snapshot, cursor_point)) = this
1327 .read_with(cx, |this, cx| {
1328 let project_state = this.projects.get(&project.entity_id())?;
1329 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1330 let snapshot = buffer.read(cx).snapshot();
1331
1332 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1333 return None;
1334 }
1335
1336 let cursor_point = position
1337 .map(|pos| pos.to_point(&snapshot))
1338 .unwrap_or_default();
1339
1340 Some((buffer, snapshot, cursor_point))
1341 })
1342 .log_err()
1343 .flatten()
1344 else {
1345 return Task::ready(anyhow::Ok(None));
1346 };
1347
1348 cx.spawn(async move |cx| {
1349 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1350 active_buffer,
1351 &snapshot,
1352 Default::default(),
1353 cursor_point,
1354 &project,
1355 cx,
1356 )
1357 .await?
1358 else {
1359 return anyhow::Ok(None);
1360 };
1361
1362 let Some(prediction_result) = this
1363 .update(cx, |this, cx| {
1364 this.request_prediction(
1365 &project,
1366 &jump_buffer,
1367 jump_position,
1368 PredictEditsRequestTrigger::Diagnostics,
1369 cx,
1370 )
1371 })?
1372 .await?
1373 else {
1374 return anyhow::Ok(None);
1375 };
1376
1377 this.update(cx, |this, cx| {
1378 Some((
1379 if this
1380 .get_or_init_project(&project, cx)
1381 .current_prediction
1382 .is_none()
1383 {
1384 prediction_result
1385 } else {
1386 EditPredictionResult {
1387 id: prediction_result.id,
1388 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1389 }
1390 },
1391 PredictionRequestedBy::DiagnosticsUpdate,
1392 ))
1393 })
1394 })
1395 });
1396 }
1397
1398 fn predictions_enabled_at(
1399 snapshot: &BufferSnapshot,
1400 position: Option<language::Anchor>,
1401 cx: &App,
1402 ) -> bool {
1403 let file = snapshot.file();
1404 let all_settings = all_language_settings(file, cx);
1405 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1406 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1407 {
1408 return false;
1409 }
1410
1411 if let Some(last_position) = position {
1412 let settings = snapshot.settings_at(last_position, cx);
1413
1414 if !settings.edit_predictions_disabled_in.is_empty()
1415 && let Some(scope) = snapshot.language_scope_at(last_position)
1416 && let Some(scope_name) = scope.override_name()
1417 && settings
1418 .edit_predictions_disabled_in
1419 .iter()
1420 .any(|s| s == scope_name)
1421 {
1422 return false;
1423 }
1424 }
1425
1426 true
1427 }
1428
1429 #[cfg(not(test))]
1430 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1431 #[cfg(test)]
1432 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1433
1434 fn queue_prediction_refresh(
1435 &mut self,
1436 project: Entity<Project>,
1437 throttle_entity: EntityId,
1438 cx: &mut Context<Self>,
1439 do_refresh: impl FnOnce(
1440 WeakEntity<Self>,
1441 &mut AsyncApp,
1442 )
1443 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1444 + 'static,
1445 ) {
1446 let project_state = self.get_or_init_project(&project, cx);
1447 let pending_prediction_id = project_state.next_pending_prediction_id;
1448 project_state.next_pending_prediction_id += 1;
1449 let last_request = project_state.last_prediction_refresh;
1450
1451 let task = cx.spawn(async move |this, cx| {
1452 if let Some((last_entity, last_timestamp)) = last_request
1453 && throttle_entity == last_entity
1454 && let Some(timeout) =
1455 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1456 {
1457 cx.background_executor().timer(timeout).await;
1458 }
1459
1460 // If this task was cancelled before the throttle timeout expired,
1461 // do not perform a request.
1462 let mut is_cancelled = true;
1463 this.update(cx, |this, cx| {
1464 let project_state = this.get_or_init_project(&project, cx);
1465 if !project_state
1466 .cancelled_predictions
1467 .remove(&pending_prediction_id)
1468 {
1469 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1470 is_cancelled = false;
1471 }
1472 })
1473 .ok();
1474 if is_cancelled {
1475 return None;
1476 }
1477
1478 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1479 let new_prediction_id = new_prediction_result
1480 .as_ref()
1481 .map(|(prediction, _)| prediction.id.clone());
1482
1483 // When a prediction completes, remove it from the pending list, and cancel
1484 // any pending predictions that were enqueued before it.
1485 this.update(cx, |this, cx| {
1486 let project_state = this.get_or_init_project(&project, cx);
1487
1488 let is_cancelled = project_state
1489 .cancelled_predictions
1490 .remove(&pending_prediction_id);
1491
1492 let new_current_prediction = if !is_cancelled
1493 && let Some((prediction_result, requested_by)) = new_prediction_result
1494 {
1495 match prediction_result.prediction {
1496 Ok(prediction) => {
1497 let new_prediction = CurrentEditPrediction {
1498 requested_by,
1499 prediction,
1500 was_shown: false,
1501 shown_with: None,
1502 };
1503
1504 if let Some(current_prediction) =
1505 project_state.current_prediction.as_ref()
1506 {
1507 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1508 {
1509 this.reject_current_prediction(
1510 EditPredictionRejectReason::Replaced,
1511 &project,
1512 );
1513
1514 Some(new_prediction)
1515 } else {
1516 this.reject_prediction(
1517 new_prediction.prediction.id,
1518 EditPredictionRejectReason::CurrentPreferred,
1519 false,
1520 );
1521 None
1522 }
1523 } else {
1524 Some(new_prediction)
1525 }
1526 }
1527 Err(reject_reason) => {
1528 this.reject_prediction(prediction_result.id, reject_reason, false);
1529 None
1530 }
1531 }
1532 } else {
1533 None
1534 };
1535
1536 let project_state = this.get_or_init_project(&project, cx);
1537
1538 if let Some(new_prediction) = new_current_prediction {
1539 project_state.current_prediction = Some(new_prediction);
1540 }
1541
1542 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1543 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1544 if pending_prediction.id == pending_prediction_id {
1545 pending_predictions.remove(ix);
1546 for pending_prediction in pending_predictions.drain(0..ix) {
1547 project_state.cancel_pending_prediction(pending_prediction, cx)
1548 }
1549 break;
1550 }
1551 }
1552 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1553 cx.notify();
1554 })
1555 .ok();
1556
1557 new_prediction_id
1558 });
1559
1560 if project_state.pending_predictions.len() <= 1 {
1561 project_state.pending_predictions.push(PendingPrediction {
1562 id: pending_prediction_id,
1563 task,
1564 });
1565 } else if project_state.pending_predictions.len() == 2 {
1566 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1567 project_state.pending_predictions.push(PendingPrediction {
1568 id: pending_prediction_id,
1569 task,
1570 });
1571 project_state.cancel_pending_prediction(pending_prediction, cx);
1572 }
1573 }
1574
1575 pub fn request_prediction(
1576 &mut self,
1577 project: &Entity<Project>,
1578 active_buffer: &Entity<Buffer>,
1579 position: language::Anchor,
1580 trigger: PredictEditsRequestTrigger,
1581 cx: &mut Context<Self>,
1582 ) -> Task<Result<Option<EditPredictionResult>>> {
1583 self.request_prediction_internal(
1584 project.clone(),
1585 active_buffer.clone(),
1586 position,
1587 trigger,
1588 cx.has_flag::<Zeta2FeatureFlag>(),
1589 cx,
1590 )
1591 }
1592
1593 fn request_prediction_internal(
1594 &mut self,
1595 project: Entity<Project>,
1596 active_buffer: Entity<Buffer>,
1597 position: language::Anchor,
1598 trigger: PredictEditsRequestTrigger,
1599 allow_jump: bool,
1600 cx: &mut Context<Self>,
1601 ) -> Task<Result<Option<EditPredictionResult>>> {
1602 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1603
1604 self.get_or_init_project(&project, cx);
1605 let project_state = self.projects.get(&project.entity_id()).unwrap();
1606 let stored_events = project_state.events(cx);
1607 let has_events = !stored_events.is_empty();
1608 let events: Vec<Arc<zeta_prompt::Event>> =
1609 stored_events.into_iter().map(|e| e.event).collect();
1610 let debug_tx = project_state.debug_tx.clone();
1611
1612 let snapshot = active_buffer.read(cx).snapshot();
1613 let cursor_point = position.to_point(&snapshot);
1614 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1615 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1616 let diagnostic_search_range =
1617 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1618
1619 let related_files = if self.use_context {
1620 self.context_for_project(&project, cx)
1621 } else {
1622 Vec::new().into()
1623 };
1624
1625 let inputs = EditPredictionModelInput {
1626 project: project.clone(),
1627 buffer: active_buffer.clone(),
1628 snapshot: snapshot.clone(),
1629 position,
1630 events,
1631 related_files,
1632 recent_paths: project_state.recent_paths.clone(),
1633 trigger,
1634 diagnostic_search_range: diagnostic_search_range.clone(),
1635 debug_tx,
1636 };
1637
1638 let can_collect_example = snapshot
1639 .file()
1640 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1641 && self.can_collect_events(&inputs.events, cx);
1642
1643 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1644 let events_for_capture =
1645 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1646 if let Some(example_task) = capture_example::capture_example(
1647 project.clone(),
1648 active_buffer.clone(),
1649 position,
1650 events_for_capture,
1651 cx,
1652 ) {
1653 cx.spawn(async move |_this, _cx| {
1654 let example = example_task.await?;
1655 telemetry::event!("Edit Prediction Example Captured", example = example);
1656 anyhow::Ok(())
1657 })
1658 .detach_and_log_err(cx);
1659 }
1660 }
1661 let task = match self.edit_prediction_model {
1662 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1663 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1664 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1665 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1666 };
1667
1668 cx.spawn(async move |this, cx| {
1669 let prediction = task.await?;
1670
1671 if prediction.is_none() && allow_jump {
1672 let cursor_point = position.to_point(&snapshot);
1673 if has_events
1674 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1675 active_buffer.clone(),
1676 &snapshot,
1677 diagnostic_search_range,
1678 cursor_point,
1679 &project,
1680 cx,
1681 )
1682 .await?
1683 {
1684 return this
1685 .update(cx, |this, cx| {
1686 this.request_prediction_internal(
1687 project,
1688 jump_buffer,
1689 jump_position,
1690 trigger,
1691 false,
1692 cx,
1693 )
1694 })?
1695 .await;
1696 }
1697
1698 return anyhow::Ok(None);
1699 }
1700
1701 Ok(prediction)
1702 })
1703 }
1704
1705 async fn next_diagnostic_location(
1706 active_buffer: Entity<Buffer>,
1707 active_buffer_snapshot: &BufferSnapshot,
1708 active_buffer_diagnostic_search_range: Range<Point>,
1709 active_buffer_cursor_point: Point,
1710 project: &Entity<Project>,
1711 cx: &mut AsyncApp,
1712 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1713 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1714 let mut jump_location = active_buffer_snapshot
1715 .diagnostic_groups(None)
1716 .into_iter()
1717 .filter_map(|(_, group)| {
1718 let range = &group.entries[group.primary_ix]
1719 .range
1720 .to_point(&active_buffer_snapshot);
1721 if range.overlaps(&active_buffer_diagnostic_search_range) {
1722 None
1723 } else {
1724 Some(range.start)
1725 }
1726 })
1727 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1728 .map(|position| {
1729 (
1730 active_buffer.clone(),
1731 active_buffer_snapshot.anchor_before(position),
1732 )
1733 });
1734
1735 if jump_location.is_none() {
1736 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1737 let file = buffer.file()?;
1738
1739 Some(ProjectPath {
1740 worktree_id: file.worktree_id(cx),
1741 path: file.path().clone(),
1742 })
1743 })?;
1744
1745 let buffer_task = project.update(cx, |project, cx| {
1746 let (path, _, _) = project
1747 .diagnostic_summaries(false, cx)
1748 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1749 .max_by_key(|(path, _, _)| {
1750 // find the buffer with errors that shares most parent directories
1751 path.path
1752 .components()
1753 .zip(
1754 active_buffer_path
1755 .as_ref()
1756 .map(|p| p.path.components())
1757 .unwrap_or_default(),
1758 )
1759 .take_while(|(a, b)| a == b)
1760 .count()
1761 })?;
1762
1763 Some(project.open_buffer(path, cx))
1764 })?;
1765
1766 if let Some(buffer_task) = buffer_task {
1767 let closest_buffer = buffer_task.await?;
1768
1769 jump_location = closest_buffer
1770 .read_with(cx, |buffer, _cx| {
1771 buffer
1772 .buffer_diagnostics(None)
1773 .into_iter()
1774 .min_by_key(|entry| entry.diagnostic.severity)
1775 .map(|entry| entry.range.start)
1776 })?
1777 .map(|position| (closest_buffer, position));
1778 }
1779 }
1780
1781 anyhow::Ok(jump_location)
1782 }
1783
1784 async fn send_raw_llm_request(
1785 request: open_ai::Request,
1786 client: Arc<Client>,
1787 llm_token: LlmApiToken,
1788 app_version: Version,
1789 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1790 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1791 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1792 let url = client
1793 .http_client()
1794 .build_zed_llm_url("/predict_edits/raw", &[])?;
1795
1796 #[cfg(feature = "cli-support")]
1797 let cache_key = if let Some(cache) = eval_cache {
1798 use collections::FxHasher;
1799 use std::hash::{Hash, Hasher};
1800
1801 let mut hasher = FxHasher::default();
1802 url.hash(&mut hasher);
1803 let request_str = serde_json::to_string_pretty(&request)?;
1804 request_str.hash(&mut hasher);
1805 let hash = hasher.finish();
1806
1807 let key = (eval_cache_kind, hash);
1808 if let Some(response_str) = cache.read(key) {
1809 return Ok((serde_json::from_str(&response_str)?, None));
1810 }
1811
1812 Some((cache, request_str, key))
1813 } else {
1814 None
1815 };
1816
1817 let (response, usage) = Self::send_api_request(
1818 |builder| {
1819 let req = builder
1820 .uri(url.as_ref())
1821 .body(serde_json::to_string(&request)?.into());
1822 Ok(req?)
1823 },
1824 client,
1825 llm_token,
1826 app_version,
1827 true,
1828 )
1829 .await?;
1830
1831 #[cfg(feature = "cli-support")]
1832 if let Some((cache, request, key)) = cache_key {
1833 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1834 }
1835
1836 Ok((response, usage))
1837 }
1838
1839 fn handle_api_response<T>(
1840 this: &WeakEntity<Self>,
1841 response: Result<(T, Option<EditPredictionUsage>)>,
1842 cx: &mut gpui::AsyncApp,
1843 ) -> Result<T> {
1844 match response {
1845 Ok((data, usage)) => {
1846 if let Some(usage) = usage {
1847 this.update(cx, |this, cx| {
1848 this.user_store.update(cx, |user_store, cx| {
1849 user_store.update_edit_prediction_usage(usage, cx);
1850 });
1851 })
1852 .ok();
1853 }
1854 Ok(data)
1855 }
1856 Err(err) => {
1857 if err.is::<ZedUpdateRequiredError>() {
1858 cx.update(|cx| {
1859 this.update(cx, |this, _cx| {
1860 this.update_required = true;
1861 })
1862 .ok();
1863
1864 let error_message: SharedString = err.to_string().into();
1865 show_app_notification(
1866 NotificationId::unique::<ZedUpdateRequiredError>(),
1867 cx,
1868 move |cx| {
1869 cx.new(|cx| {
1870 ErrorMessagePrompt::new(error_message.clone(), cx)
1871 .with_link_button("Update Zed", "https://zed.dev/releases")
1872 })
1873 },
1874 );
1875 })
1876 .ok();
1877 }
1878 Err(err)
1879 }
1880 }
1881 }
1882
1883 async fn send_api_request<Res>(
1884 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1885 client: Arc<Client>,
1886 llm_token: LlmApiToken,
1887 app_version: Version,
1888 require_auth: bool,
1889 ) -> Result<(Res, Option<EditPredictionUsage>)>
1890 where
1891 Res: DeserializeOwned,
1892 {
1893 let http_client = client.http_client();
1894
1895 let mut token = if require_auth {
1896 Some(llm_token.acquire(&client).await?)
1897 } else {
1898 llm_token.acquire(&client).await.ok()
1899 };
1900 let mut did_retry = false;
1901
1902 loop {
1903 let request_builder = http_client::Request::builder().method(Method::POST);
1904
1905 let mut request_builder = request_builder
1906 .header("Content-Type", "application/json")
1907 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
1908
1909 // Only add Authorization header if we have a token
1910 if let Some(ref token_value) = token {
1911 request_builder =
1912 request_builder.header("Authorization", format!("Bearer {}", token_value));
1913 }
1914
1915 let request = build(request_builder)?;
1916
1917 let mut response = http_client.send(request).await?;
1918
1919 if let Some(minimum_required_version) = response
1920 .headers()
1921 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1922 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1923 {
1924 anyhow::ensure!(
1925 app_version >= minimum_required_version,
1926 ZedUpdateRequiredError {
1927 minimum_version: minimum_required_version
1928 }
1929 );
1930 }
1931
1932 if response.status().is_success() {
1933 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1934
1935 let mut body = Vec::new();
1936 response.body_mut().read_to_end(&mut body).await?;
1937 return Ok((serde_json::from_slice(&body)?, usage));
1938 } else if !did_retry
1939 && token.is_some()
1940 && response
1941 .headers()
1942 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1943 .is_some()
1944 {
1945 did_retry = true;
1946 token = Some(llm_token.refresh(&client).await?);
1947 } else {
1948 let mut body = String::new();
1949 response.body_mut().read_to_string(&mut body).await?;
1950 anyhow::bail!(
1951 "Request failed with status: {:?}\nBody: {}",
1952 response.status(),
1953 body
1954 );
1955 }
1956 }
1957 }
1958
1959 pub fn refresh_context(
1960 &mut self,
1961 project: &Entity<Project>,
1962 buffer: &Entity<language::Buffer>,
1963 cursor_position: language::Anchor,
1964 cx: &mut Context<Self>,
1965 ) {
1966 if self.use_context {
1967 self.get_or_init_project(project, cx)
1968 .context
1969 .update(cx, |store, cx| {
1970 store.refresh(buffer.clone(), cursor_position, cx);
1971 });
1972 }
1973 }
1974
1975 #[cfg(feature = "cli-support")]
1976 pub fn set_context_for_buffer(
1977 &mut self,
1978 project: &Entity<Project>,
1979 related_files: Vec<RelatedFile>,
1980 cx: &mut Context<Self>,
1981 ) {
1982 self.get_or_init_project(project, cx)
1983 .context
1984 .update(cx, |store, _| {
1985 store.set_related_files(related_files);
1986 });
1987 }
1988
1989 fn is_file_open_source(
1990 &self,
1991 project: &Entity<Project>,
1992 file: &Arc<dyn File>,
1993 cx: &App,
1994 ) -> bool {
1995 if !file.is_local() || file.is_private() {
1996 return false;
1997 }
1998 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1999 return false;
2000 };
2001 project_state
2002 .license_detection_watchers
2003 .get(&file.worktree_id(cx))
2004 .as_ref()
2005 .is_some_and(|watcher| watcher.is_project_open_source())
2006 }
2007
2008 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2009 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2010 }
2011
2012 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2013 if !self.data_collection_choice.is_enabled(cx) {
2014 return false;
2015 }
2016 events.iter().all(|event| {
2017 matches!(
2018 event.as_ref(),
2019 zeta_prompt::Event::BufferChange {
2020 in_open_source_repo: true,
2021 ..
2022 }
2023 )
2024 })
2025 }
2026
2027 fn load_data_collection_choice() -> DataCollectionChoice {
2028 let choice = KEY_VALUE_STORE
2029 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2030 .log_err()
2031 .flatten();
2032
2033 match choice.as_deref() {
2034 Some("true") => DataCollectionChoice::Enabled,
2035 Some("false") => DataCollectionChoice::Disabled,
2036 Some(_) => {
2037 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2038 DataCollectionChoice::NotAnswered
2039 }
2040 None => DataCollectionChoice::NotAnswered,
2041 }
2042 }
2043
2044 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2045 self.data_collection_choice = self.data_collection_choice.toggle();
2046 let new_choice = self.data_collection_choice;
2047 let is_enabled = new_choice.is_enabled(cx);
2048 db::write_and_log(cx, move || {
2049 KEY_VALUE_STORE.write_kvp(
2050 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2051 is_enabled.to_string(),
2052 )
2053 });
2054 }
2055
2056 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2057 self.shown_predictions.iter()
2058 }
2059
2060 pub fn shown_completions_len(&self) -> usize {
2061 self.shown_predictions.len()
2062 }
2063
2064 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2065 self.rated_predictions.contains(id)
2066 }
2067
2068 pub fn rate_prediction(
2069 &mut self,
2070 prediction: &EditPrediction,
2071 rating: EditPredictionRating,
2072 feedback: String,
2073 cx: &mut Context<Self>,
2074 ) {
2075 self.rated_predictions.insert(prediction.id.clone());
2076 telemetry::event!(
2077 "Edit Prediction Rated",
2078 rating,
2079 inputs = prediction.inputs,
2080 output = prediction
2081 .edit_preview
2082 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2083 feedback
2084 );
2085 self.client.telemetry().flush_events().detach();
2086 cx.notify();
2087 }
2088
2089 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2090 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2091 && all_language_settings(None, cx).edit_predictions.use_context;
2092 }
2093}
2094
2095#[derive(Error, Debug)]
2096#[error(
2097 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2098)]
2099pub struct ZedUpdateRequiredError {
2100 minimum_version: Version,
2101}
2102
2103#[cfg(feature = "cli-support")]
2104pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2105
2106#[cfg(feature = "cli-support")]
2107#[derive(Debug, Clone, Copy, PartialEq)]
2108pub enum EvalCacheEntryKind {
2109 Context,
2110 Search,
2111 Prediction,
2112}
2113
2114#[cfg(feature = "cli-support")]
2115impl std::fmt::Display for EvalCacheEntryKind {
2116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2117 match self {
2118 EvalCacheEntryKind::Search => write!(f, "search"),
2119 EvalCacheEntryKind::Context => write!(f, "context"),
2120 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2121 }
2122 }
2123}
2124
2125#[cfg(feature = "cli-support")]
2126pub trait EvalCache: Send + Sync {
2127 fn read(&self, key: EvalCacheKey) -> Option<String>;
2128 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2129}
2130
2131#[derive(Debug, Clone, Copy)]
2132pub enum DataCollectionChoice {
2133 NotAnswered,
2134 Enabled,
2135 Disabled,
2136}
2137
2138impl DataCollectionChoice {
2139 pub fn is_enabled(self, cx: &App) -> bool {
2140 if cx.is_staff() {
2141 return true;
2142 }
2143 match self {
2144 Self::Enabled => true,
2145 Self::NotAnswered | Self::Disabled => false,
2146 }
2147 }
2148
2149 #[must_use]
2150 pub fn toggle(&self) -> DataCollectionChoice {
2151 match self {
2152 Self::Enabled => Self::Disabled,
2153 Self::Disabled => Self::Enabled,
2154 Self::NotAnswered => Self::Enabled,
2155 }
2156 }
2157}
2158
2159impl From<bool> for DataCollectionChoice {
2160 fn from(value: bool) -> Self {
2161 match value {
2162 true => DataCollectionChoice::Enabled,
2163 false => DataCollectionChoice::Disabled,
2164 }
2165 }
2166}
2167
2168struct ZedPredictUpsell;
2169
2170impl Dismissable for ZedPredictUpsell {
2171 const KEY: &'static str = "dismissed-edit-predict-upsell";
2172
2173 fn dismissed() -> bool {
2174 // To make this backwards compatible with older versions of Zed, we
2175 // check if the user has seen the previous Edit Prediction Onboarding
2176 // before, by checking the data collection choice which was written to
2177 // the database once the user clicked on "Accept and Enable"
2178 if KEY_VALUE_STORE
2179 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2180 .log_err()
2181 .is_some_and(|s| s.is_some())
2182 {
2183 return true;
2184 }
2185
2186 KEY_VALUE_STORE
2187 .read_kvp(Self::KEY)
2188 .log_err()
2189 .is_some_and(|s| s.is_some())
2190 }
2191}
2192
2193pub fn should_show_upsell_modal() -> bool {
2194 !ZedPredictUpsell::dismissed()
2195}
2196
2197pub fn init(cx: &mut App) {
2198 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2199 workspace.register_action(
2200 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2201 ZedPredictModal::toggle(
2202 workspace,
2203 workspace.user_store().clone(),
2204 workspace.client().clone(),
2205 window,
2206 cx,
2207 )
2208 },
2209 );
2210
2211 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2212 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2213 settings
2214 .project
2215 .all_languages
2216 .features
2217 .get_or_insert_default()
2218 .edit_prediction_provider = Some(EditPredictionProvider::None)
2219 });
2220 });
2221 })
2222 .detach();
2223}