1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 self, PromptFormat, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use db::kvp::{Dismissable, KEY_VALUE_STORE};
14use edit_prediction_context::EditPredictionExcerptOptions;
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, RefreshLlmTokenListener};
33use project::{Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use text::Edit;
40use workspace::Workspace;
41
42use std::ops::Range;
43use std::path::Path;
44use std::rc::Rc;
45use std::str::FromStr as _;
46use std::sync::{Arc, LazyLock};
47use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
48use std::{env, mem};
49use thiserror::Error;
50use util::{RangeExt as _, ResultExt as _};
51use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
52
53pub mod cursor_excerpt;
54pub mod example_spec;
55mod license_detection;
56pub mod mercury;
57mod onboarding_modal;
58pub mod open_ai_response;
59mod prediction;
60pub mod sweep_ai;
61
62pub mod udiff;
63
64mod capture_example;
65mod zed_edit_prediction_delegate;
66pub mod zeta1;
67pub mod zeta2;
68
69#[cfg(test)]
70mod edit_prediction_tests;
71
72use crate::capture_example::should_sample_edit_prediction_example_capture;
73use crate::license_detection::LicenseDetectionWatcher;
74use crate::mercury::Mercury;
75use crate::onboarding_modal::ZedPredictModal;
76pub use crate::prediction::EditPrediction;
77pub use crate::prediction::EditPredictionId;
78use crate::prediction::EditPredictionResult;
79pub use crate::sweep_ai::SweepAi;
80pub use capture_example::capture_example;
81pub use language_model::ApiKeyState;
82pub use telemetry_events::EditPredictionRating;
83pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
84
85actions!(
86 edit_prediction,
87 [
88 /// Resets the edit prediction onboarding state.
89 ResetOnboarding,
90 /// Clears the edit prediction history.
91 ClearHistory,
92 ]
93);
94
95/// Maximum number of events to track.
96const EVENT_COUNT_MAX: usize = 6;
97const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
98const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
99const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
100const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
101
102pub struct SweepFeatureFlag;
103
104impl FeatureFlag for SweepFeatureFlag {
105 const NAME: &str = "sweep-ai";
106}
107
108pub struct MercuryFeatureFlag;
109
110impl FeatureFlag for MercuryFeatureFlag {
111 const NAME: &str = "mercury";
112}
113
114pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
115 context: EditPredictionExcerptOptions {
116 max_bytes: 512,
117 min_bytes: 128,
118 target_before_cursor_over_total_bytes: 0.5,
119 },
120 prompt_format: PromptFormat::DEFAULT,
121};
122
123static USE_OLLAMA: LazyLock<bool> =
124 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
125
126static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
127 match env::var("ZED_ZETA2_MODEL").as_deref() {
128 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
129 Ok(model) => model,
130 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
131 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
132 }
133 .to_string()
134});
135
136pub struct Zeta2FeatureFlag;
137
138impl FeatureFlag for Zeta2FeatureFlag {
139 const NAME: &'static str = "zeta2";
140
141 fn enabled_for_staff() -> bool {
142 true
143 }
144}
145
146pub struct EditPredictionExampleCaptureFeatureFlag;
147
148impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
149 const NAME: &'static str = "edit-prediction-example-capture";
150
151 fn enabled_for_staff() -> bool {
152 true
153 }
154}
155
156#[derive(Clone)]
157struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
158
159impl Global for EditPredictionStoreGlobal {}
160
161pub struct EditPredictionStore {
162 client: Arc<Client>,
163 user_store: Entity<UserStore>,
164 llm_token: LlmApiToken,
165 _llm_token_subscription: Subscription,
166 projects: HashMap<EntityId, ProjectState>,
167 use_context: bool,
168 options: ZetaOptions,
169 update_required: bool,
170 #[cfg(feature = "cli-support")]
171 eval_cache: Option<Arc<dyn EvalCache>>,
172 edit_prediction_model: EditPredictionModel,
173 pub sweep_ai: SweepAi,
174 pub mercury: Mercury,
175 data_collection_choice: DataCollectionChoice,
176 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
177 shown_predictions: VecDeque<EditPrediction>,
178 rated_predictions: HashSet<EditPredictionId>,
179 custom_predict_edits_url: Option<Arc<Url>>,
180}
181
182#[derive(Copy, Clone, Default, PartialEq, Eq)]
183pub enum EditPredictionModel {
184 #[default]
185 Zeta1,
186 Zeta2,
187 Sweep,
188 Mercury,
189}
190
191pub struct EditPredictionModelInput {
192 project: Entity<Project>,
193 buffer: Entity<Buffer>,
194 snapshot: BufferSnapshot,
195 position: Anchor,
196 events: Vec<Arc<zeta_prompt::Event>>,
197 related_files: Arc<[RelatedFile]>,
198 recent_paths: VecDeque<ProjectPath>,
199 trigger: PredictEditsRequestTrigger,
200 diagnostic_search_range: Range<Point>,
201 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
202 pub user_actions: Vec<UserActionRecord>,
203}
204
205#[derive(Debug, Clone, PartialEq)]
206pub struct ZetaOptions {
207 pub context: EditPredictionExcerptOptions,
208 pub prompt_format: predict_edits_v3::PromptFormat,
209}
210
211#[derive(Debug)]
212pub enum DebugEvent {
213 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
214 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
215 EditPredictionStarted(EditPredictionStartedDebugEvent),
216 EditPredictionFinished(EditPredictionFinishedDebugEvent),
217}
218
219#[derive(Debug)]
220pub struct ContextRetrievalStartedDebugEvent {
221 pub project_entity_id: EntityId,
222 pub timestamp: Instant,
223 pub search_prompt: String,
224}
225
226#[derive(Debug)]
227pub struct ContextRetrievalFinishedDebugEvent {
228 pub project_entity_id: EntityId,
229 pub timestamp: Instant,
230 pub metadata: Vec<(&'static str, SharedString)>,
231}
232
233#[derive(Debug)]
234pub struct EditPredictionStartedDebugEvent {
235 pub buffer: WeakEntity<Buffer>,
236 pub position: Anchor,
237 pub prompt: Option<String>,
238}
239
240#[derive(Debug)]
241pub struct EditPredictionFinishedDebugEvent {
242 pub buffer: WeakEntity<Buffer>,
243 pub position: Anchor,
244 pub model_output: Option<String>,
245}
246
247pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
248
249const USER_ACTION_HISTORY_SIZE: usize = 16;
250
251#[derive(Clone, Debug)]
252pub struct UserActionRecord {
253 pub action_type: UserActionType,
254 pub buffer_id: EntityId,
255 pub line_number: u32,
256 pub offset: usize,
257 pub timestamp_epoch_ms: u64,
258}
259
260#[derive(Clone, Copy, Debug, PartialEq, Eq)]
261pub enum UserActionType {
262 InsertChar,
263 InsertSelection,
264 DeleteChar,
265 DeleteSelection,
266 CursorMovement,
267}
268
269/// An event with associated metadata for reconstructing buffer state.
270#[derive(Clone)]
271pub struct StoredEvent {
272 pub event: Arc<zeta_prompt::Event>,
273 pub old_snapshot: TextBufferSnapshot,
274}
275
276struct ProjectState {
277 events: VecDeque<StoredEvent>,
278 last_event: Option<LastEvent>,
279 recent_paths: VecDeque<ProjectPath>,
280 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
281 current_prediction: Option<CurrentEditPrediction>,
282 next_pending_prediction_id: usize,
283 pending_predictions: ArrayVec<PendingPrediction, 2>,
284 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
285 last_prediction_refresh: Option<(EntityId, Instant)>,
286 cancelled_predictions: HashSet<usize>,
287 context: Entity<RelatedExcerptStore>,
288 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
289 user_actions: VecDeque<UserActionRecord>,
290 _subscription: gpui::Subscription,
291}
292
293impl ProjectState {
294 fn record_user_action(&mut self, action: UserActionRecord) {
295 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
296 self.user_actions.pop_front();
297 }
298 self.user_actions.push_back(action);
299 }
300
301 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
302 self.events
303 .iter()
304 .cloned()
305 .chain(
306 self.last_event
307 .as_ref()
308 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
309 )
310 .collect()
311 }
312
313 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
314 self.events
315 .iter()
316 .cloned()
317 .chain(self.last_event.as_ref().iter().flat_map(|event| {
318 let (one, two) = event.split_by_pause();
319 let one = one.finalize(&self.license_detection_watchers, cx);
320 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
321 one.into_iter().chain(two)
322 }))
323 .collect()
324 }
325
326 fn cancel_pending_prediction(
327 &mut self,
328 pending_prediction: PendingPrediction,
329 cx: &mut Context<EditPredictionStore>,
330 ) {
331 self.cancelled_predictions.insert(pending_prediction.id);
332
333 cx.spawn(async move |this, cx| {
334 let Some(prediction_id) = pending_prediction.task.await else {
335 return;
336 };
337
338 this.update(cx, |this, _cx| {
339 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
340 })
341 .ok();
342 })
343 .detach()
344 }
345
346 fn active_buffer(
347 &self,
348 project: &Entity<Project>,
349 cx: &App,
350 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
351 let project = project.read(cx);
352 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
353 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
354 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
355 Some((active_buffer, registered_buffer.last_position))
356 }
357}
358
359#[derive(Debug, Clone)]
360struct CurrentEditPrediction {
361 pub requested_by: PredictionRequestedBy,
362 pub prediction: EditPrediction,
363 pub was_shown: bool,
364 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
365}
366
367impl CurrentEditPrediction {
368 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
369 let Some(new_edits) = self
370 .prediction
371 .interpolate(&self.prediction.buffer.read(cx))
372 else {
373 return false;
374 };
375
376 if self.prediction.buffer != old_prediction.prediction.buffer {
377 return true;
378 }
379
380 let Some(old_edits) = old_prediction
381 .prediction
382 .interpolate(&old_prediction.prediction.buffer.read(cx))
383 else {
384 return true;
385 };
386
387 let requested_by_buffer_id = self.requested_by.buffer_id();
388
389 // This reduces the occurrence of UI thrash from replacing edits
390 //
391 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
392 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
393 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
394 && old_edits.len() == 1
395 && new_edits.len() == 1
396 {
397 let (old_range, old_text) = &old_edits[0];
398 let (new_range, new_text) = &new_edits[0];
399 new_range == old_range && new_text.starts_with(old_text.as_ref())
400 } else {
401 true
402 }
403 }
404}
405
406#[derive(Debug, Clone)]
407enum PredictionRequestedBy {
408 DiagnosticsUpdate,
409 Buffer(EntityId),
410}
411
412impl PredictionRequestedBy {
413 pub fn buffer_id(&self) -> Option<EntityId> {
414 match self {
415 PredictionRequestedBy::DiagnosticsUpdate => None,
416 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
417 }
418 }
419}
420
421#[derive(Debug)]
422struct PendingPrediction {
423 id: usize,
424 task: Task<Option<EditPredictionId>>,
425}
426
427/// A prediction from the perspective of a buffer.
428#[derive(Debug)]
429enum BufferEditPrediction<'a> {
430 Local { prediction: &'a EditPrediction },
431 Jump { prediction: &'a EditPrediction },
432}
433
434#[cfg(test)]
435impl std::ops::Deref for BufferEditPrediction<'_> {
436 type Target = EditPrediction;
437
438 fn deref(&self) -> &Self::Target {
439 match self {
440 BufferEditPrediction::Local { prediction } => prediction,
441 BufferEditPrediction::Jump { prediction } => prediction,
442 }
443 }
444}
445
446struct RegisteredBuffer {
447 file: Option<Arc<dyn File>>,
448 snapshot: TextBufferSnapshot,
449 last_position: Option<Anchor>,
450 _subscriptions: [gpui::Subscription; 2],
451}
452
453#[derive(Clone)]
454struct LastEvent {
455 old_snapshot: TextBufferSnapshot,
456 new_snapshot: TextBufferSnapshot,
457 old_file: Option<Arc<dyn File>>,
458 new_file: Option<Arc<dyn File>>,
459 edit_range: Option<Range<Anchor>>,
460 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
461 last_edit_time: Option<Instant>,
462}
463
464impl LastEvent {
465 pub fn finalize(
466 &self,
467 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
468 cx: &App,
469 ) -> Option<StoredEvent> {
470 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
471 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
472
473 let in_open_source_repo =
474 [self.new_file.as_ref(), self.old_file.as_ref()]
475 .iter()
476 .all(|file| {
477 file.is_some_and(|file| {
478 license_detection_watchers
479 .get(&file.worktree_id(cx))
480 .is_some_and(|watcher| watcher.is_project_open_source())
481 })
482 });
483
484 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
485
486 if path == old_path && diff.is_empty() {
487 None
488 } else {
489 Some(StoredEvent {
490 event: Arc::new(zeta_prompt::Event::BufferChange {
491 old_path,
492 path,
493 diff,
494 in_open_source_repo,
495 // TODO: Actually detect if this edit was predicted or not
496 predicted: false,
497 }),
498 old_snapshot: self.old_snapshot.clone(),
499 })
500 }
501 }
502
503 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
504 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
505 return (self.clone(), None);
506 };
507
508 let before = LastEvent {
509 old_snapshot: self.old_snapshot.clone(),
510 new_snapshot: boundary_snapshot.clone(),
511 old_file: self.old_file.clone(),
512 new_file: self.new_file.clone(),
513 edit_range: None,
514 snapshot_after_last_editing_pause: None,
515 last_edit_time: self.last_edit_time,
516 };
517
518 let after = LastEvent {
519 old_snapshot: boundary_snapshot.clone(),
520 new_snapshot: self.new_snapshot.clone(),
521 old_file: self.old_file.clone(),
522 new_file: self.new_file.clone(),
523 edit_range: None,
524 snapshot_after_last_editing_pause: None,
525 last_edit_time: self.last_edit_time,
526 };
527
528 (before, Some(after))
529 }
530}
531
532pub(crate) fn compute_diff_between_snapshots(
533 old_snapshot: &TextBufferSnapshot,
534 new_snapshot: &TextBufferSnapshot,
535) -> Option<String> {
536 let edits: Vec<Edit<usize>> = new_snapshot
537 .edits_since::<usize>(&old_snapshot.version)
538 .collect();
539
540 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
541
542 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
543 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
544 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
545 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
546
547 const CONTEXT_LINES: u32 = 3;
548
549 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
550 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
551 let old_context_end_row =
552 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
553 let new_context_end_row =
554 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
555
556 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
557 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
558 let old_end_line_offset = old_snapshot
559 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
560 let new_end_line_offset = new_snapshot
561 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
562 let old_edit_range = old_start_line_offset..old_end_line_offset;
563 let new_edit_range = new_start_line_offset..new_end_line_offset;
564
565 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
566 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
567
568 let diff = language::unified_diff_with_offsets(
569 &old_region_text,
570 &new_region_text,
571 old_context_start_row,
572 new_context_start_row,
573 );
574
575 Some(diff)
576}
577
578fn buffer_path_with_id_fallback(
579 file: Option<&Arc<dyn File>>,
580 snapshot: &TextBufferSnapshot,
581 cx: &App,
582) -> Arc<Path> {
583 if let Some(file) = file {
584 file.full_path(cx).into()
585 } else {
586 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
587 }
588}
589
590impl EditPredictionStore {
591 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
592 cx.try_global::<EditPredictionStoreGlobal>()
593 .map(|global| global.0.clone())
594 }
595
596 pub fn global(
597 client: &Arc<Client>,
598 user_store: &Entity<UserStore>,
599 cx: &mut App,
600 ) -> Entity<Self> {
601 cx.try_global::<EditPredictionStoreGlobal>()
602 .map(|global| global.0.clone())
603 .unwrap_or_else(|| {
604 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
605 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
606 ep_store
607 })
608 }
609
610 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
611 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
612 let data_collection_choice = Self::load_data_collection_choice();
613
614 let llm_token = LlmApiToken::default();
615
616 let (reject_tx, reject_rx) = mpsc::unbounded();
617 cx.background_spawn({
618 let client = client.clone();
619 let llm_token = llm_token.clone();
620 let app_version = AppVersion::global(cx);
621 let background_executor = cx.background_executor().clone();
622 async move {
623 Self::handle_rejected_predictions(
624 reject_rx,
625 client,
626 llm_token,
627 app_version,
628 background_executor,
629 )
630 .await
631 }
632 })
633 .detach();
634
635 let mut this = Self {
636 projects: HashMap::default(),
637 client,
638 user_store,
639 options: DEFAULT_OPTIONS,
640 use_context: false,
641 llm_token,
642 _llm_token_subscription: cx.subscribe(
643 &refresh_llm_token_listener,
644 |this, _listener, _event, cx| {
645 let client = this.client.clone();
646 let llm_token = this.llm_token.clone();
647 cx.spawn(async move |_this, _cx| {
648 llm_token.refresh(&client).await?;
649 anyhow::Ok(())
650 })
651 .detach_and_log_err(cx);
652 },
653 ),
654 update_required: false,
655 #[cfg(feature = "cli-support")]
656 eval_cache: None,
657 edit_prediction_model: EditPredictionModel::Zeta2,
658 sweep_ai: SweepAi::new(cx),
659 mercury: Mercury::new(cx),
660 data_collection_choice,
661 reject_predictions_tx: reject_tx,
662 rated_predictions: Default::default(),
663 shown_predictions: Default::default(),
664 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
665 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
666 Err(_) => {
667 if *USE_OLLAMA {
668 Some(
669 Url::parse("http://localhost:11434/v1/chat/completions")
670 .unwrap()
671 .into(),
672 )
673 } else {
674 None
675 }
676 }
677 },
678 };
679
680 this.configure_context_retrieval(cx);
681 let weak_this = cx.weak_entity();
682 cx.on_flags_ready(move |_, cx| {
683 weak_this
684 .update(cx, |this, cx| this.configure_context_retrieval(cx))
685 .ok();
686 })
687 .detach();
688 cx.observe_global::<SettingsStore>(|this, cx| {
689 this.configure_context_retrieval(cx);
690 })
691 .detach();
692
693 this
694 }
695
696 #[cfg(test)]
697 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
698 self.custom_predict_edits_url = Some(url.into());
699 }
700
701 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
702 self.edit_prediction_model = model;
703 }
704
705 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
706 self.sweep_ai.api_token.read(cx).has_key()
707 }
708
709 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
710 self.mercury.api_token.read(cx).has_key()
711 }
712
713 #[cfg(feature = "cli-support")]
714 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
715 self.eval_cache = Some(cache);
716 }
717
718 pub fn options(&self) -> &ZetaOptions {
719 &self.options
720 }
721
722 pub fn set_options(&mut self, options: ZetaOptions) {
723 self.options = options;
724 }
725
726 pub fn set_use_context(&mut self, use_context: bool) {
727 self.use_context = use_context;
728 }
729
730 pub fn clear_history(&mut self) {
731 for project_state in self.projects.values_mut() {
732 project_state.events.clear();
733 project_state.last_event.take();
734 }
735 }
736
737 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
738 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
739 project_state.events.clear();
740 project_state.last_event.take();
741 }
742 }
743
744 pub fn edit_history_for_project(
745 &self,
746 project: &Entity<Project>,
747 cx: &App,
748 ) -> Vec<StoredEvent> {
749 self.projects
750 .get(&project.entity_id())
751 .map(|project_state| project_state.events(cx))
752 .unwrap_or_default()
753 }
754
755 pub fn edit_history_for_project_with_pause_split_last_event(
756 &self,
757 project: &Entity<Project>,
758 cx: &App,
759 ) -> Vec<StoredEvent> {
760 self.projects
761 .get(&project.entity_id())
762 .map(|project_state| project_state.events_split_by_pause(cx))
763 .unwrap_or_default()
764 }
765
766 pub fn context_for_project<'a>(
767 &'a self,
768 project: &Entity<Project>,
769 cx: &'a App,
770 ) -> Arc<[RelatedFile]> {
771 self.projects
772 .get(&project.entity_id())
773 .map(|project| project.context.read(cx).related_files())
774 .unwrap_or_else(|| vec![].into())
775 }
776
777 pub fn context_for_project_with_buffers<'a>(
778 &'a self,
779 project: &Entity<Project>,
780 cx: &'a App,
781 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
782 self.projects
783 .get(&project.entity_id())
784 .map(|project| project.context.read(cx).related_files_with_buffers())
785 }
786
787 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
788 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
789 self.user_store.read(cx).edit_prediction_usage()
790 } else {
791 None
792 }
793 }
794
795 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
796 self.get_or_init_project(project, cx);
797 }
798
799 pub fn register_buffer(
800 &mut self,
801 buffer: &Entity<Buffer>,
802 project: &Entity<Project>,
803 cx: &mut Context<Self>,
804 ) {
805 let project_state = self.get_or_init_project(project, cx);
806 Self::register_buffer_impl(project_state, buffer, project, cx);
807 }
808
809 fn get_or_init_project(
810 &mut self,
811 project: &Entity<Project>,
812 cx: &mut Context<Self>,
813 ) -> &mut ProjectState {
814 let entity_id = project.entity_id();
815 self.projects
816 .entry(entity_id)
817 .or_insert_with(|| ProjectState {
818 context: {
819 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
820 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
821 this.handle_excerpt_store_event(entity_id, event);
822 })
823 .detach();
824 related_excerpt_store
825 },
826 events: VecDeque::new(),
827 last_event: None,
828 recent_paths: VecDeque::new(),
829 debug_tx: None,
830 registered_buffers: HashMap::default(),
831 current_prediction: None,
832 cancelled_predictions: HashSet::default(),
833 pending_predictions: ArrayVec::new(),
834 next_pending_prediction_id: 0,
835 last_prediction_refresh: None,
836 license_detection_watchers: HashMap::default(),
837 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
838 _subscription: cx.subscribe(&project, Self::handle_project_event),
839 })
840 }
841
842 pub fn remove_project(&mut self, project: &Entity<Project>) {
843 self.projects.remove(&project.entity_id());
844 }
845
846 fn handle_excerpt_store_event(
847 &mut self,
848 project_entity_id: EntityId,
849 event: &RelatedExcerptStoreEvent,
850 ) {
851 if let Some(project_state) = self.projects.get(&project_entity_id) {
852 if let Some(debug_tx) = project_state.debug_tx.clone() {
853 match event {
854 RelatedExcerptStoreEvent::StartedRefresh => {
855 debug_tx
856 .unbounded_send(DebugEvent::ContextRetrievalStarted(
857 ContextRetrievalStartedDebugEvent {
858 project_entity_id: project_entity_id,
859 timestamp: Instant::now(),
860 search_prompt: String::new(),
861 },
862 ))
863 .ok();
864 }
865 RelatedExcerptStoreEvent::FinishedRefresh {
866 cache_hit_count,
867 cache_miss_count,
868 mean_definition_latency,
869 max_definition_latency,
870 } => {
871 debug_tx
872 .unbounded_send(DebugEvent::ContextRetrievalFinished(
873 ContextRetrievalFinishedDebugEvent {
874 project_entity_id: project_entity_id,
875 timestamp: Instant::now(),
876 metadata: vec![
877 (
878 "Cache Hits",
879 format!(
880 "{}/{}",
881 cache_hit_count,
882 cache_hit_count + cache_miss_count
883 )
884 .into(),
885 ),
886 (
887 "Max LSP Time",
888 format!("{} ms", max_definition_latency.as_millis())
889 .into(),
890 ),
891 (
892 "Mean LSP Time",
893 format!("{} ms", mean_definition_latency.as_millis())
894 .into(),
895 ),
896 ],
897 },
898 ))
899 .ok();
900 }
901 }
902 }
903 }
904 }
905
906 pub fn debug_info(
907 &mut self,
908 project: &Entity<Project>,
909 cx: &mut Context<Self>,
910 ) -> mpsc::UnboundedReceiver<DebugEvent> {
911 let project_state = self.get_or_init_project(project, cx);
912 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
913 project_state.debug_tx = Some(debug_watch_tx);
914 debug_watch_rx
915 }
916
917 fn handle_project_event(
918 &mut self,
919 project: Entity<Project>,
920 event: &project::Event,
921 cx: &mut Context<Self>,
922 ) {
923 // TODO [zeta2] init with recent paths
924 match event {
925 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
926 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
927 return;
928 };
929 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
930 if let Some(path) = path {
931 if let Some(ix) = project_state
932 .recent_paths
933 .iter()
934 .position(|probe| probe == &path)
935 {
936 project_state.recent_paths.remove(ix);
937 }
938 project_state.recent_paths.push_front(path);
939 }
940 }
941 project::Event::DiagnosticsUpdated { .. } => {
942 if cx.has_flag::<Zeta2FeatureFlag>() {
943 self.refresh_prediction_from_diagnostics(project, cx);
944 }
945 }
946 _ => (),
947 }
948 }
949
950 fn register_buffer_impl<'a>(
951 project_state: &'a mut ProjectState,
952 buffer: &Entity<Buffer>,
953 project: &Entity<Project>,
954 cx: &mut Context<Self>,
955 ) -> &'a mut RegisteredBuffer {
956 let buffer_id = buffer.entity_id();
957
958 if let Some(file) = buffer.read(cx).file() {
959 let worktree_id = file.worktree_id(cx);
960 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
961 project_state
962 .license_detection_watchers
963 .entry(worktree_id)
964 .or_insert_with(|| {
965 let project_entity_id = project.entity_id();
966 cx.observe_release(&worktree, move |this, _worktree, _cx| {
967 let Some(project_state) = this.projects.get_mut(&project_entity_id)
968 else {
969 return;
970 };
971 project_state
972 .license_detection_watchers
973 .remove(&worktree_id);
974 })
975 .detach();
976 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
977 });
978 }
979 }
980
981 match project_state.registered_buffers.entry(buffer_id) {
982 hash_map::Entry::Occupied(entry) => entry.into_mut(),
983 hash_map::Entry::Vacant(entry) => {
984 let buf = buffer.read(cx);
985 let snapshot = buf.text_snapshot();
986 let file = buf.file().cloned();
987 let project_entity_id = project.entity_id();
988 entry.insert(RegisteredBuffer {
989 snapshot,
990 file,
991 last_position: None,
992 _subscriptions: [
993 cx.subscribe(buffer, {
994 let project = project.downgrade();
995 move |this, buffer, event, cx| {
996 if let language::BufferEvent::Edited = event
997 && let Some(project) = project.upgrade()
998 {
999 this.report_changes_for_buffer(&buffer, &project, cx);
1000 }
1001 }
1002 }),
1003 cx.observe_release(buffer, move |this, _buffer, _cx| {
1004 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1005 else {
1006 return;
1007 };
1008 project_state.registered_buffers.remove(&buffer_id);
1009 }),
1010 ],
1011 })
1012 }
1013 }
1014 }
1015
1016 fn report_changes_for_buffer(
1017 &mut self,
1018 buffer: &Entity<Buffer>,
1019 project: &Entity<Project>,
1020 cx: &mut Context<Self>,
1021 ) {
1022 let project_state = self.get_or_init_project(project, cx);
1023 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1024
1025 let buf = buffer.read(cx);
1026 let new_file = buf.file().cloned();
1027 let new_snapshot = buf.text_snapshot();
1028 if new_snapshot.version == registered_buffer.snapshot.version {
1029 return;
1030 }
1031
1032 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1033 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1034 let mut num_edits = 0usize;
1035 let mut total_deleted = 0usize;
1036 let mut total_inserted = 0usize;
1037 let mut edit_range: Option<Range<Anchor>> = None;
1038 let mut last_offset: Option<usize> = None;
1039
1040 for (edit, anchor_range) in
1041 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1042 {
1043 num_edits += 1;
1044 total_deleted += edit.old.len();
1045 total_inserted += edit.new.len();
1046 edit_range = Some(match edit_range {
1047 None => anchor_range,
1048 Some(acc) => acc.start..anchor_range.end,
1049 });
1050 last_offset = Some(edit.new.end);
1051 }
1052
1053 if num_edits > 0 {
1054 let action_type = match (total_deleted, total_inserted, num_edits) {
1055 (0, ins, n) if ins == n => UserActionType::InsertChar,
1056 (0, _, _) => UserActionType::InsertSelection,
1057 (del, 0, n) if del == n => UserActionType::DeleteChar,
1058 (_, 0, _) => UserActionType::DeleteSelection,
1059 (_, ins, n) if ins == n => UserActionType::InsertChar,
1060 (_, _, _) => UserActionType::InsertSelection,
1061 };
1062
1063 if let Some(offset) = last_offset {
1064 let point = new_snapshot.offset_to_point(offset);
1065 let timestamp_epoch_ms = SystemTime::now()
1066 .duration_since(UNIX_EPOCH)
1067 .map(|d| d.as_millis() as u64)
1068 .unwrap_or(0);
1069 project_state.record_user_action(UserActionRecord {
1070 action_type,
1071 buffer_id: buffer.entity_id(),
1072 line_number: point.row,
1073 offset,
1074 timestamp_epoch_ms,
1075 });
1076 }
1077 }
1078
1079 let events = &mut project_state.events;
1080
1081 let now = cx.background_executor().now();
1082 if let Some(last_event) = project_state.last_event.as_mut() {
1083 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1084 == last_event.new_snapshot.remote_id()
1085 && old_snapshot.version == last_event.new_snapshot.version;
1086
1087 let should_coalesce = is_next_snapshot_of_same_buffer
1088 && edit_range
1089 .as_ref()
1090 .zip(last_event.edit_range.as_ref())
1091 .is_some_and(|(a, b)| {
1092 let a = a.to_point(&new_snapshot);
1093 let b = b.to_point(&new_snapshot);
1094 if a.start > b.end {
1095 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1096 } else if b.start > a.end {
1097 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1098 } else {
1099 true
1100 }
1101 });
1102
1103 if should_coalesce {
1104 let pause_elapsed = last_event
1105 .last_edit_time
1106 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1107 .unwrap_or(false);
1108 if pause_elapsed {
1109 last_event.snapshot_after_last_editing_pause =
1110 Some(last_event.new_snapshot.clone());
1111 }
1112
1113 last_event.edit_range = edit_range;
1114 last_event.new_snapshot = new_snapshot;
1115 last_event.last_edit_time = Some(now);
1116 return;
1117 }
1118 }
1119
1120 if let Some(event) = project_state.last_event.take() {
1121 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1122 if events.len() + 1 >= EVENT_COUNT_MAX {
1123 events.pop_front();
1124 }
1125 events.push_back(event);
1126 }
1127 }
1128
1129 project_state.last_event = Some(LastEvent {
1130 old_file,
1131 new_file,
1132 old_snapshot,
1133 new_snapshot,
1134 edit_range,
1135 snapshot_after_last_editing_pause: None,
1136 last_edit_time: Some(now),
1137 });
1138 }
1139
1140 fn prediction_at(
1141 &mut self,
1142 buffer: &Entity<Buffer>,
1143 position: Option<language::Anchor>,
1144 project: &Entity<Project>,
1145 cx: &App,
1146 ) -> Option<BufferEditPrediction<'_>> {
1147 let project_state = self.projects.get_mut(&project.entity_id())?;
1148 if let Some(position) = position
1149 && let Some(buffer) = project_state
1150 .registered_buffers
1151 .get_mut(&buffer.entity_id())
1152 {
1153 buffer.last_position = Some(position);
1154 }
1155
1156 let CurrentEditPrediction {
1157 requested_by,
1158 prediction,
1159 ..
1160 } = project_state.current_prediction.as_ref()?;
1161
1162 if prediction.targets_buffer(buffer.read(cx)) {
1163 Some(BufferEditPrediction::Local { prediction })
1164 } else {
1165 let show_jump = match requested_by {
1166 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1167 requested_by_buffer_id == &buffer.entity_id()
1168 }
1169 PredictionRequestedBy::DiagnosticsUpdate => true,
1170 };
1171
1172 if show_jump {
1173 Some(BufferEditPrediction::Jump { prediction })
1174 } else {
1175 None
1176 }
1177 }
1178 }
1179
1180 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1181 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1182 return;
1183 };
1184
1185 let Some(current_prediction) = project_state.current_prediction.take() else {
1186 return;
1187 };
1188
1189 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1190 project_state.cancel_pending_prediction(pending_prediction, cx);
1191 }
1192
1193 match self.edit_prediction_model {
1194 EditPredictionModel::Sweep => {
1195 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1196 }
1197 EditPredictionModel::Mercury => {}
1198 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1199 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1200 }
1201 }
1202 }
1203
1204 async fn handle_rejected_predictions(
1205 rx: UnboundedReceiver<EditPredictionRejection>,
1206 client: Arc<Client>,
1207 llm_token: LlmApiToken,
1208 app_version: Version,
1209 background_executor: BackgroundExecutor,
1210 ) {
1211 let mut rx = std::pin::pin!(rx.peekable());
1212 let mut batched = Vec::new();
1213
1214 while let Some(rejection) = rx.next().await {
1215 batched.push(rejection);
1216
1217 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1218 select_biased! {
1219 next = rx.as_mut().peek().fuse() => {
1220 if next.is_some() {
1221 continue;
1222 }
1223 }
1224 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1225 }
1226 }
1227
1228 let url = client
1229 .http_client()
1230 .build_zed_llm_url("/predict_edits/reject", &[])
1231 .unwrap();
1232
1233 let flush_count = batched
1234 .len()
1235 // in case items have accumulated after failure
1236 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1237 let start = batched.len() - flush_count;
1238
1239 let body = RejectEditPredictionsBodyRef {
1240 rejections: &batched[start..],
1241 };
1242
1243 let result = Self::send_api_request::<()>(
1244 |builder| {
1245 let req = builder
1246 .uri(url.as_ref())
1247 .body(serde_json::to_string(&body)?.into());
1248 anyhow::Ok(req?)
1249 },
1250 client.clone(),
1251 llm_token.clone(),
1252 app_version.clone(),
1253 true,
1254 )
1255 .await;
1256
1257 if result.log_err().is_some() {
1258 batched.drain(start..);
1259 }
1260 }
1261 }
1262
1263 fn reject_current_prediction(
1264 &mut self,
1265 reason: EditPredictionRejectReason,
1266 project: &Entity<Project>,
1267 ) {
1268 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1269 project_state.pending_predictions.clear();
1270 if let Some(prediction) = project_state.current_prediction.take() {
1271 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1272 }
1273 };
1274 }
1275
1276 fn did_show_current_prediction(
1277 &mut self,
1278 project: &Entity<Project>,
1279 display_type: edit_prediction_types::SuggestionDisplayType,
1280 cx: &mut Context<Self>,
1281 ) {
1282 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1283 return;
1284 };
1285
1286 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1287 return;
1288 };
1289
1290 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1291 let previous_shown_with = current_prediction.shown_with;
1292
1293 if previous_shown_with.is_none() || !is_jump {
1294 current_prediction.shown_with = Some(display_type);
1295 }
1296
1297 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1298
1299 if is_first_non_jump_show {
1300 current_prediction.was_shown = true;
1301 }
1302
1303 let display_type_changed = previous_shown_with != Some(display_type);
1304
1305 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1306 sweep_ai::edit_prediction_shown(
1307 &self.sweep_ai,
1308 self.client.clone(),
1309 ¤t_prediction.prediction,
1310 display_type,
1311 cx,
1312 );
1313 }
1314
1315 if is_first_non_jump_show {
1316 self.shown_predictions
1317 .push_front(current_prediction.prediction.clone());
1318 if self.shown_predictions.len() > 50 {
1319 let completion = self.shown_predictions.pop_back().unwrap();
1320 self.rated_predictions.remove(&completion.id);
1321 }
1322 }
1323 }
1324
1325 fn reject_prediction(
1326 &mut self,
1327 prediction_id: EditPredictionId,
1328 reason: EditPredictionRejectReason,
1329 was_shown: bool,
1330 ) {
1331 match self.edit_prediction_model {
1332 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1333 if self.custom_predict_edits_url.is_some() {
1334 return;
1335 }
1336 }
1337 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1338 }
1339
1340 self.reject_predictions_tx
1341 .unbounded_send(EditPredictionRejection {
1342 request_id: prediction_id.to_string(),
1343 reason,
1344 was_shown,
1345 })
1346 .log_err();
1347 }
1348
1349 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1350 self.projects
1351 .get(&project.entity_id())
1352 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1353 }
1354
1355 pub fn refresh_prediction_from_buffer(
1356 &mut self,
1357 project: Entity<Project>,
1358 buffer: Entity<Buffer>,
1359 position: language::Anchor,
1360 cx: &mut Context<Self>,
1361 ) {
1362 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1363 let Some(request_task) = this
1364 .update(cx, |this, cx| {
1365 this.request_prediction(
1366 &project,
1367 &buffer,
1368 position,
1369 PredictEditsRequestTrigger::Other,
1370 cx,
1371 )
1372 })
1373 .log_err()
1374 else {
1375 return Task::ready(anyhow::Ok(None));
1376 };
1377
1378 cx.spawn(async move |_cx| {
1379 request_task.await.map(|prediction_result| {
1380 prediction_result.map(|prediction_result| {
1381 (
1382 prediction_result,
1383 PredictionRequestedBy::Buffer(buffer.entity_id()),
1384 )
1385 })
1386 })
1387 })
1388 })
1389 }
1390
1391 pub fn refresh_prediction_from_diagnostics(
1392 &mut self,
1393 project: Entity<Project>,
1394 cx: &mut Context<Self>,
1395 ) {
1396 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1397 return;
1398 };
1399
1400 // Prefer predictions from buffer
1401 if project_state.current_prediction.is_some() {
1402 return;
1403 };
1404
1405 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1406 let Some((active_buffer, snapshot, cursor_point)) = this
1407 .read_with(cx, |this, cx| {
1408 let project_state = this.projects.get(&project.entity_id())?;
1409 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1410 let snapshot = buffer.read(cx).snapshot();
1411
1412 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1413 return None;
1414 }
1415
1416 let cursor_point = position
1417 .map(|pos| pos.to_point(&snapshot))
1418 .unwrap_or_default();
1419
1420 Some((buffer, snapshot, cursor_point))
1421 })
1422 .log_err()
1423 .flatten()
1424 else {
1425 return Task::ready(anyhow::Ok(None));
1426 };
1427
1428 cx.spawn(async move |cx| {
1429 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1430 active_buffer,
1431 &snapshot,
1432 Default::default(),
1433 cursor_point,
1434 &project,
1435 cx,
1436 )
1437 .await?
1438 else {
1439 return anyhow::Ok(None);
1440 };
1441
1442 let Some(prediction_result) = this
1443 .update(cx, |this, cx| {
1444 this.request_prediction(
1445 &project,
1446 &jump_buffer,
1447 jump_position,
1448 PredictEditsRequestTrigger::Diagnostics,
1449 cx,
1450 )
1451 })?
1452 .await?
1453 else {
1454 return anyhow::Ok(None);
1455 };
1456
1457 this.update(cx, |this, cx| {
1458 Some((
1459 if this
1460 .get_or_init_project(&project, cx)
1461 .current_prediction
1462 .is_none()
1463 {
1464 prediction_result
1465 } else {
1466 EditPredictionResult {
1467 id: prediction_result.id,
1468 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1469 }
1470 },
1471 PredictionRequestedBy::DiagnosticsUpdate,
1472 ))
1473 })
1474 })
1475 });
1476 }
1477
1478 fn predictions_enabled_at(
1479 snapshot: &BufferSnapshot,
1480 position: Option<language::Anchor>,
1481 cx: &App,
1482 ) -> bool {
1483 let file = snapshot.file();
1484 let all_settings = all_language_settings(file, cx);
1485 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1486 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1487 {
1488 return false;
1489 }
1490
1491 if let Some(last_position) = position {
1492 let settings = snapshot.settings_at(last_position, cx);
1493
1494 if !settings.edit_predictions_disabled_in.is_empty()
1495 && let Some(scope) = snapshot.language_scope_at(last_position)
1496 && let Some(scope_name) = scope.override_name()
1497 && settings
1498 .edit_predictions_disabled_in
1499 .iter()
1500 .any(|s| s == scope_name)
1501 {
1502 return false;
1503 }
1504 }
1505
1506 true
1507 }
1508
1509 #[cfg(not(test))]
1510 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1511 #[cfg(test)]
1512 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1513
1514 fn queue_prediction_refresh(
1515 &mut self,
1516 project: Entity<Project>,
1517 throttle_entity: EntityId,
1518 cx: &mut Context<Self>,
1519 do_refresh: impl FnOnce(
1520 WeakEntity<Self>,
1521 &mut AsyncApp,
1522 )
1523 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1524 + 'static,
1525 ) {
1526 let project_state = self.get_or_init_project(&project, cx);
1527 let pending_prediction_id = project_state.next_pending_prediction_id;
1528 project_state.next_pending_prediction_id += 1;
1529 let last_request = project_state.last_prediction_refresh;
1530
1531 let task = cx.spawn(async move |this, cx| {
1532 if let Some((last_entity, last_timestamp)) = last_request
1533 && throttle_entity == last_entity
1534 && let Some(timeout) =
1535 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1536 {
1537 cx.background_executor().timer(timeout).await;
1538 }
1539
1540 // If this task was cancelled before the throttle timeout expired,
1541 // do not perform a request.
1542 let mut is_cancelled = true;
1543 this.update(cx, |this, cx| {
1544 let project_state = this.get_or_init_project(&project, cx);
1545 if !project_state
1546 .cancelled_predictions
1547 .remove(&pending_prediction_id)
1548 {
1549 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1550 is_cancelled = false;
1551 }
1552 })
1553 .ok();
1554 if is_cancelled {
1555 return None;
1556 }
1557
1558 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1559 let new_prediction_id = new_prediction_result
1560 .as_ref()
1561 .map(|(prediction, _)| prediction.id.clone());
1562
1563 // When a prediction completes, remove it from the pending list, and cancel
1564 // any pending predictions that were enqueued before it.
1565 this.update(cx, |this, cx| {
1566 let project_state = this.get_or_init_project(&project, cx);
1567
1568 let is_cancelled = project_state
1569 .cancelled_predictions
1570 .remove(&pending_prediction_id);
1571
1572 let new_current_prediction = if !is_cancelled
1573 && let Some((prediction_result, requested_by)) = new_prediction_result
1574 {
1575 match prediction_result.prediction {
1576 Ok(prediction) => {
1577 let new_prediction = CurrentEditPrediction {
1578 requested_by,
1579 prediction,
1580 was_shown: false,
1581 shown_with: None,
1582 };
1583
1584 if let Some(current_prediction) =
1585 project_state.current_prediction.as_ref()
1586 {
1587 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1588 {
1589 this.reject_current_prediction(
1590 EditPredictionRejectReason::Replaced,
1591 &project,
1592 );
1593
1594 Some(new_prediction)
1595 } else {
1596 this.reject_prediction(
1597 new_prediction.prediction.id,
1598 EditPredictionRejectReason::CurrentPreferred,
1599 false,
1600 );
1601 None
1602 }
1603 } else {
1604 Some(new_prediction)
1605 }
1606 }
1607 Err(reject_reason) => {
1608 this.reject_prediction(prediction_result.id, reject_reason, false);
1609 None
1610 }
1611 }
1612 } else {
1613 None
1614 };
1615
1616 let project_state = this.get_or_init_project(&project, cx);
1617
1618 if let Some(new_prediction) = new_current_prediction {
1619 project_state.current_prediction = Some(new_prediction);
1620 }
1621
1622 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1623 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1624 if pending_prediction.id == pending_prediction_id {
1625 pending_predictions.remove(ix);
1626 for pending_prediction in pending_predictions.drain(0..ix) {
1627 project_state.cancel_pending_prediction(pending_prediction, cx)
1628 }
1629 break;
1630 }
1631 }
1632 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1633 cx.notify();
1634 })
1635 .ok();
1636
1637 new_prediction_id
1638 });
1639
1640 if project_state.pending_predictions.len() <= 1 {
1641 project_state.pending_predictions.push(PendingPrediction {
1642 id: pending_prediction_id,
1643 task,
1644 });
1645 } else if project_state.pending_predictions.len() == 2 {
1646 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1647 project_state.pending_predictions.push(PendingPrediction {
1648 id: pending_prediction_id,
1649 task,
1650 });
1651 project_state.cancel_pending_prediction(pending_prediction, cx);
1652 }
1653 }
1654
1655 pub fn request_prediction(
1656 &mut self,
1657 project: &Entity<Project>,
1658 active_buffer: &Entity<Buffer>,
1659 position: language::Anchor,
1660 trigger: PredictEditsRequestTrigger,
1661 cx: &mut Context<Self>,
1662 ) -> Task<Result<Option<EditPredictionResult>>> {
1663 self.request_prediction_internal(
1664 project.clone(),
1665 active_buffer.clone(),
1666 position,
1667 trigger,
1668 cx.has_flag::<Zeta2FeatureFlag>(),
1669 cx,
1670 )
1671 }
1672
1673 fn request_prediction_internal(
1674 &mut self,
1675 project: Entity<Project>,
1676 active_buffer: Entity<Buffer>,
1677 position: language::Anchor,
1678 trigger: PredictEditsRequestTrigger,
1679 allow_jump: bool,
1680 cx: &mut Context<Self>,
1681 ) -> Task<Result<Option<EditPredictionResult>>> {
1682 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1683
1684 self.get_or_init_project(&project, cx);
1685 let project_state = self.projects.get(&project.entity_id()).unwrap();
1686 let stored_events = project_state.events(cx);
1687 let has_events = !stored_events.is_empty();
1688 let events: Vec<Arc<zeta_prompt::Event>> =
1689 stored_events.into_iter().map(|e| e.event).collect();
1690 let debug_tx = project_state.debug_tx.clone();
1691
1692 let snapshot = active_buffer.read(cx).snapshot();
1693 let cursor_point = position.to_point(&snapshot);
1694 let current_offset = position.to_offset(&snapshot);
1695
1696 let mut user_actions: Vec<UserActionRecord> =
1697 project_state.user_actions.iter().cloned().collect();
1698
1699 if let Some(last_action) = user_actions.last() {
1700 if last_action.buffer_id == active_buffer.entity_id()
1701 && current_offset != last_action.offset
1702 {
1703 let timestamp_epoch_ms = SystemTime::now()
1704 .duration_since(UNIX_EPOCH)
1705 .map(|d| d.as_millis() as u64)
1706 .unwrap_or(0);
1707 user_actions.push(UserActionRecord {
1708 action_type: UserActionType::CursorMovement,
1709 buffer_id: active_buffer.entity_id(),
1710 line_number: cursor_point.row,
1711 offset: current_offset,
1712 timestamp_epoch_ms,
1713 });
1714 }
1715 }
1716 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1717 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1718 let diagnostic_search_range =
1719 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1720
1721 let related_files = if self.use_context {
1722 self.context_for_project(&project, cx)
1723 } else {
1724 Vec::new().into()
1725 };
1726
1727 let inputs = EditPredictionModelInput {
1728 project: project.clone(),
1729 buffer: active_buffer.clone(),
1730 snapshot: snapshot.clone(),
1731 position,
1732 events,
1733 related_files,
1734 recent_paths: project_state.recent_paths.clone(),
1735 trigger,
1736 diagnostic_search_range: diagnostic_search_range.clone(),
1737 debug_tx,
1738 user_actions,
1739 };
1740
1741 let can_collect_example = snapshot
1742 .file()
1743 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1744 && self.can_collect_events(&inputs.events, cx);
1745
1746 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1747 let events_for_capture =
1748 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1749 if let Some(example_task) = capture_example::capture_example(
1750 project.clone(),
1751 active_buffer.clone(),
1752 position,
1753 events_for_capture,
1754 false,
1755 cx,
1756 ) {
1757 cx.spawn(async move |_this, _cx| {
1758 let example = example_task.await?;
1759 telemetry::event!("Edit Prediction Example Captured", example = example);
1760 anyhow::Ok(())
1761 })
1762 .detach_and_log_err(cx);
1763 }
1764 }
1765 let task = match self.edit_prediction_model {
1766 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1767 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1768 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1769 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1770 };
1771
1772 cx.spawn(async move |this, cx| {
1773 let prediction = task.await?;
1774
1775 if prediction.is_none() && allow_jump {
1776 let cursor_point = position.to_point(&snapshot);
1777 if has_events
1778 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1779 active_buffer.clone(),
1780 &snapshot,
1781 diagnostic_search_range,
1782 cursor_point,
1783 &project,
1784 cx,
1785 )
1786 .await?
1787 {
1788 return this
1789 .update(cx, |this, cx| {
1790 this.request_prediction_internal(
1791 project,
1792 jump_buffer,
1793 jump_position,
1794 trigger,
1795 false,
1796 cx,
1797 )
1798 })?
1799 .await;
1800 }
1801
1802 return anyhow::Ok(None);
1803 }
1804
1805 Ok(prediction)
1806 })
1807 }
1808
1809 async fn next_diagnostic_location(
1810 active_buffer: Entity<Buffer>,
1811 active_buffer_snapshot: &BufferSnapshot,
1812 active_buffer_diagnostic_search_range: Range<Point>,
1813 active_buffer_cursor_point: Point,
1814 project: &Entity<Project>,
1815 cx: &mut AsyncApp,
1816 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1817 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1818 let mut jump_location = active_buffer_snapshot
1819 .diagnostic_groups(None)
1820 .into_iter()
1821 .filter_map(|(_, group)| {
1822 let range = &group.entries[group.primary_ix]
1823 .range
1824 .to_point(&active_buffer_snapshot);
1825 if range.overlaps(&active_buffer_diagnostic_search_range) {
1826 None
1827 } else {
1828 Some(range.start)
1829 }
1830 })
1831 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1832 .map(|position| {
1833 (
1834 active_buffer.clone(),
1835 active_buffer_snapshot.anchor_before(position),
1836 )
1837 });
1838
1839 if jump_location.is_none() {
1840 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1841 let file = buffer.file()?;
1842
1843 Some(ProjectPath {
1844 worktree_id: file.worktree_id(cx),
1845 path: file.path().clone(),
1846 })
1847 });
1848
1849 let buffer_task = project.update(cx, |project, cx| {
1850 let (path, _, _) = project
1851 .diagnostic_summaries(false, cx)
1852 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1853 .max_by_key(|(path, _, _)| {
1854 // find the buffer with errors that shares most parent directories
1855 path.path
1856 .components()
1857 .zip(
1858 active_buffer_path
1859 .as_ref()
1860 .map(|p| p.path.components())
1861 .unwrap_or_default(),
1862 )
1863 .take_while(|(a, b)| a == b)
1864 .count()
1865 })?;
1866
1867 Some(project.open_buffer(path, cx))
1868 });
1869
1870 if let Some(buffer_task) = buffer_task {
1871 let closest_buffer = buffer_task.await?;
1872
1873 jump_location = closest_buffer
1874 .read_with(cx, |buffer, _cx| {
1875 buffer
1876 .buffer_diagnostics(None)
1877 .into_iter()
1878 .min_by_key(|entry| entry.diagnostic.severity)
1879 .map(|entry| entry.range.start)
1880 })
1881 .map(|position| (closest_buffer, position));
1882 }
1883 }
1884
1885 anyhow::Ok(jump_location)
1886 }
1887
1888 async fn send_raw_llm_request(
1889 request: RawCompletionRequest,
1890 client: Arc<Client>,
1891 custom_url: Option<Arc<Url>>,
1892 llm_token: LlmApiToken,
1893 app_version: Version,
1894 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1895 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1896 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1897 let url = if let Some(custom_url) = custom_url {
1898 custom_url.as_ref().clone()
1899 } else {
1900 client
1901 .http_client()
1902 .build_zed_llm_url("/predict_edits/raw", &[])?
1903 };
1904
1905 #[cfg(feature = "cli-support")]
1906 let cache_key = if let Some(cache) = eval_cache {
1907 use collections::FxHasher;
1908 use std::hash::{Hash, Hasher};
1909
1910 let mut hasher = FxHasher::default();
1911 url.hash(&mut hasher);
1912 let request_str = serde_json::to_string_pretty(&request)?;
1913 request_str.hash(&mut hasher);
1914 let hash = hasher.finish();
1915
1916 let key = (eval_cache_kind, hash);
1917 if let Some(response_str) = cache.read(key) {
1918 return Ok((serde_json::from_str(&response_str)?, None));
1919 }
1920
1921 Some((cache, request_str, key))
1922 } else {
1923 None
1924 };
1925
1926 let (response, usage) = Self::send_api_request(
1927 |builder| {
1928 let req = builder
1929 .uri(url.as_ref())
1930 .body(serde_json::to_string(&request)?.into());
1931 Ok(req?)
1932 },
1933 client,
1934 llm_token,
1935 app_version,
1936 true,
1937 )
1938 .await?;
1939
1940 #[cfg(feature = "cli-support")]
1941 if let Some((cache, request, key)) = cache_key {
1942 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1943 }
1944
1945 Ok((response, usage))
1946 }
1947
1948 fn handle_api_response<T>(
1949 this: &WeakEntity<Self>,
1950 response: Result<(T, Option<EditPredictionUsage>)>,
1951 cx: &mut gpui::AsyncApp,
1952 ) -> Result<T> {
1953 match response {
1954 Ok((data, usage)) => {
1955 if let Some(usage) = usage {
1956 this.update(cx, |this, cx| {
1957 this.user_store.update(cx, |user_store, cx| {
1958 user_store.update_edit_prediction_usage(usage, cx);
1959 });
1960 })
1961 .ok();
1962 }
1963 Ok(data)
1964 }
1965 Err(err) => {
1966 if err.is::<ZedUpdateRequiredError>() {
1967 cx.update(|cx| {
1968 this.update(cx, |this, _cx| {
1969 this.update_required = true;
1970 })
1971 .ok();
1972
1973 let error_message: SharedString = err.to_string().into();
1974 show_app_notification(
1975 NotificationId::unique::<ZedUpdateRequiredError>(),
1976 cx,
1977 move |cx| {
1978 cx.new(|cx| {
1979 ErrorMessagePrompt::new(error_message.clone(), cx)
1980 .with_link_button("Update Zed", "https://zed.dev/releases")
1981 })
1982 },
1983 );
1984 });
1985 }
1986 Err(err)
1987 }
1988 }
1989 }
1990
1991 async fn send_api_request<Res>(
1992 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1993 client: Arc<Client>,
1994 llm_token: LlmApiToken,
1995 app_version: Version,
1996 require_auth: bool,
1997 ) -> Result<(Res, Option<EditPredictionUsage>)>
1998 where
1999 Res: DeserializeOwned,
2000 {
2001 let http_client = client.http_client();
2002
2003 let mut token = if require_auth {
2004 Some(llm_token.acquire(&client).await?)
2005 } else {
2006 llm_token.acquire(&client).await.ok()
2007 };
2008 let mut did_retry = false;
2009
2010 loop {
2011 let request_builder = http_client::Request::builder().method(Method::POST);
2012
2013 let mut request_builder = request_builder
2014 .header("Content-Type", "application/json")
2015 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2016
2017 // Only add Authorization header if we have a token
2018 if let Some(ref token_value) = token {
2019 request_builder =
2020 request_builder.header("Authorization", format!("Bearer {}", token_value));
2021 }
2022
2023 let request = build(request_builder)?;
2024
2025 let mut response = http_client.send(request).await?;
2026
2027 if let Some(minimum_required_version) = response
2028 .headers()
2029 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2030 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2031 {
2032 anyhow::ensure!(
2033 app_version >= minimum_required_version,
2034 ZedUpdateRequiredError {
2035 minimum_version: minimum_required_version
2036 }
2037 );
2038 }
2039
2040 if response.status().is_success() {
2041 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2042
2043 let mut body = Vec::new();
2044 response.body_mut().read_to_end(&mut body).await?;
2045 return Ok((serde_json::from_slice(&body)?, usage));
2046 } else if !did_retry
2047 && token.is_some()
2048 && response
2049 .headers()
2050 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2051 .is_some()
2052 {
2053 did_retry = true;
2054 token = Some(llm_token.refresh(&client).await?);
2055 } else {
2056 let mut body = String::new();
2057 response.body_mut().read_to_string(&mut body).await?;
2058 anyhow::bail!(
2059 "Request failed with status: {:?}\nBody: {}",
2060 response.status(),
2061 body
2062 );
2063 }
2064 }
2065 }
2066
2067 pub fn refresh_context(
2068 &mut self,
2069 project: &Entity<Project>,
2070 buffer: &Entity<language::Buffer>,
2071 cursor_position: language::Anchor,
2072 cx: &mut Context<Self>,
2073 ) {
2074 if self.use_context {
2075 self.get_or_init_project(project, cx)
2076 .context
2077 .update(cx, |store, cx| {
2078 store.refresh(buffer.clone(), cursor_position, cx);
2079 });
2080 }
2081 }
2082
2083 #[cfg(feature = "cli-support")]
2084 pub fn set_context_for_buffer(
2085 &mut self,
2086 project: &Entity<Project>,
2087 related_files: Vec<RelatedFile>,
2088 cx: &mut Context<Self>,
2089 ) {
2090 self.get_or_init_project(project, cx)
2091 .context
2092 .update(cx, |store, _| {
2093 store.set_related_files(related_files);
2094 });
2095 }
2096
2097 fn is_file_open_source(
2098 &self,
2099 project: &Entity<Project>,
2100 file: &Arc<dyn File>,
2101 cx: &App,
2102 ) -> bool {
2103 if !file.is_local() || file.is_private() {
2104 return false;
2105 }
2106 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2107 return false;
2108 };
2109 project_state
2110 .license_detection_watchers
2111 .get(&file.worktree_id(cx))
2112 .as_ref()
2113 .is_some_and(|watcher| watcher.is_project_open_source())
2114 }
2115
2116 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2117 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2118 }
2119
2120 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2121 if !self.data_collection_choice.is_enabled(cx) {
2122 return false;
2123 }
2124 events.iter().all(|event| {
2125 matches!(
2126 event.as_ref(),
2127 zeta_prompt::Event::BufferChange {
2128 in_open_source_repo: true,
2129 ..
2130 }
2131 )
2132 })
2133 }
2134
2135 fn load_data_collection_choice() -> DataCollectionChoice {
2136 let choice = KEY_VALUE_STORE
2137 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2138 .log_err()
2139 .flatten();
2140
2141 match choice.as_deref() {
2142 Some("true") => DataCollectionChoice::Enabled,
2143 Some("false") => DataCollectionChoice::Disabled,
2144 Some(_) => {
2145 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2146 DataCollectionChoice::NotAnswered
2147 }
2148 None => DataCollectionChoice::NotAnswered,
2149 }
2150 }
2151
2152 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2153 self.data_collection_choice = self.data_collection_choice.toggle();
2154 let new_choice = self.data_collection_choice;
2155 let is_enabled = new_choice.is_enabled(cx);
2156 db::write_and_log(cx, move || {
2157 KEY_VALUE_STORE.write_kvp(
2158 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2159 is_enabled.to_string(),
2160 )
2161 });
2162 }
2163
2164 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2165 self.shown_predictions.iter()
2166 }
2167
2168 pub fn shown_completions_len(&self) -> usize {
2169 self.shown_predictions.len()
2170 }
2171
2172 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2173 self.rated_predictions.contains(id)
2174 }
2175
2176 pub fn rate_prediction(
2177 &mut self,
2178 prediction: &EditPrediction,
2179 rating: EditPredictionRating,
2180 feedback: String,
2181 cx: &mut Context<Self>,
2182 ) {
2183 self.rated_predictions.insert(prediction.id.clone());
2184 telemetry::event!(
2185 "Edit Prediction Rated",
2186 rating,
2187 inputs = prediction.inputs,
2188 output = prediction
2189 .edit_preview
2190 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2191 feedback
2192 );
2193 self.client.telemetry().flush_events().detach();
2194 cx.notify();
2195 }
2196
2197 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2198 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2199 && all_language_settings(None, cx).edit_predictions.use_context;
2200 }
2201}
2202
2203#[derive(Error, Debug)]
2204#[error(
2205 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2206)]
2207pub struct ZedUpdateRequiredError {
2208 minimum_version: Version,
2209}
2210
2211#[cfg(feature = "cli-support")]
2212pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2213
2214#[cfg(feature = "cli-support")]
2215#[derive(Debug, Clone, Copy, PartialEq)]
2216pub enum EvalCacheEntryKind {
2217 Context,
2218 Search,
2219 Prediction,
2220}
2221
2222#[cfg(feature = "cli-support")]
2223impl std::fmt::Display for EvalCacheEntryKind {
2224 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2225 match self {
2226 EvalCacheEntryKind::Search => write!(f, "search"),
2227 EvalCacheEntryKind::Context => write!(f, "context"),
2228 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2229 }
2230 }
2231}
2232
2233#[cfg(feature = "cli-support")]
2234pub trait EvalCache: Send + Sync {
2235 fn read(&self, key: EvalCacheKey) -> Option<String>;
2236 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2237}
2238
2239#[derive(Debug, Clone, Copy)]
2240pub enum DataCollectionChoice {
2241 NotAnswered,
2242 Enabled,
2243 Disabled,
2244}
2245
2246impl DataCollectionChoice {
2247 pub fn is_enabled(self, cx: &App) -> bool {
2248 if cx.is_staff() {
2249 return true;
2250 }
2251 match self {
2252 Self::Enabled => true,
2253 Self::NotAnswered | Self::Disabled => false,
2254 }
2255 }
2256
2257 #[must_use]
2258 pub fn toggle(&self) -> DataCollectionChoice {
2259 match self {
2260 Self::Enabled => Self::Disabled,
2261 Self::Disabled => Self::Enabled,
2262 Self::NotAnswered => Self::Enabled,
2263 }
2264 }
2265}
2266
2267impl From<bool> for DataCollectionChoice {
2268 fn from(value: bool) -> Self {
2269 match value {
2270 true => DataCollectionChoice::Enabled,
2271 false => DataCollectionChoice::Disabled,
2272 }
2273 }
2274}
2275
2276struct ZedPredictUpsell;
2277
2278impl Dismissable for ZedPredictUpsell {
2279 const KEY: &'static str = "dismissed-edit-predict-upsell";
2280
2281 fn dismissed() -> bool {
2282 // To make this backwards compatible with older versions of Zed, we
2283 // check if the user has seen the previous Edit Prediction Onboarding
2284 // before, by checking the data collection choice which was written to
2285 // the database once the user clicked on "Accept and Enable"
2286 if KEY_VALUE_STORE
2287 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2288 .log_err()
2289 .is_some_and(|s| s.is_some())
2290 {
2291 return true;
2292 }
2293
2294 KEY_VALUE_STORE
2295 .read_kvp(Self::KEY)
2296 .log_err()
2297 .is_some_and(|s| s.is_some())
2298 }
2299}
2300
2301pub fn should_show_upsell_modal() -> bool {
2302 !ZedPredictUpsell::dismissed()
2303}
2304
2305pub fn init(cx: &mut App) {
2306 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2307 workspace.register_action(
2308 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2309 ZedPredictModal::toggle(
2310 workspace,
2311 workspace.user_store().clone(),
2312 workspace.client().clone(),
2313 window,
2314 cx,
2315 )
2316 },
2317 );
2318
2319 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2320 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2321 settings
2322 .project
2323 .all_languages
2324 .features
2325 .get_or_insert_default()
2326 .edit_prediction_provider = Some(EditPredictionProvider::None)
2327 });
2328 });
2329 })
2330 .detach();
2331}