1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
7 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
8 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
9};
10use collections::{HashMap, HashSet};
11use db::kvp::{Dismissable, KEY_VALUE_STORE};
12use edit_prediction_context::EditPredictionExcerptOptions;
13use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
14use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
15use futures::{
16 AsyncReadExt as _, FutureExt as _, StreamExt as _,
17 channel::mpsc::{self, UnboundedReceiver},
18 select_biased,
19};
20use gpui::BackgroundExecutor;
21use gpui::http_client::Url;
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
24 http_client::{self, AsyncBody, Method},
25 prelude::*,
26};
27use language::language_settings::all_language_settings;
28use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
29use language::{BufferSnapshot, OffsetRangeExt};
30use language_model::{LlmApiToken, RefreshLlmTokenListener};
31use project::{Project, ProjectPath, WorktreeId};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
36use std::collections::{VecDeque, hash_map};
37use text::Edit;
38use workspace::Workspace;
39
40use std::ops::Range;
41use std::path::Path;
42use std::rc::Rc;
43use std::str::FromStr as _;
44use std::sync::{Arc, LazyLock};
45use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
46use std::{env, mem};
47use thiserror::Error;
48use util::{RangeExt as _, ResultExt as _};
49use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
50
51pub mod cursor_excerpt;
52pub mod example_spec;
53mod license_detection;
54pub mod mercury;
55mod onboarding_modal;
56pub mod open_ai_response;
57mod prediction;
58pub mod sweep_ai;
59
60pub mod udiff;
61
62mod capture_example;
63mod zed_edit_prediction_delegate;
64pub mod zeta1;
65pub mod zeta2;
66
67#[cfg(test)]
68mod edit_prediction_tests;
69
70use crate::capture_example::should_sample_edit_prediction_example_capture;
71use crate::license_detection::LicenseDetectionWatcher;
72use crate::mercury::Mercury;
73use crate::onboarding_modal::ZedPredictModal;
74pub use crate::prediction::EditPrediction;
75pub use crate::prediction::EditPredictionId;
76use crate::prediction::EditPredictionResult;
77pub use crate::sweep_ai::SweepAi;
78pub use capture_example::capture_example;
79pub use language_model::ApiKeyState;
80pub use telemetry_events::EditPredictionRating;
81pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
82
83actions!(
84 edit_prediction,
85 [
86 /// Resets the edit prediction onboarding state.
87 ResetOnboarding,
88 /// Clears the edit prediction history.
89 ClearHistory,
90 ]
91);
92
93/// Maximum number of events to track.
94const EVENT_COUNT_MAX: usize = 6;
95const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
96const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
97const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
98const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
99
100pub struct SweepFeatureFlag;
101
102impl FeatureFlag for SweepFeatureFlag {
103 const NAME: &str = "sweep-ai";
104}
105
106pub struct MercuryFeatureFlag;
107
108impl FeatureFlag for MercuryFeatureFlag {
109 const NAME: &str = "mercury";
110}
111
112pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
113 context: EditPredictionExcerptOptions {
114 max_bytes: 512,
115 min_bytes: 128,
116 target_before_cursor_over_total_bytes: 0.5,
117 },
118 prompt_format: PromptFormat::DEFAULT,
119};
120
121static USE_OLLAMA: LazyLock<bool> =
122 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
123
124static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
125 match env::var("ZED_ZETA2_MODEL").as_deref() {
126 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
127 Ok(model) => model,
128 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
129 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
130 }
131 .to_string()
132});
133
134pub struct Zeta2FeatureFlag;
135
136impl FeatureFlag for Zeta2FeatureFlag {
137 const NAME: &'static str = "zeta2";
138
139 fn enabled_for_staff() -> bool {
140 true
141 }
142}
143
144pub struct EditPredictionExampleCaptureFeatureFlag;
145
146impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
147 const NAME: &'static str = "edit-prediction-example-capture";
148
149 fn enabled_for_staff() -> bool {
150 true
151 }
152}
153
154#[derive(Clone)]
155struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
156
157impl Global for EditPredictionStoreGlobal {}
158
159pub struct EditPredictionStore {
160 client: Arc<Client>,
161 user_store: Entity<UserStore>,
162 llm_token: LlmApiToken,
163 _llm_token_subscription: Subscription,
164 projects: HashMap<EntityId, ProjectState>,
165 use_context: bool,
166 options: ZetaOptions,
167 update_required: bool,
168 #[cfg(feature = "cli-support")]
169 eval_cache: Option<Arc<dyn EvalCache>>,
170 edit_prediction_model: EditPredictionModel,
171 pub sweep_ai: SweepAi,
172 pub mercury: Mercury,
173 data_collection_choice: DataCollectionChoice,
174 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
175 shown_predictions: VecDeque<EditPrediction>,
176 rated_predictions: HashSet<EditPredictionId>,
177 custom_predict_edits_url: Option<Arc<Url>>,
178}
179
180#[derive(Copy, Clone, Default, PartialEq, Eq)]
181pub enum EditPredictionModel {
182 #[default]
183 Zeta1,
184 Zeta2,
185 Sweep,
186 Mercury,
187}
188
189pub struct EditPredictionModelInput {
190 project: Entity<Project>,
191 buffer: Entity<Buffer>,
192 snapshot: BufferSnapshot,
193 position: Anchor,
194 events: Vec<Arc<zeta_prompt::Event>>,
195 related_files: Arc<[RelatedFile]>,
196 recent_paths: VecDeque<ProjectPath>,
197 trigger: PredictEditsRequestTrigger,
198 diagnostic_search_range: Range<Point>,
199 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
200 pub user_actions: Vec<UserActionRecord>,
201}
202
203#[derive(Debug, Clone, PartialEq)]
204pub struct ZetaOptions {
205 pub context: EditPredictionExcerptOptions,
206 pub prompt_format: predict_edits_v3::PromptFormat,
207}
208
209#[derive(Debug)]
210pub enum DebugEvent {
211 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
212 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
213 EditPredictionStarted(EditPredictionStartedDebugEvent),
214 EditPredictionFinished(EditPredictionFinishedDebugEvent),
215}
216
217#[derive(Debug)]
218pub struct ContextRetrievalStartedDebugEvent {
219 pub project_entity_id: EntityId,
220 pub timestamp: Instant,
221 pub search_prompt: String,
222}
223
224#[derive(Debug)]
225pub struct ContextRetrievalFinishedDebugEvent {
226 pub project_entity_id: EntityId,
227 pub timestamp: Instant,
228 pub metadata: Vec<(&'static str, SharedString)>,
229}
230
231#[derive(Debug)]
232pub struct EditPredictionStartedDebugEvent {
233 pub buffer: WeakEntity<Buffer>,
234 pub position: Anchor,
235 pub prompt: Option<String>,
236}
237
238#[derive(Debug)]
239pub struct EditPredictionFinishedDebugEvent {
240 pub buffer: WeakEntity<Buffer>,
241 pub position: Anchor,
242 pub model_output: Option<String>,
243}
244
245pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
246
247const USER_ACTION_HISTORY_SIZE: usize = 16;
248
249#[derive(Clone, Debug)]
250pub struct UserActionRecord {
251 pub action_type: UserActionType,
252 pub buffer_id: EntityId,
253 pub line_number: u32,
254 pub offset: usize,
255 pub timestamp_epoch_ms: u64,
256}
257
258#[derive(Clone, Copy, Debug, PartialEq, Eq)]
259pub enum UserActionType {
260 InsertChar,
261 InsertSelection,
262 DeleteChar,
263 DeleteSelection,
264 CursorMovement,
265}
266
267/// An event with associated metadata for reconstructing buffer state.
268#[derive(Clone)]
269pub struct StoredEvent {
270 pub event: Arc<zeta_prompt::Event>,
271 pub old_snapshot: TextBufferSnapshot,
272}
273
274struct ProjectState {
275 events: VecDeque<StoredEvent>,
276 last_event: Option<LastEvent>,
277 recent_paths: VecDeque<ProjectPath>,
278 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
279 current_prediction: Option<CurrentEditPrediction>,
280 next_pending_prediction_id: usize,
281 pending_predictions: ArrayVec<PendingPrediction, 2>,
282 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
283 last_prediction_refresh: Option<(EntityId, Instant)>,
284 cancelled_predictions: HashSet<usize>,
285 context: Entity<RelatedExcerptStore>,
286 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
287 user_actions: VecDeque<UserActionRecord>,
288 _subscription: gpui::Subscription,
289}
290
291impl ProjectState {
292 fn record_user_action(&mut self, action: UserActionRecord) {
293 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
294 self.user_actions.pop_front();
295 }
296 self.user_actions.push_back(action);
297 }
298
299 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
300 self.events
301 .iter()
302 .cloned()
303 .chain(
304 self.last_event
305 .as_ref()
306 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
307 )
308 .collect()
309 }
310
311 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
312 self.events
313 .iter()
314 .cloned()
315 .chain(self.last_event.as_ref().iter().flat_map(|event| {
316 let (one, two) = event.split_by_pause();
317 let one = one.finalize(&self.license_detection_watchers, cx);
318 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
319 one.into_iter().chain(two)
320 }))
321 .collect()
322 }
323
324 fn cancel_pending_prediction(
325 &mut self,
326 pending_prediction: PendingPrediction,
327 cx: &mut Context<EditPredictionStore>,
328 ) {
329 self.cancelled_predictions.insert(pending_prediction.id);
330
331 cx.spawn(async move |this, cx| {
332 let Some(prediction_id) = pending_prediction.task.await else {
333 return;
334 };
335
336 this.update(cx, |this, _cx| {
337 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
338 })
339 .ok();
340 })
341 .detach()
342 }
343
344 fn active_buffer(
345 &self,
346 project: &Entity<Project>,
347 cx: &App,
348 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
349 let project = project.read(cx);
350 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
351 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
352 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
353 Some((active_buffer, registered_buffer.last_position))
354 }
355}
356
357#[derive(Debug, Clone)]
358struct CurrentEditPrediction {
359 pub requested_by: PredictionRequestedBy,
360 pub prediction: EditPrediction,
361 pub was_shown: bool,
362 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
363}
364
365impl CurrentEditPrediction {
366 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
367 let Some(new_edits) = self
368 .prediction
369 .interpolate(&self.prediction.buffer.read(cx))
370 else {
371 return false;
372 };
373
374 if self.prediction.buffer != old_prediction.prediction.buffer {
375 return true;
376 }
377
378 let Some(old_edits) = old_prediction
379 .prediction
380 .interpolate(&old_prediction.prediction.buffer.read(cx))
381 else {
382 return true;
383 };
384
385 let requested_by_buffer_id = self.requested_by.buffer_id();
386
387 // This reduces the occurrence of UI thrash from replacing edits
388 //
389 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
390 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
391 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
392 && old_edits.len() == 1
393 && new_edits.len() == 1
394 {
395 let (old_range, old_text) = &old_edits[0];
396 let (new_range, new_text) = &new_edits[0];
397 new_range == old_range && new_text.starts_with(old_text.as_ref())
398 } else {
399 true
400 }
401 }
402}
403
404#[derive(Debug, Clone)]
405enum PredictionRequestedBy {
406 DiagnosticsUpdate,
407 Buffer(EntityId),
408}
409
410impl PredictionRequestedBy {
411 pub fn buffer_id(&self) -> Option<EntityId> {
412 match self {
413 PredictionRequestedBy::DiagnosticsUpdate => None,
414 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
415 }
416 }
417}
418
419#[derive(Debug)]
420struct PendingPrediction {
421 id: usize,
422 task: Task<Option<EditPredictionId>>,
423}
424
425/// A prediction from the perspective of a buffer.
426#[derive(Debug)]
427enum BufferEditPrediction<'a> {
428 Local { prediction: &'a EditPrediction },
429 Jump { prediction: &'a EditPrediction },
430}
431
432#[cfg(test)]
433impl std::ops::Deref for BufferEditPrediction<'_> {
434 type Target = EditPrediction;
435
436 fn deref(&self) -> &Self::Target {
437 match self {
438 BufferEditPrediction::Local { prediction } => prediction,
439 BufferEditPrediction::Jump { prediction } => prediction,
440 }
441 }
442}
443
444struct RegisteredBuffer {
445 file: Option<Arc<dyn File>>,
446 snapshot: TextBufferSnapshot,
447 last_position: Option<Anchor>,
448 _subscriptions: [gpui::Subscription; 2],
449}
450
451#[derive(Clone)]
452struct LastEvent {
453 old_snapshot: TextBufferSnapshot,
454 new_snapshot: TextBufferSnapshot,
455 old_file: Option<Arc<dyn File>>,
456 new_file: Option<Arc<dyn File>>,
457 edit_range: Option<Range<Anchor>>,
458 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
459 last_edit_time: Option<Instant>,
460}
461
462impl LastEvent {
463 pub fn finalize(
464 &self,
465 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
466 cx: &App,
467 ) -> Option<StoredEvent> {
468 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
469 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
470
471 let in_open_source_repo =
472 [self.new_file.as_ref(), self.old_file.as_ref()]
473 .iter()
474 .all(|file| {
475 file.is_some_and(|file| {
476 license_detection_watchers
477 .get(&file.worktree_id(cx))
478 .is_some_and(|watcher| watcher.is_project_open_source())
479 })
480 });
481
482 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
483
484 if path == old_path && diff.is_empty() {
485 None
486 } else {
487 Some(StoredEvent {
488 event: Arc::new(zeta_prompt::Event::BufferChange {
489 old_path,
490 path,
491 diff,
492 in_open_source_repo,
493 // TODO: Actually detect if this edit was predicted or not
494 predicted: false,
495 }),
496 old_snapshot: self.old_snapshot.clone(),
497 })
498 }
499 }
500
501 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
502 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
503 return (self.clone(), None);
504 };
505
506 let before = LastEvent {
507 old_snapshot: self.old_snapshot.clone(),
508 new_snapshot: boundary_snapshot.clone(),
509 old_file: self.old_file.clone(),
510 new_file: self.new_file.clone(),
511 edit_range: None,
512 snapshot_after_last_editing_pause: None,
513 last_edit_time: self.last_edit_time,
514 };
515
516 let after = LastEvent {
517 old_snapshot: boundary_snapshot.clone(),
518 new_snapshot: self.new_snapshot.clone(),
519 old_file: self.old_file.clone(),
520 new_file: self.new_file.clone(),
521 edit_range: None,
522 snapshot_after_last_editing_pause: None,
523 last_edit_time: self.last_edit_time,
524 };
525
526 (before, Some(after))
527 }
528}
529
530pub(crate) fn compute_diff_between_snapshots(
531 old_snapshot: &TextBufferSnapshot,
532 new_snapshot: &TextBufferSnapshot,
533) -> Option<String> {
534 let edits: Vec<Edit<usize>> = new_snapshot
535 .edits_since::<usize>(&old_snapshot.version)
536 .collect();
537
538 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
539
540 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
541 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
542 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
543 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
544
545 const CONTEXT_LINES: u32 = 3;
546
547 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
548 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
549 let old_context_end_row =
550 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
551 let new_context_end_row =
552 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
553
554 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
555 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
556 let old_end_line_offset = old_snapshot
557 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
558 let new_end_line_offset = new_snapshot
559 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
560 let old_edit_range = old_start_line_offset..old_end_line_offset;
561 let new_edit_range = new_start_line_offset..new_end_line_offset;
562
563 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
564 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
565
566 let diff = language::unified_diff_with_offsets(
567 &old_region_text,
568 &new_region_text,
569 old_context_start_row,
570 new_context_start_row,
571 );
572
573 Some(diff)
574}
575
576fn buffer_path_with_id_fallback(
577 file: Option<&Arc<dyn File>>,
578 snapshot: &TextBufferSnapshot,
579 cx: &App,
580) -> Arc<Path> {
581 if let Some(file) = file {
582 file.full_path(cx).into()
583 } else {
584 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
585 }
586}
587
588impl EditPredictionStore {
589 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
590 cx.try_global::<EditPredictionStoreGlobal>()
591 .map(|global| global.0.clone())
592 }
593
594 pub fn global(
595 client: &Arc<Client>,
596 user_store: &Entity<UserStore>,
597 cx: &mut App,
598 ) -> Entity<Self> {
599 cx.try_global::<EditPredictionStoreGlobal>()
600 .map(|global| global.0.clone())
601 .unwrap_or_else(|| {
602 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
603 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
604 ep_store
605 })
606 }
607
608 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
609 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
610 let data_collection_choice = Self::load_data_collection_choice();
611
612 let llm_token = LlmApiToken::default();
613
614 let (reject_tx, reject_rx) = mpsc::unbounded();
615 cx.background_spawn({
616 let client = client.clone();
617 let llm_token = llm_token.clone();
618 let app_version = AppVersion::global(cx);
619 let background_executor = cx.background_executor().clone();
620 async move {
621 Self::handle_rejected_predictions(
622 reject_rx,
623 client,
624 llm_token,
625 app_version,
626 background_executor,
627 )
628 .await
629 }
630 })
631 .detach();
632
633 let mut this = Self {
634 projects: HashMap::default(),
635 client,
636 user_store,
637 options: DEFAULT_OPTIONS,
638 use_context: false,
639 llm_token,
640 _llm_token_subscription: cx.subscribe(
641 &refresh_llm_token_listener,
642 |this, _listener, _event, cx| {
643 let client = this.client.clone();
644 let llm_token = this.llm_token.clone();
645 cx.spawn(async move |_this, _cx| {
646 llm_token.refresh(&client).await?;
647 anyhow::Ok(())
648 })
649 .detach_and_log_err(cx);
650 },
651 ),
652 update_required: false,
653 #[cfg(feature = "cli-support")]
654 eval_cache: None,
655 edit_prediction_model: EditPredictionModel::Zeta2,
656 sweep_ai: SweepAi::new(cx),
657 mercury: Mercury::new(cx),
658 data_collection_choice,
659 reject_predictions_tx: reject_tx,
660 rated_predictions: Default::default(),
661 shown_predictions: Default::default(),
662 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
663 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
664 Err(_) => {
665 if *USE_OLLAMA {
666 Some(
667 Url::parse("http://localhost:11434/v1/chat/completions")
668 .unwrap()
669 .into(),
670 )
671 } else {
672 None
673 }
674 }
675 },
676 };
677
678 this.configure_context_retrieval(cx);
679 let weak_this = cx.weak_entity();
680 cx.on_flags_ready(move |_, cx| {
681 weak_this
682 .update(cx, |this, cx| this.configure_context_retrieval(cx))
683 .ok();
684 })
685 .detach();
686 cx.observe_global::<SettingsStore>(|this, cx| {
687 this.configure_context_retrieval(cx);
688 })
689 .detach();
690
691 this
692 }
693
694 #[cfg(test)]
695 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
696 self.custom_predict_edits_url = Some(url.into());
697 }
698
699 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
700 self.edit_prediction_model = model;
701 }
702
703 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
704 self.sweep_ai.api_token.read(cx).has_key()
705 }
706
707 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
708 self.mercury.api_token.read(cx).has_key()
709 }
710
711 #[cfg(feature = "cli-support")]
712 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
713 self.eval_cache = Some(cache);
714 }
715
716 pub fn options(&self) -> &ZetaOptions {
717 &self.options
718 }
719
720 pub fn set_options(&mut self, options: ZetaOptions) {
721 self.options = options;
722 }
723
724 pub fn set_use_context(&mut self, use_context: bool) {
725 self.use_context = use_context;
726 }
727
728 pub fn clear_history(&mut self) {
729 for project_state in self.projects.values_mut() {
730 project_state.events.clear();
731 project_state.last_event.take();
732 }
733 }
734
735 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
736 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
737 project_state.events.clear();
738 project_state.last_event.take();
739 }
740 }
741
742 pub fn edit_history_for_project(
743 &self,
744 project: &Entity<Project>,
745 cx: &App,
746 ) -> Vec<StoredEvent> {
747 self.projects
748 .get(&project.entity_id())
749 .map(|project_state| project_state.events(cx))
750 .unwrap_or_default()
751 }
752
753 pub fn edit_history_for_project_with_pause_split_last_event(
754 &self,
755 project: &Entity<Project>,
756 cx: &App,
757 ) -> Vec<StoredEvent> {
758 self.projects
759 .get(&project.entity_id())
760 .map(|project_state| project_state.events_split_by_pause(cx))
761 .unwrap_or_default()
762 }
763
764 pub fn context_for_project<'a>(
765 &'a self,
766 project: &Entity<Project>,
767 cx: &'a App,
768 ) -> Arc<[RelatedFile]> {
769 self.projects
770 .get(&project.entity_id())
771 .map(|project| project.context.read(cx).related_files())
772 .unwrap_or_else(|| vec![].into())
773 }
774
775 pub fn context_for_project_with_buffers<'a>(
776 &'a self,
777 project: &Entity<Project>,
778 cx: &'a App,
779 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
780 self.projects
781 .get(&project.entity_id())
782 .map(|project| project.context.read(cx).related_files_with_buffers())
783 }
784
785 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
786 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
787 self.user_store.read(cx).edit_prediction_usage()
788 } else {
789 None
790 }
791 }
792
793 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
794 self.get_or_init_project(project, cx);
795 }
796
797 pub fn register_buffer(
798 &mut self,
799 buffer: &Entity<Buffer>,
800 project: &Entity<Project>,
801 cx: &mut Context<Self>,
802 ) {
803 let project_state = self.get_or_init_project(project, cx);
804 Self::register_buffer_impl(project_state, buffer, project, cx);
805 }
806
807 fn get_or_init_project(
808 &mut self,
809 project: &Entity<Project>,
810 cx: &mut Context<Self>,
811 ) -> &mut ProjectState {
812 let entity_id = project.entity_id();
813 self.projects
814 .entry(entity_id)
815 .or_insert_with(|| ProjectState {
816 context: {
817 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
818 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
819 this.handle_excerpt_store_event(entity_id, event);
820 })
821 .detach();
822 related_excerpt_store
823 },
824 events: VecDeque::new(),
825 last_event: None,
826 recent_paths: VecDeque::new(),
827 debug_tx: None,
828 registered_buffers: HashMap::default(),
829 current_prediction: None,
830 cancelled_predictions: HashSet::default(),
831 pending_predictions: ArrayVec::new(),
832 next_pending_prediction_id: 0,
833 last_prediction_refresh: None,
834 license_detection_watchers: HashMap::default(),
835 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
836 _subscription: cx.subscribe(&project, Self::handle_project_event),
837 })
838 }
839
840 pub fn remove_project(&mut self, project: &Entity<Project>) {
841 self.projects.remove(&project.entity_id());
842 }
843
844 fn handle_excerpt_store_event(
845 &mut self,
846 project_entity_id: EntityId,
847 event: &RelatedExcerptStoreEvent,
848 ) {
849 if let Some(project_state) = self.projects.get(&project_entity_id) {
850 if let Some(debug_tx) = project_state.debug_tx.clone() {
851 match event {
852 RelatedExcerptStoreEvent::StartedRefresh => {
853 debug_tx
854 .unbounded_send(DebugEvent::ContextRetrievalStarted(
855 ContextRetrievalStartedDebugEvent {
856 project_entity_id: project_entity_id,
857 timestamp: Instant::now(),
858 search_prompt: String::new(),
859 },
860 ))
861 .ok();
862 }
863 RelatedExcerptStoreEvent::FinishedRefresh {
864 cache_hit_count,
865 cache_miss_count,
866 mean_definition_latency,
867 max_definition_latency,
868 } => {
869 debug_tx
870 .unbounded_send(DebugEvent::ContextRetrievalFinished(
871 ContextRetrievalFinishedDebugEvent {
872 project_entity_id: project_entity_id,
873 timestamp: Instant::now(),
874 metadata: vec![
875 (
876 "Cache Hits",
877 format!(
878 "{}/{}",
879 cache_hit_count,
880 cache_hit_count + cache_miss_count
881 )
882 .into(),
883 ),
884 (
885 "Max LSP Time",
886 format!("{} ms", max_definition_latency.as_millis())
887 .into(),
888 ),
889 (
890 "Mean LSP Time",
891 format!("{} ms", mean_definition_latency.as_millis())
892 .into(),
893 ),
894 ],
895 },
896 ))
897 .ok();
898 }
899 }
900 }
901 }
902 }
903
904 pub fn debug_info(
905 &mut self,
906 project: &Entity<Project>,
907 cx: &mut Context<Self>,
908 ) -> mpsc::UnboundedReceiver<DebugEvent> {
909 let project_state = self.get_or_init_project(project, cx);
910 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
911 project_state.debug_tx = Some(debug_watch_tx);
912 debug_watch_rx
913 }
914
915 fn handle_project_event(
916 &mut self,
917 project: Entity<Project>,
918 event: &project::Event,
919 cx: &mut Context<Self>,
920 ) {
921 // TODO [zeta2] init with recent paths
922 match event {
923 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
924 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
925 return;
926 };
927 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
928 if let Some(path) = path {
929 if let Some(ix) = project_state
930 .recent_paths
931 .iter()
932 .position(|probe| probe == &path)
933 {
934 project_state.recent_paths.remove(ix);
935 }
936 project_state.recent_paths.push_front(path);
937 }
938 }
939 project::Event::DiagnosticsUpdated { .. } => {
940 if cx.has_flag::<Zeta2FeatureFlag>() {
941 self.refresh_prediction_from_diagnostics(project, cx);
942 }
943 }
944 _ => (),
945 }
946 }
947
948 fn register_buffer_impl<'a>(
949 project_state: &'a mut ProjectState,
950 buffer: &Entity<Buffer>,
951 project: &Entity<Project>,
952 cx: &mut Context<Self>,
953 ) -> &'a mut RegisteredBuffer {
954 let buffer_id = buffer.entity_id();
955
956 if let Some(file) = buffer.read(cx).file() {
957 let worktree_id = file.worktree_id(cx);
958 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
959 project_state
960 .license_detection_watchers
961 .entry(worktree_id)
962 .or_insert_with(|| {
963 let project_entity_id = project.entity_id();
964 cx.observe_release(&worktree, move |this, _worktree, _cx| {
965 let Some(project_state) = this.projects.get_mut(&project_entity_id)
966 else {
967 return;
968 };
969 project_state
970 .license_detection_watchers
971 .remove(&worktree_id);
972 })
973 .detach();
974 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
975 });
976 }
977 }
978
979 match project_state.registered_buffers.entry(buffer_id) {
980 hash_map::Entry::Occupied(entry) => entry.into_mut(),
981 hash_map::Entry::Vacant(entry) => {
982 let buf = buffer.read(cx);
983 let snapshot = buf.text_snapshot();
984 let file = buf.file().cloned();
985 let project_entity_id = project.entity_id();
986 entry.insert(RegisteredBuffer {
987 snapshot,
988 file,
989 last_position: None,
990 _subscriptions: [
991 cx.subscribe(buffer, {
992 let project = project.downgrade();
993 move |this, buffer, event, cx| {
994 if let language::BufferEvent::Edited = event
995 && let Some(project) = project.upgrade()
996 {
997 this.report_changes_for_buffer(&buffer, &project, cx);
998 }
999 }
1000 }),
1001 cx.observe_release(buffer, move |this, _buffer, _cx| {
1002 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1003 else {
1004 return;
1005 };
1006 project_state.registered_buffers.remove(&buffer_id);
1007 }),
1008 ],
1009 })
1010 }
1011 }
1012 }
1013
1014 fn report_changes_for_buffer(
1015 &mut self,
1016 buffer: &Entity<Buffer>,
1017 project: &Entity<Project>,
1018 cx: &mut Context<Self>,
1019 ) {
1020 let project_state = self.get_or_init_project(project, cx);
1021 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1022
1023 let buf = buffer.read(cx);
1024 let new_file = buf.file().cloned();
1025 let new_snapshot = buf.text_snapshot();
1026 if new_snapshot.version == registered_buffer.snapshot.version {
1027 return;
1028 }
1029
1030 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1031 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1032 let mut num_edits = 0usize;
1033 let mut total_deleted = 0usize;
1034 let mut total_inserted = 0usize;
1035 let mut edit_range: Option<Range<Anchor>> = None;
1036 let mut last_offset: Option<usize> = None;
1037
1038 for (edit, anchor_range) in
1039 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1040 {
1041 num_edits += 1;
1042 total_deleted += edit.old.len();
1043 total_inserted += edit.new.len();
1044 edit_range = Some(match edit_range {
1045 None => anchor_range,
1046 Some(acc) => acc.start..anchor_range.end,
1047 });
1048 last_offset = Some(edit.new.end);
1049 }
1050
1051 if num_edits > 0 {
1052 let action_type = match (total_deleted, total_inserted, num_edits) {
1053 (0, ins, n) if ins == n => UserActionType::InsertChar,
1054 (0, _, _) => UserActionType::InsertSelection,
1055 (del, 0, n) if del == n => UserActionType::DeleteChar,
1056 (_, 0, _) => UserActionType::DeleteSelection,
1057 (_, ins, n) if ins == n => UserActionType::InsertChar,
1058 (_, _, _) => UserActionType::InsertSelection,
1059 };
1060
1061 if let Some(offset) = last_offset {
1062 let point = new_snapshot.offset_to_point(offset);
1063 let timestamp_epoch_ms = SystemTime::now()
1064 .duration_since(UNIX_EPOCH)
1065 .map(|d| d.as_millis() as u64)
1066 .unwrap_or(0);
1067 project_state.record_user_action(UserActionRecord {
1068 action_type,
1069 buffer_id: buffer.entity_id(),
1070 line_number: point.row,
1071 offset,
1072 timestamp_epoch_ms,
1073 });
1074 }
1075 }
1076
1077 let events = &mut project_state.events;
1078
1079 let now = cx.background_executor().now();
1080 if let Some(last_event) = project_state.last_event.as_mut() {
1081 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1082 == last_event.new_snapshot.remote_id()
1083 && old_snapshot.version == last_event.new_snapshot.version;
1084
1085 let should_coalesce = is_next_snapshot_of_same_buffer
1086 && edit_range
1087 .as_ref()
1088 .zip(last_event.edit_range.as_ref())
1089 .is_some_and(|(a, b)| {
1090 let a = a.to_point(&new_snapshot);
1091 let b = b.to_point(&new_snapshot);
1092 if a.start > b.end {
1093 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1094 } else if b.start > a.end {
1095 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1096 } else {
1097 true
1098 }
1099 });
1100
1101 if should_coalesce {
1102 let pause_elapsed = last_event
1103 .last_edit_time
1104 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1105 .unwrap_or(false);
1106 if pause_elapsed {
1107 last_event.snapshot_after_last_editing_pause =
1108 Some(last_event.new_snapshot.clone());
1109 }
1110
1111 last_event.edit_range = edit_range;
1112 last_event.new_snapshot = new_snapshot;
1113 last_event.last_edit_time = Some(now);
1114 return;
1115 }
1116 }
1117
1118 if let Some(event) = project_state.last_event.take() {
1119 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1120 if events.len() + 1 >= EVENT_COUNT_MAX {
1121 events.pop_front();
1122 }
1123 events.push_back(event);
1124 }
1125 }
1126
1127 project_state.last_event = Some(LastEvent {
1128 old_file,
1129 new_file,
1130 old_snapshot,
1131 new_snapshot,
1132 edit_range,
1133 snapshot_after_last_editing_pause: None,
1134 last_edit_time: Some(now),
1135 });
1136 }
1137
1138 fn prediction_at(
1139 &mut self,
1140 buffer: &Entity<Buffer>,
1141 position: Option<language::Anchor>,
1142 project: &Entity<Project>,
1143 cx: &App,
1144 ) -> Option<BufferEditPrediction<'_>> {
1145 let project_state = self.projects.get_mut(&project.entity_id())?;
1146 if let Some(position) = position
1147 && let Some(buffer) = project_state
1148 .registered_buffers
1149 .get_mut(&buffer.entity_id())
1150 {
1151 buffer.last_position = Some(position);
1152 }
1153
1154 let CurrentEditPrediction {
1155 requested_by,
1156 prediction,
1157 ..
1158 } = project_state.current_prediction.as_ref()?;
1159
1160 if prediction.targets_buffer(buffer.read(cx)) {
1161 Some(BufferEditPrediction::Local { prediction })
1162 } else {
1163 let show_jump = match requested_by {
1164 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1165 requested_by_buffer_id == &buffer.entity_id()
1166 }
1167 PredictionRequestedBy::DiagnosticsUpdate => true,
1168 };
1169
1170 if show_jump {
1171 Some(BufferEditPrediction::Jump { prediction })
1172 } else {
1173 None
1174 }
1175 }
1176 }
1177
1178 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1179 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1180 return;
1181 };
1182
1183 let Some(current_prediction) = project_state.current_prediction.take() else {
1184 return;
1185 };
1186
1187 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1188 project_state.cancel_pending_prediction(pending_prediction, cx);
1189 }
1190
1191 match self.edit_prediction_model {
1192 EditPredictionModel::Sweep => {
1193 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1194 }
1195 EditPredictionModel::Mercury => {}
1196 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1197 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1198 }
1199 }
1200 }
1201
1202 async fn handle_rejected_predictions(
1203 rx: UnboundedReceiver<EditPredictionRejection>,
1204 client: Arc<Client>,
1205 llm_token: LlmApiToken,
1206 app_version: Version,
1207 background_executor: BackgroundExecutor,
1208 ) {
1209 let mut rx = std::pin::pin!(rx.peekable());
1210 let mut batched = Vec::new();
1211
1212 while let Some(rejection) = rx.next().await {
1213 batched.push(rejection);
1214
1215 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1216 select_biased! {
1217 next = rx.as_mut().peek().fuse() => {
1218 if next.is_some() {
1219 continue;
1220 }
1221 }
1222 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1223 }
1224 }
1225
1226 let url = client
1227 .http_client()
1228 .build_zed_llm_url("/predict_edits/reject", &[])
1229 .unwrap();
1230
1231 let flush_count = batched
1232 .len()
1233 // in case items have accumulated after failure
1234 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1235 let start = batched.len() - flush_count;
1236
1237 let body = RejectEditPredictionsBodyRef {
1238 rejections: &batched[start..],
1239 };
1240
1241 let result = Self::send_api_request::<()>(
1242 |builder| {
1243 let req = builder
1244 .uri(url.as_ref())
1245 .body(serde_json::to_string(&body)?.into());
1246 anyhow::Ok(req?)
1247 },
1248 client.clone(),
1249 llm_token.clone(),
1250 app_version.clone(),
1251 true,
1252 )
1253 .await;
1254
1255 if result.log_err().is_some() {
1256 batched.drain(start..);
1257 }
1258 }
1259 }
1260
1261 fn reject_current_prediction(
1262 &mut self,
1263 reason: EditPredictionRejectReason,
1264 project: &Entity<Project>,
1265 ) {
1266 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1267 project_state.pending_predictions.clear();
1268 if let Some(prediction) = project_state.current_prediction.take() {
1269 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1270 }
1271 };
1272 }
1273
1274 fn did_show_current_prediction(
1275 &mut self,
1276 project: &Entity<Project>,
1277 display_type: edit_prediction_types::SuggestionDisplayType,
1278 cx: &mut Context<Self>,
1279 ) {
1280 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1281 return;
1282 };
1283
1284 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1285 return;
1286 };
1287
1288 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1289 let previous_shown_with = current_prediction.shown_with;
1290
1291 if previous_shown_with.is_none() || !is_jump {
1292 current_prediction.shown_with = Some(display_type);
1293 }
1294
1295 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1296
1297 if is_first_non_jump_show {
1298 current_prediction.was_shown = true;
1299 }
1300
1301 let display_type_changed = previous_shown_with != Some(display_type);
1302
1303 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1304 sweep_ai::edit_prediction_shown(
1305 &self.sweep_ai,
1306 self.client.clone(),
1307 ¤t_prediction.prediction,
1308 display_type,
1309 cx,
1310 );
1311 }
1312
1313 if is_first_non_jump_show {
1314 self.shown_predictions
1315 .push_front(current_prediction.prediction.clone());
1316 if self.shown_predictions.len() > 50 {
1317 let completion = self.shown_predictions.pop_back().unwrap();
1318 self.rated_predictions.remove(&completion.id);
1319 }
1320 }
1321 }
1322
1323 fn reject_prediction(
1324 &mut self,
1325 prediction_id: EditPredictionId,
1326 reason: EditPredictionRejectReason,
1327 was_shown: bool,
1328 ) {
1329 match self.edit_prediction_model {
1330 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1331 if self.custom_predict_edits_url.is_some() {
1332 return;
1333 }
1334 }
1335 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1336 }
1337
1338 self.reject_predictions_tx
1339 .unbounded_send(EditPredictionRejection {
1340 request_id: prediction_id.to_string(),
1341 reason,
1342 was_shown,
1343 })
1344 .log_err();
1345 }
1346
1347 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1348 self.projects
1349 .get(&project.entity_id())
1350 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1351 }
1352
1353 pub fn refresh_prediction_from_buffer(
1354 &mut self,
1355 project: Entity<Project>,
1356 buffer: Entity<Buffer>,
1357 position: language::Anchor,
1358 cx: &mut Context<Self>,
1359 ) {
1360 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1361 let Some(request_task) = this
1362 .update(cx, |this, cx| {
1363 this.request_prediction(
1364 &project,
1365 &buffer,
1366 position,
1367 PredictEditsRequestTrigger::Other,
1368 cx,
1369 )
1370 })
1371 .log_err()
1372 else {
1373 return Task::ready(anyhow::Ok(None));
1374 };
1375
1376 cx.spawn(async move |_cx| {
1377 request_task.await.map(|prediction_result| {
1378 prediction_result.map(|prediction_result| {
1379 (
1380 prediction_result,
1381 PredictionRequestedBy::Buffer(buffer.entity_id()),
1382 )
1383 })
1384 })
1385 })
1386 })
1387 }
1388
1389 pub fn refresh_prediction_from_diagnostics(
1390 &mut self,
1391 project: Entity<Project>,
1392 cx: &mut Context<Self>,
1393 ) {
1394 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1395 return;
1396 };
1397
1398 // Prefer predictions from buffer
1399 if project_state.current_prediction.is_some() {
1400 return;
1401 };
1402
1403 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1404 let Some((active_buffer, snapshot, cursor_point)) = this
1405 .read_with(cx, |this, cx| {
1406 let project_state = this.projects.get(&project.entity_id())?;
1407 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1408 let snapshot = buffer.read(cx).snapshot();
1409
1410 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1411 return None;
1412 }
1413
1414 let cursor_point = position
1415 .map(|pos| pos.to_point(&snapshot))
1416 .unwrap_or_default();
1417
1418 Some((buffer, snapshot, cursor_point))
1419 })
1420 .log_err()
1421 .flatten()
1422 else {
1423 return Task::ready(anyhow::Ok(None));
1424 };
1425
1426 cx.spawn(async move |cx| {
1427 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1428 active_buffer,
1429 &snapshot,
1430 Default::default(),
1431 cursor_point,
1432 &project,
1433 cx,
1434 )
1435 .await?
1436 else {
1437 return anyhow::Ok(None);
1438 };
1439
1440 let Some(prediction_result) = this
1441 .update(cx, |this, cx| {
1442 this.request_prediction(
1443 &project,
1444 &jump_buffer,
1445 jump_position,
1446 PredictEditsRequestTrigger::Diagnostics,
1447 cx,
1448 )
1449 })?
1450 .await?
1451 else {
1452 return anyhow::Ok(None);
1453 };
1454
1455 this.update(cx, |this, cx| {
1456 Some((
1457 if this
1458 .get_or_init_project(&project, cx)
1459 .current_prediction
1460 .is_none()
1461 {
1462 prediction_result
1463 } else {
1464 EditPredictionResult {
1465 id: prediction_result.id,
1466 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1467 }
1468 },
1469 PredictionRequestedBy::DiagnosticsUpdate,
1470 ))
1471 })
1472 })
1473 });
1474 }
1475
1476 fn predictions_enabled_at(
1477 snapshot: &BufferSnapshot,
1478 position: Option<language::Anchor>,
1479 cx: &App,
1480 ) -> bool {
1481 let file = snapshot.file();
1482 let all_settings = all_language_settings(file, cx);
1483 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1484 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1485 {
1486 return false;
1487 }
1488
1489 if let Some(last_position) = position {
1490 let settings = snapshot.settings_at(last_position, cx);
1491
1492 if !settings.edit_predictions_disabled_in.is_empty()
1493 && let Some(scope) = snapshot.language_scope_at(last_position)
1494 && let Some(scope_name) = scope.override_name()
1495 && settings
1496 .edit_predictions_disabled_in
1497 .iter()
1498 .any(|s| s == scope_name)
1499 {
1500 return false;
1501 }
1502 }
1503
1504 true
1505 }
1506
1507 #[cfg(not(test))]
1508 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1509 #[cfg(test)]
1510 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1511
1512 fn queue_prediction_refresh(
1513 &mut self,
1514 project: Entity<Project>,
1515 throttle_entity: EntityId,
1516 cx: &mut Context<Self>,
1517 do_refresh: impl FnOnce(
1518 WeakEntity<Self>,
1519 &mut AsyncApp,
1520 )
1521 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1522 + 'static,
1523 ) {
1524 let project_state = self.get_or_init_project(&project, cx);
1525 let pending_prediction_id = project_state.next_pending_prediction_id;
1526 project_state.next_pending_prediction_id += 1;
1527 let last_request = project_state.last_prediction_refresh;
1528
1529 let task = cx.spawn(async move |this, cx| {
1530 if let Some((last_entity, last_timestamp)) = last_request
1531 && throttle_entity == last_entity
1532 && let Some(timeout) =
1533 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1534 {
1535 cx.background_executor().timer(timeout).await;
1536 }
1537
1538 // If this task was cancelled before the throttle timeout expired,
1539 // do not perform a request.
1540 let mut is_cancelled = true;
1541 this.update(cx, |this, cx| {
1542 let project_state = this.get_or_init_project(&project, cx);
1543 if !project_state
1544 .cancelled_predictions
1545 .remove(&pending_prediction_id)
1546 {
1547 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1548 is_cancelled = false;
1549 }
1550 })
1551 .ok();
1552 if is_cancelled {
1553 return None;
1554 }
1555
1556 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1557 let new_prediction_id = new_prediction_result
1558 .as_ref()
1559 .map(|(prediction, _)| prediction.id.clone());
1560
1561 // When a prediction completes, remove it from the pending list, and cancel
1562 // any pending predictions that were enqueued before it.
1563 this.update(cx, |this, cx| {
1564 let project_state = this.get_or_init_project(&project, cx);
1565
1566 let is_cancelled = project_state
1567 .cancelled_predictions
1568 .remove(&pending_prediction_id);
1569
1570 let new_current_prediction = if !is_cancelled
1571 && let Some((prediction_result, requested_by)) = new_prediction_result
1572 {
1573 match prediction_result.prediction {
1574 Ok(prediction) => {
1575 let new_prediction = CurrentEditPrediction {
1576 requested_by,
1577 prediction,
1578 was_shown: false,
1579 shown_with: None,
1580 };
1581
1582 if let Some(current_prediction) =
1583 project_state.current_prediction.as_ref()
1584 {
1585 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1586 {
1587 this.reject_current_prediction(
1588 EditPredictionRejectReason::Replaced,
1589 &project,
1590 );
1591
1592 Some(new_prediction)
1593 } else {
1594 this.reject_prediction(
1595 new_prediction.prediction.id,
1596 EditPredictionRejectReason::CurrentPreferred,
1597 false,
1598 );
1599 None
1600 }
1601 } else {
1602 Some(new_prediction)
1603 }
1604 }
1605 Err(reject_reason) => {
1606 this.reject_prediction(prediction_result.id, reject_reason, false);
1607 None
1608 }
1609 }
1610 } else {
1611 None
1612 };
1613
1614 let project_state = this.get_or_init_project(&project, cx);
1615
1616 if let Some(new_prediction) = new_current_prediction {
1617 project_state.current_prediction = Some(new_prediction);
1618 }
1619
1620 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1621 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1622 if pending_prediction.id == pending_prediction_id {
1623 pending_predictions.remove(ix);
1624 for pending_prediction in pending_predictions.drain(0..ix) {
1625 project_state.cancel_pending_prediction(pending_prediction, cx)
1626 }
1627 break;
1628 }
1629 }
1630 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1631 cx.notify();
1632 })
1633 .ok();
1634
1635 new_prediction_id
1636 });
1637
1638 if project_state.pending_predictions.len() <= 1 {
1639 project_state.pending_predictions.push(PendingPrediction {
1640 id: pending_prediction_id,
1641 task,
1642 });
1643 } else if project_state.pending_predictions.len() == 2 {
1644 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1645 project_state.pending_predictions.push(PendingPrediction {
1646 id: pending_prediction_id,
1647 task,
1648 });
1649 project_state.cancel_pending_prediction(pending_prediction, cx);
1650 }
1651 }
1652
1653 pub fn request_prediction(
1654 &mut self,
1655 project: &Entity<Project>,
1656 active_buffer: &Entity<Buffer>,
1657 position: language::Anchor,
1658 trigger: PredictEditsRequestTrigger,
1659 cx: &mut Context<Self>,
1660 ) -> Task<Result<Option<EditPredictionResult>>> {
1661 self.request_prediction_internal(
1662 project.clone(),
1663 active_buffer.clone(),
1664 position,
1665 trigger,
1666 cx.has_flag::<Zeta2FeatureFlag>(),
1667 cx,
1668 )
1669 }
1670
1671 fn request_prediction_internal(
1672 &mut self,
1673 project: Entity<Project>,
1674 active_buffer: Entity<Buffer>,
1675 position: language::Anchor,
1676 trigger: PredictEditsRequestTrigger,
1677 allow_jump: bool,
1678 cx: &mut Context<Self>,
1679 ) -> Task<Result<Option<EditPredictionResult>>> {
1680 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1681
1682 self.get_or_init_project(&project, cx);
1683 let project_state = self.projects.get(&project.entity_id()).unwrap();
1684 let stored_events = project_state.events(cx);
1685 let has_events = !stored_events.is_empty();
1686 let events: Vec<Arc<zeta_prompt::Event>> =
1687 stored_events.into_iter().map(|e| e.event).collect();
1688 let debug_tx = project_state.debug_tx.clone();
1689
1690 let snapshot = active_buffer.read(cx).snapshot();
1691 let cursor_point = position.to_point(&snapshot);
1692 let current_offset = position.to_offset(&snapshot);
1693
1694 let mut user_actions: Vec<UserActionRecord> =
1695 project_state.user_actions.iter().cloned().collect();
1696
1697 if let Some(last_action) = user_actions.last() {
1698 if last_action.buffer_id == active_buffer.entity_id()
1699 && current_offset != last_action.offset
1700 {
1701 let timestamp_epoch_ms = SystemTime::now()
1702 .duration_since(UNIX_EPOCH)
1703 .map(|d| d.as_millis() as u64)
1704 .unwrap_or(0);
1705 user_actions.push(UserActionRecord {
1706 action_type: UserActionType::CursorMovement,
1707 buffer_id: active_buffer.entity_id(),
1708 line_number: cursor_point.row,
1709 offset: current_offset,
1710 timestamp_epoch_ms,
1711 });
1712 }
1713 }
1714 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1715 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1716 let diagnostic_search_range =
1717 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1718
1719 let related_files = if self.use_context {
1720 self.context_for_project(&project, cx)
1721 } else {
1722 Vec::new().into()
1723 };
1724
1725 let inputs = EditPredictionModelInput {
1726 project: project.clone(),
1727 buffer: active_buffer.clone(),
1728 snapshot: snapshot.clone(),
1729 position,
1730 events,
1731 related_files,
1732 recent_paths: project_state.recent_paths.clone(),
1733 trigger,
1734 diagnostic_search_range: diagnostic_search_range.clone(),
1735 debug_tx,
1736 user_actions,
1737 };
1738
1739 let can_collect_example = snapshot
1740 .file()
1741 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1742 && self.can_collect_events(&inputs.events, cx);
1743
1744 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1745 let events_for_capture =
1746 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1747 if let Some(example_task) = capture_example::capture_example(
1748 project.clone(),
1749 active_buffer.clone(),
1750 position,
1751 events_for_capture,
1752 false,
1753 cx,
1754 ) {
1755 cx.spawn(async move |_this, _cx| {
1756 let example = example_task.await?;
1757 telemetry::event!("Edit Prediction Example Captured", example = example);
1758 anyhow::Ok(())
1759 })
1760 .detach_and_log_err(cx);
1761 }
1762 }
1763 let task = match self.edit_prediction_model {
1764 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1765 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1766 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1767 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1768 };
1769
1770 cx.spawn(async move |this, cx| {
1771 let prediction = task.await?;
1772
1773 if prediction.is_none() && allow_jump {
1774 let cursor_point = position.to_point(&snapshot);
1775 if has_events
1776 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1777 active_buffer.clone(),
1778 &snapshot,
1779 diagnostic_search_range,
1780 cursor_point,
1781 &project,
1782 cx,
1783 )
1784 .await?
1785 {
1786 return this
1787 .update(cx, |this, cx| {
1788 this.request_prediction_internal(
1789 project,
1790 jump_buffer,
1791 jump_position,
1792 trigger,
1793 false,
1794 cx,
1795 )
1796 })?
1797 .await;
1798 }
1799
1800 return anyhow::Ok(None);
1801 }
1802
1803 Ok(prediction)
1804 })
1805 }
1806
1807 async fn next_diagnostic_location(
1808 active_buffer: Entity<Buffer>,
1809 active_buffer_snapshot: &BufferSnapshot,
1810 active_buffer_diagnostic_search_range: Range<Point>,
1811 active_buffer_cursor_point: Point,
1812 project: &Entity<Project>,
1813 cx: &mut AsyncApp,
1814 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1815 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1816 let mut jump_location = active_buffer_snapshot
1817 .diagnostic_groups(None)
1818 .into_iter()
1819 .filter_map(|(_, group)| {
1820 let range = &group.entries[group.primary_ix]
1821 .range
1822 .to_point(&active_buffer_snapshot);
1823 if range.overlaps(&active_buffer_diagnostic_search_range) {
1824 None
1825 } else {
1826 Some(range.start)
1827 }
1828 })
1829 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1830 .map(|position| {
1831 (
1832 active_buffer.clone(),
1833 active_buffer_snapshot.anchor_before(position),
1834 )
1835 });
1836
1837 if jump_location.is_none() {
1838 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1839 let file = buffer.file()?;
1840
1841 Some(ProjectPath {
1842 worktree_id: file.worktree_id(cx),
1843 path: file.path().clone(),
1844 })
1845 });
1846
1847 let buffer_task = project.update(cx, |project, cx| {
1848 let (path, _, _) = project
1849 .diagnostic_summaries(false, cx)
1850 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1851 .max_by_key(|(path, _, _)| {
1852 // find the buffer with errors that shares most parent directories
1853 path.path
1854 .components()
1855 .zip(
1856 active_buffer_path
1857 .as_ref()
1858 .map(|p| p.path.components())
1859 .unwrap_or_default(),
1860 )
1861 .take_while(|(a, b)| a == b)
1862 .count()
1863 })?;
1864
1865 Some(project.open_buffer(path, cx))
1866 });
1867
1868 if let Some(buffer_task) = buffer_task {
1869 let closest_buffer = buffer_task.await?;
1870
1871 jump_location = closest_buffer
1872 .read_with(cx, |buffer, _cx| {
1873 buffer
1874 .buffer_diagnostics(None)
1875 .into_iter()
1876 .min_by_key(|entry| entry.diagnostic.severity)
1877 .map(|entry| entry.range.start)
1878 })
1879 .map(|position| (closest_buffer, position));
1880 }
1881 }
1882
1883 anyhow::Ok(jump_location)
1884 }
1885
1886 async fn send_raw_llm_request(
1887 request: open_ai::Request,
1888 client: Arc<Client>,
1889 llm_token: LlmApiToken,
1890 app_version: Version,
1891 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1892 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1893 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1894 let url = client
1895 .http_client()
1896 .build_zed_llm_url("/predict_edits/raw", &[])?;
1897
1898 #[cfg(feature = "cli-support")]
1899 let cache_key = if let Some(cache) = eval_cache {
1900 use collections::FxHasher;
1901 use std::hash::{Hash, Hasher};
1902
1903 let mut hasher = FxHasher::default();
1904 url.hash(&mut hasher);
1905 let request_str = serde_json::to_string_pretty(&request)?;
1906 request_str.hash(&mut hasher);
1907 let hash = hasher.finish();
1908
1909 let key = (eval_cache_kind, hash);
1910 if let Some(response_str) = cache.read(key) {
1911 return Ok((serde_json::from_str(&response_str)?, None));
1912 }
1913
1914 Some((cache, request_str, key))
1915 } else {
1916 None
1917 };
1918
1919 let (response, usage) = Self::send_api_request(
1920 |builder| {
1921 let req = builder
1922 .uri(url.as_ref())
1923 .body(serde_json::to_string(&request)?.into());
1924 Ok(req?)
1925 },
1926 client,
1927 llm_token,
1928 app_version,
1929 true,
1930 )
1931 .await?;
1932
1933 #[cfg(feature = "cli-support")]
1934 if let Some((cache, request, key)) = cache_key {
1935 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1936 }
1937
1938 Ok((response, usage))
1939 }
1940
1941 fn handle_api_response<T>(
1942 this: &WeakEntity<Self>,
1943 response: Result<(T, Option<EditPredictionUsage>)>,
1944 cx: &mut gpui::AsyncApp,
1945 ) -> Result<T> {
1946 match response {
1947 Ok((data, usage)) => {
1948 if let Some(usage) = usage {
1949 this.update(cx, |this, cx| {
1950 this.user_store.update(cx, |user_store, cx| {
1951 user_store.update_edit_prediction_usage(usage, cx);
1952 });
1953 })
1954 .ok();
1955 }
1956 Ok(data)
1957 }
1958 Err(err) => {
1959 if err.is::<ZedUpdateRequiredError>() {
1960 cx.update(|cx| {
1961 this.update(cx, |this, _cx| {
1962 this.update_required = true;
1963 })
1964 .ok();
1965
1966 let error_message: SharedString = err.to_string().into();
1967 show_app_notification(
1968 NotificationId::unique::<ZedUpdateRequiredError>(),
1969 cx,
1970 move |cx| {
1971 cx.new(|cx| {
1972 ErrorMessagePrompt::new(error_message.clone(), cx)
1973 .with_link_button("Update Zed", "https://zed.dev/releases")
1974 })
1975 },
1976 );
1977 });
1978 }
1979 Err(err)
1980 }
1981 }
1982 }
1983
1984 async fn send_api_request<Res>(
1985 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1986 client: Arc<Client>,
1987 llm_token: LlmApiToken,
1988 app_version: Version,
1989 require_auth: bool,
1990 ) -> Result<(Res, Option<EditPredictionUsage>)>
1991 where
1992 Res: DeserializeOwned,
1993 {
1994 let http_client = client.http_client();
1995
1996 let mut token = if require_auth {
1997 Some(llm_token.acquire(&client).await?)
1998 } else {
1999 llm_token.acquire(&client).await.ok()
2000 };
2001 let mut did_retry = false;
2002
2003 loop {
2004 let request_builder = http_client::Request::builder().method(Method::POST);
2005
2006 let mut request_builder = request_builder
2007 .header("Content-Type", "application/json")
2008 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2009
2010 // Only add Authorization header if we have a token
2011 if let Some(ref token_value) = token {
2012 request_builder =
2013 request_builder.header("Authorization", format!("Bearer {}", token_value));
2014 }
2015
2016 let request = build(request_builder)?;
2017
2018 let mut response = http_client.send(request).await?;
2019
2020 if let Some(minimum_required_version) = response
2021 .headers()
2022 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2023 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2024 {
2025 anyhow::ensure!(
2026 app_version >= minimum_required_version,
2027 ZedUpdateRequiredError {
2028 minimum_version: minimum_required_version
2029 }
2030 );
2031 }
2032
2033 if response.status().is_success() {
2034 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2035
2036 let mut body = Vec::new();
2037 response.body_mut().read_to_end(&mut body).await?;
2038 return Ok((serde_json::from_slice(&body)?, usage));
2039 } else if !did_retry
2040 && token.is_some()
2041 && response
2042 .headers()
2043 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2044 .is_some()
2045 {
2046 did_retry = true;
2047 token = Some(llm_token.refresh(&client).await?);
2048 } else {
2049 let mut body = String::new();
2050 response.body_mut().read_to_string(&mut body).await?;
2051 anyhow::bail!(
2052 "Request failed with status: {:?}\nBody: {}",
2053 response.status(),
2054 body
2055 );
2056 }
2057 }
2058 }
2059
2060 pub fn refresh_context(
2061 &mut self,
2062 project: &Entity<Project>,
2063 buffer: &Entity<language::Buffer>,
2064 cursor_position: language::Anchor,
2065 cx: &mut Context<Self>,
2066 ) {
2067 if self.use_context {
2068 self.get_or_init_project(project, cx)
2069 .context
2070 .update(cx, |store, cx| {
2071 store.refresh(buffer.clone(), cursor_position, cx);
2072 });
2073 }
2074 }
2075
2076 #[cfg(feature = "cli-support")]
2077 pub fn set_context_for_buffer(
2078 &mut self,
2079 project: &Entity<Project>,
2080 related_files: Vec<RelatedFile>,
2081 cx: &mut Context<Self>,
2082 ) {
2083 self.get_or_init_project(project, cx)
2084 .context
2085 .update(cx, |store, _| {
2086 store.set_related_files(related_files);
2087 });
2088 }
2089
2090 fn is_file_open_source(
2091 &self,
2092 project: &Entity<Project>,
2093 file: &Arc<dyn File>,
2094 cx: &App,
2095 ) -> bool {
2096 if !file.is_local() || file.is_private() {
2097 return false;
2098 }
2099 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2100 return false;
2101 };
2102 project_state
2103 .license_detection_watchers
2104 .get(&file.worktree_id(cx))
2105 .as_ref()
2106 .is_some_and(|watcher| watcher.is_project_open_source())
2107 }
2108
2109 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2110 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2111 }
2112
2113 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2114 if !self.data_collection_choice.is_enabled(cx) {
2115 return false;
2116 }
2117 events.iter().all(|event| {
2118 matches!(
2119 event.as_ref(),
2120 zeta_prompt::Event::BufferChange {
2121 in_open_source_repo: true,
2122 ..
2123 }
2124 )
2125 })
2126 }
2127
2128 fn load_data_collection_choice() -> DataCollectionChoice {
2129 let choice = KEY_VALUE_STORE
2130 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2131 .log_err()
2132 .flatten();
2133
2134 match choice.as_deref() {
2135 Some("true") => DataCollectionChoice::Enabled,
2136 Some("false") => DataCollectionChoice::Disabled,
2137 Some(_) => {
2138 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2139 DataCollectionChoice::NotAnswered
2140 }
2141 None => DataCollectionChoice::NotAnswered,
2142 }
2143 }
2144
2145 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2146 self.data_collection_choice = self.data_collection_choice.toggle();
2147 let new_choice = self.data_collection_choice;
2148 let is_enabled = new_choice.is_enabled(cx);
2149 db::write_and_log(cx, move || {
2150 KEY_VALUE_STORE.write_kvp(
2151 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2152 is_enabled.to_string(),
2153 )
2154 });
2155 }
2156
2157 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2158 self.shown_predictions.iter()
2159 }
2160
2161 pub fn shown_completions_len(&self) -> usize {
2162 self.shown_predictions.len()
2163 }
2164
2165 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2166 self.rated_predictions.contains(id)
2167 }
2168
2169 pub fn rate_prediction(
2170 &mut self,
2171 prediction: &EditPrediction,
2172 rating: EditPredictionRating,
2173 feedback: String,
2174 cx: &mut Context<Self>,
2175 ) {
2176 self.rated_predictions.insert(prediction.id.clone());
2177 telemetry::event!(
2178 "Edit Prediction Rated",
2179 rating,
2180 inputs = prediction.inputs,
2181 output = prediction
2182 .edit_preview
2183 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2184 feedback
2185 );
2186 self.client.telemetry().flush_events().detach();
2187 cx.notify();
2188 }
2189
2190 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2191 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2192 && all_language_settings(None, cx).edit_predictions.use_context;
2193 }
2194}
2195
2196#[derive(Error, Debug)]
2197#[error(
2198 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2199)]
2200pub struct ZedUpdateRequiredError {
2201 minimum_version: Version,
2202}
2203
2204#[cfg(feature = "cli-support")]
2205pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2206
2207#[cfg(feature = "cli-support")]
2208#[derive(Debug, Clone, Copy, PartialEq)]
2209pub enum EvalCacheEntryKind {
2210 Context,
2211 Search,
2212 Prediction,
2213}
2214
2215#[cfg(feature = "cli-support")]
2216impl std::fmt::Display for EvalCacheEntryKind {
2217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2218 match self {
2219 EvalCacheEntryKind::Search => write!(f, "search"),
2220 EvalCacheEntryKind::Context => write!(f, "context"),
2221 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2222 }
2223 }
2224}
2225
2226#[cfg(feature = "cli-support")]
2227pub trait EvalCache: Send + Sync {
2228 fn read(&self, key: EvalCacheKey) -> Option<String>;
2229 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2230}
2231
2232#[derive(Debug, Clone, Copy)]
2233pub enum DataCollectionChoice {
2234 NotAnswered,
2235 Enabled,
2236 Disabled,
2237}
2238
2239impl DataCollectionChoice {
2240 pub fn is_enabled(self, cx: &App) -> bool {
2241 if cx.is_staff() {
2242 return true;
2243 }
2244 match self {
2245 Self::Enabled => true,
2246 Self::NotAnswered | Self::Disabled => false,
2247 }
2248 }
2249
2250 #[must_use]
2251 pub fn toggle(&self) -> DataCollectionChoice {
2252 match self {
2253 Self::Enabled => Self::Disabled,
2254 Self::Disabled => Self::Enabled,
2255 Self::NotAnswered => Self::Enabled,
2256 }
2257 }
2258}
2259
2260impl From<bool> for DataCollectionChoice {
2261 fn from(value: bool) -> Self {
2262 match value {
2263 true => DataCollectionChoice::Enabled,
2264 false => DataCollectionChoice::Disabled,
2265 }
2266 }
2267}
2268
2269struct ZedPredictUpsell;
2270
2271impl Dismissable for ZedPredictUpsell {
2272 const KEY: &'static str = "dismissed-edit-predict-upsell";
2273
2274 fn dismissed() -> bool {
2275 // To make this backwards compatible with older versions of Zed, we
2276 // check if the user has seen the previous Edit Prediction Onboarding
2277 // before, by checking the data collection choice which was written to
2278 // the database once the user clicked on "Accept and Enable"
2279 if KEY_VALUE_STORE
2280 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2281 .log_err()
2282 .is_some_and(|s| s.is_some())
2283 {
2284 return true;
2285 }
2286
2287 KEY_VALUE_STORE
2288 .read_kvp(Self::KEY)
2289 .log_err()
2290 .is_some_and(|s| s.is_some())
2291 }
2292}
2293
2294pub fn should_show_upsell_modal() -> bool {
2295 !ZedPredictUpsell::dismissed()
2296}
2297
2298pub fn init(cx: &mut App) {
2299 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2300 workspace.register_action(
2301 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2302 ZedPredictModal::toggle(
2303 workspace,
2304 workspace.user_store().clone(),
2305 workspace.client().clone(),
2306 window,
2307 cx,
2308 )
2309 },
2310 );
2311
2312 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2313 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2314 settings
2315 .project
2316 .all_languages
2317 .features
2318 .get_or_insert_default()
2319 .edit_prediction_provider = Some(EditPredictionProvider::None)
2320 });
2321 });
2322 })
2323 .detach();
2324}