1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::Copilot;
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, RefreshLlmTokenListener};
33use project::{Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use text::Edit;
40use workspace::Workspace;
41use zeta_prompt::ZetaPromptInput;
42use zeta_prompt::ZetaVersion;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
50use std::{env, mem};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59mod onboarding_modal;
60pub mod open_ai_response;
61mod prediction;
62pub mod sweep_ai;
63
64pub mod udiff;
65
66mod capture_example;
67mod zed_edit_prediction_delegate;
68pub mod zeta1;
69pub mod zeta2;
70
71#[cfg(test)]
72mod edit_prediction_tests;
73
74use crate::capture_example::should_sample_edit_prediction_example_capture;
75use crate::license_detection::LicenseDetectionWatcher;
76use crate::mercury::Mercury;
77use crate::onboarding_modal::ZedPredictModal;
78pub use crate::prediction::EditPrediction;
79pub use crate::prediction::EditPredictionId;
80use crate::prediction::EditPredictionResult;
81pub use crate::sweep_ai::SweepAi;
82pub use capture_example::capture_example;
83pub use language_model::ApiKeyState;
84pub use telemetry_events::EditPredictionRating;
85pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
86
87actions!(
88 edit_prediction,
89 [
90 /// Resets the edit prediction onboarding state.
91 ResetOnboarding,
92 /// Clears the edit prediction history.
93 ClearHistory,
94 ]
95);
96
97/// Maximum number of events to track.
98const EVENT_COUNT_MAX: usize = 6;
99const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
100const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
101const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
102const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
103
104pub struct SweepFeatureFlag;
105
106impl FeatureFlag for SweepFeatureFlag {
107 const NAME: &str = "sweep-ai";
108}
109
110pub struct MercuryFeatureFlag;
111
112impl FeatureFlag for MercuryFeatureFlag {
113 const NAME: &str = "mercury";
114}
115
116static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
117 LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
118
119pub struct Zeta2FeatureFlag;
120
121impl FeatureFlag for Zeta2FeatureFlag {
122 const NAME: &'static str = "zeta2";
123
124 fn enabled_for_staff() -> bool {
125 true
126 }
127}
128
129pub struct EditPredictionExampleCaptureFeatureFlag;
130
131impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
132 const NAME: &'static str = "edit-prediction-example-capture";
133
134 fn enabled_for_staff() -> bool {
135 true
136 }
137}
138
139#[derive(Clone)]
140struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
141
142impl Global for EditPredictionStoreGlobal {}
143
144pub struct EditPredictionStore {
145 client: Arc<Client>,
146 user_store: Entity<UserStore>,
147 llm_token: LlmApiToken,
148 _llm_token_subscription: Subscription,
149 projects: HashMap<EntityId, ProjectState>,
150 use_context: bool,
151 update_required: bool,
152 edit_prediction_model: EditPredictionModel,
153 pub sweep_ai: SweepAi,
154 pub mercury: Mercury,
155 data_collection_choice: DataCollectionChoice,
156 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
157 shown_predictions: VecDeque<EditPrediction>,
158 rated_predictions: HashSet<EditPredictionId>,
159 custom_predict_edits_url: Option<Arc<Url>>,
160}
161
162#[derive(Copy, Clone, Default, PartialEq, Eq)]
163pub enum EditPredictionModel {
164 #[default]
165 Zeta1,
166 Zeta2 {
167 version: ZetaVersion,
168 },
169 Sweep,
170 Mercury,
171}
172
173pub struct EditPredictionModelInput {
174 project: Entity<Project>,
175 buffer: Entity<Buffer>,
176 snapshot: BufferSnapshot,
177 position: Anchor,
178 events: Vec<Arc<zeta_prompt::Event>>,
179 related_files: Vec<RelatedFile>,
180 recent_paths: VecDeque<ProjectPath>,
181 trigger: PredictEditsRequestTrigger,
182 diagnostic_search_range: Range<Point>,
183 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
184 pub user_actions: Vec<UserActionRecord>,
185}
186
187#[derive(Debug)]
188pub enum DebugEvent {
189 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
190 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
191 EditPredictionStarted(EditPredictionStartedDebugEvent),
192 EditPredictionFinished(EditPredictionFinishedDebugEvent),
193}
194
195#[derive(Debug)]
196pub struct ContextRetrievalStartedDebugEvent {
197 pub project_entity_id: EntityId,
198 pub timestamp: Instant,
199 pub search_prompt: String,
200}
201
202#[derive(Debug)]
203pub struct ContextRetrievalFinishedDebugEvent {
204 pub project_entity_id: EntityId,
205 pub timestamp: Instant,
206 pub metadata: Vec<(&'static str, SharedString)>,
207}
208
209#[derive(Debug)]
210pub struct EditPredictionStartedDebugEvent {
211 pub buffer: WeakEntity<Buffer>,
212 pub position: Anchor,
213 pub prompt: Option<String>,
214}
215
216#[derive(Debug)]
217pub struct EditPredictionFinishedDebugEvent {
218 pub buffer: WeakEntity<Buffer>,
219 pub position: Anchor,
220 pub model_output: Option<String>,
221}
222
223const USER_ACTION_HISTORY_SIZE: usize = 16;
224
225#[derive(Clone, Debug)]
226pub struct UserActionRecord {
227 pub action_type: UserActionType,
228 pub buffer_id: EntityId,
229 pub line_number: u32,
230 pub offset: usize,
231 pub timestamp_epoch_ms: u64,
232}
233
234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
235pub enum UserActionType {
236 InsertChar,
237 InsertSelection,
238 DeleteChar,
239 DeleteSelection,
240 CursorMovement,
241}
242
243/// An event with associated metadata for reconstructing buffer state.
244#[derive(Clone)]
245pub struct StoredEvent {
246 pub event: Arc<zeta_prompt::Event>,
247 pub old_snapshot: TextBufferSnapshot,
248}
249
250struct ProjectState {
251 events: VecDeque<StoredEvent>,
252 last_event: Option<LastEvent>,
253 recent_paths: VecDeque<ProjectPath>,
254 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
255 current_prediction: Option<CurrentEditPrediction>,
256 next_pending_prediction_id: usize,
257 pending_predictions: ArrayVec<PendingPrediction, 2>,
258 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
259 last_prediction_refresh: Option<(EntityId, Instant)>,
260 cancelled_predictions: HashSet<usize>,
261 context: Entity<RelatedExcerptStore>,
262 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
263 user_actions: VecDeque<UserActionRecord>,
264 _subscription: gpui::Subscription,
265 copilot: Option<Entity<Copilot>>,
266}
267
268impl ProjectState {
269 fn record_user_action(&mut self, action: UserActionRecord) {
270 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
271 self.user_actions.pop_front();
272 }
273 self.user_actions.push_back(action);
274 }
275
276 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
277 self.events
278 .iter()
279 .cloned()
280 .chain(
281 self.last_event
282 .as_ref()
283 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
284 )
285 .collect()
286 }
287
288 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
289 self.events
290 .iter()
291 .cloned()
292 .chain(self.last_event.as_ref().iter().flat_map(|event| {
293 let (one, two) = event.split_by_pause();
294 let one = one.finalize(&self.license_detection_watchers, cx);
295 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
296 one.into_iter().chain(two)
297 }))
298 .collect()
299 }
300
301 fn cancel_pending_prediction(
302 &mut self,
303 pending_prediction: PendingPrediction,
304 cx: &mut Context<EditPredictionStore>,
305 ) {
306 self.cancelled_predictions.insert(pending_prediction.id);
307
308 cx.spawn(async move |this, cx| {
309 let Some(prediction_id) = pending_prediction.task.await else {
310 return;
311 };
312
313 this.update(cx, |this, _cx| {
314 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
315 })
316 .ok();
317 })
318 .detach()
319 }
320
321 fn active_buffer(
322 &self,
323 project: &Entity<Project>,
324 cx: &App,
325 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
326 let project = project.read(cx);
327 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
328 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
329 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
330 Some((active_buffer, registered_buffer.last_position))
331 }
332}
333
334#[derive(Debug, Clone)]
335struct CurrentEditPrediction {
336 pub requested_by: PredictionRequestedBy,
337 pub prediction: EditPrediction,
338 pub was_shown: bool,
339 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
340}
341
342impl CurrentEditPrediction {
343 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
344 let Some(new_edits) = self
345 .prediction
346 .interpolate(&self.prediction.buffer.read(cx))
347 else {
348 return false;
349 };
350
351 if self.prediction.buffer != old_prediction.prediction.buffer {
352 return true;
353 }
354
355 let Some(old_edits) = old_prediction
356 .prediction
357 .interpolate(&old_prediction.prediction.buffer.read(cx))
358 else {
359 return true;
360 };
361
362 let requested_by_buffer_id = self.requested_by.buffer_id();
363
364 // This reduces the occurrence of UI thrash from replacing edits
365 //
366 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
367 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
368 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
369 && old_edits.len() == 1
370 && new_edits.len() == 1
371 {
372 let (old_range, old_text) = &old_edits[0];
373 let (new_range, new_text) = &new_edits[0];
374 new_range == old_range && new_text.starts_with(old_text.as_ref())
375 } else {
376 true
377 }
378 }
379}
380
381#[derive(Debug, Clone)]
382enum PredictionRequestedBy {
383 DiagnosticsUpdate,
384 Buffer(EntityId),
385}
386
387impl PredictionRequestedBy {
388 pub fn buffer_id(&self) -> Option<EntityId> {
389 match self {
390 PredictionRequestedBy::DiagnosticsUpdate => None,
391 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
392 }
393 }
394}
395
396#[derive(Debug)]
397struct PendingPrediction {
398 id: usize,
399 task: Task<Option<EditPredictionId>>,
400}
401
402/// A prediction from the perspective of a buffer.
403#[derive(Debug)]
404enum BufferEditPrediction<'a> {
405 Local { prediction: &'a EditPrediction },
406 Jump { prediction: &'a EditPrediction },
407}
408
409#[cfg(test)]
410impl std::ops::Deref for BufferEditPrediction<'_> {
411 type Target = EditPrediction;
412
413 fn deref(&self) -> &Self::Target {
414 match self {
415 BufferEditPrediction::Local { prediction } => prediction,
416 BufferEditPrediction::Jump { prediction } => prediction,
417 }
418 }
419}
420
421struct RegisteredBuffer {
422 file: Option<Arc<dyn File>>,
423 snapshot: TextBufferSnapshot,
424 last_position: Option<Anchor>,
425 _subscriptions: [gpui::Subscription; 2],
426}
427
428#[derive(Clone)]
429struct LastEvent {
430 old_snapshot: TextBufferSnapshot,
431 new_snapshot: TextBufferSnapshot,
432 old_file: Option<Arc<dyn File>>,
433 new_file: Option<Arc<dyn File>>,
434 edit_range: Option<Range<Anchor>>,
435 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
436 last_edit_time: Option<Instant>,
437}
438
439impl LastEvent {
440 pub fn finalize(
441 &self,
442 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
443 cx: &App,
444 ) -> Option<StoredEvent> {
445 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
446 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
447
448 let in_open_source_repo =
449 [self.new_file.as_ref(), self.old_file.as_ref()]
450 .iter()
451 .all(|file| {
452 file.is_some_and(|file| {
453 license_detection_watchers
454 .get(&file.worktree_id(cx))
455 .is_some_and(|watcher| watcher.is_project_open_source())
456 })
457 });
458
459 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
460
461 if path == old_path && diff.is_empty() {
462 None
463 } else {
464 Some(StoredEvent {
465 event: Arc::new(zeta_prompt::Event::BufferChange {
466 old_path,
467 path,
468 diff,
469 in_open_source_repo,
470 // TODO: Actually detect if this edit was predicted or not
471 predicted: false,
472 }),
473 old_snapshot: self.old_snapshot.clone(),
474 })
475 }
476 }
477
478 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
479 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
480 return (self.clone(), None);
481 };
482
483 let before = LastEvent {
484 old_snapshot: self.old_snapshot.clone(),
485 new_snapshot: boundary_snapshot.clone(),
486 old_file: self.old_file.clone(),
487 new_file: self.new_file.clone(),
488 edit_range: None,
489 snapshot_after_last_editing_pause: None,
490 last_edit_time: self.last_edit_time,
491 };
492
493 let after = LastEvent {
494 old_snapshot: boundary_snapshot.clone(),
495 new_snapshot: self.new_snapshot.clone(),
496 old_file: self.old_file.clone(),
497 new_file: self.new_file.clone(),
498 edit_range: None,
499 snapshot_after_last_editing_pause: None,
500 last_edit_time: self.last_edit_time,
501 };
502
503 (before, Some(after))
504 }
505}
506
507pub(crate) fn compute_diff_between_snapshots(
508 old_snapshot: &TextBufferSnapshot,
509 new_snapshot: &TextBufferSnapshot,
510) -> Option<String> {
511 let edits: Vec<Edit<usize>> = new_snapshot
512 .edits_since::<usize>(&old_snapshot.version)
513 .collect();
514
515 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
516
517 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
518 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
519 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
520 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
521
522 const CONTEXT_LINES: u32 = 3;
523
524 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
525 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
526 let old_context_end_row =
527 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
528 let new_context_end_row =
529 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
530
531 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
532 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
533 let old_end_line_offset = old_snapshot
534 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
535 let new_end_line_offset = new_snapshot
536 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
537 let old_edit_range = old_start_line_offset..old_end_line_offset;
538 let new_edit_range = new_start_line_offset..new_end_line_offset;
539
540 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
541 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
542
543 let diff = language::unified_diff_with_offsets(
544 &old_region_text,
545 &new_region_text,
546 old_context_start_row,
547 new_context_start_row,
548 );
549
550 Some(diff)
551}
552
553fn buffer_path_with_id_fallback(
554 file: Option<&Arc<dyn File>>,
555 snapshot: &TextBufferSnapshot,
556 cx: &App,
557) -> Arc<Path> {
558 if let Some(file) = file {
559 file.full_path(cx).into()
560 } else {
561 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
562 }
563}
564
565impl EditPredictionStore {
566 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
567 cx.try_global::<EditPredictionStoreGlobal>()
568 .map(|global| global.0.clone())
569 }
570
571 pub fn global(
572 client: &Arc<Client>,
573 user_store: &Entity<UserStore>,
574 cx: &mut App,
575 ) -> Entity<Self> {
576 cx.try_global::<EditPredictionStoreGlobal>()
577 .map(|global| global.0.clone())
578 .unwrap_or_else(|| {
579 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
580 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
581 ep_store
582 })
583 }
584
585 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
586 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
587 let data_collection_choice = Self::load_data_collection_choice();
588
589 let llm_token = LlmApiToken::default();
590
591 let (reject_tx, reject_rx) = mpsc::unbounded();
592 cx.background_spawn({
593 let client = client.clone();
594 let llm_token = llm_token.clone();
595 let app_version = AppVersion::global(cx);
596 let background_executor = cx.background_executor().clone();
597 async move {
598 Self::handle_rejected_predictions(
599 reject_rx,
600 client,
601 llm_token,
602 app_version,
603 background_executor,
604 )
605 .await
606 }
607 })
608 .detach();
609
610 let mut this = Self {
611 projects: HashMap::default(),
612 client,
613 user_store,
614 use_context: false,
615 llm_token,
616 _llm_token_subscription: cx.subscribe(
617 &refresh_llm_token_listener,
618 |this, _listener, _event, cx| {
619 let client = this.client.clone();
620 let llm_token = this.llm_token.clone();
621 cx.spawn(async move |_this, _cx| {
622 llm_token.refresh(&client).await?;
623 anyhow::Ok(())
624 })
625 .detach_and_log_err(cx);
626 },
627 ),
628 update_required: false,
629 edit_prediction_model: EditPredictionModel::Zeta2 {
630 version: Default::default(),
631 },
632 sweep_ai: SweepAi::new(cx),
633 mercury: Mercury::new(cx),
634
635 data_collection_choice,
636 reject_predictions_tx: reject_tx,
637 rated_predictions: Default::default(),
638 shown_predictions: Default::default(),
639 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
640 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
641 Err(_) => None,
642 },
643 };
644
645 this.configure_context_retrieval(cx);
646 let weak_this = cx.weak_entity();
647 cx.on_flags_ready(move |_, cx| {
648 weak_this
649 .update(cx, |this, cx| this.configure_context_retrieval(cx))
650 .ok();
651 })
652 .detach();
653 cx.observe_global::<SettingsStore>(|this, cx| {
654 this.configure_context_retrieval(cx);
655 })
656 .detach();
657
658 this
659 }
660
661 #[cfg(test)]
662 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
663 self.custom_predict_edits_url = Some(url.into());
664 }
665
666 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
667 self.edit_prediction_model = model;
668 }
669
670 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
671 self.sweep_ai.api_token.read(cx).has_key()
672 }
673
674 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
675 self.mercury.api_token.read(cx).has_key()
676 }
677
678 pub fn set_use_context(&mut self, use_context: bool) {
679 self.use_context = use_context;
680 }
681
682 pub fn clear_history(&mut self) {
683 for project_state in self.projects.values_mut() {
684 project_state.events.clear();
685 project_state.last_event.take();
686 }
687 }
688
689 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
690 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
691 project_state.events.clear();
692 project_state.last_event.take();
693 }
694 }
695
696 pub fn edit_history_for_project(
697 &self,
698 project: &Entity<Project>,
699 cx: &App,
700 ) -> Vec<StoredEvent> {
701 self.projects
702 .get(&project.entity_id())
703 .map(|project_state| project_state.events(cx))
704 .unwrap_or_default()
705 }
706
707 pub fn edit_history_for_project_with_pause_split_last_event(
708 &self,
709 project: &Entity<Project>,
710 cx: &App,
711 ) -> Vec<StoredEvent> {
712 self.projects
713 .get(&project.entity_id())
714 .map(|project_state| project_state.events_split_by_pause(cx))
715 .unwrap_or_default()
716 }
717
718 pub fn context_for_project<'a>(
719 &'a self,
720 project: &Entity<Project>,
721 cx: &'a mut App,
722 ) -> Vec<RelatedFile> {
723 self.projects
724 .get(&project.entity_id())
725 .map(|project| {
726 project
727 .context
728 .update(cx, |context, cx| context.related_files(cx))
729 })
730 .unwrap_or_default()
731 }
732
733 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
734 self.projects
735 .get(&project.entity_id())
736 .and_then(|project| project.copilot.clone())
737 }
738
739 pub fn start_copilot_for_project(
740 &mut self,
741 project: &Entity<Project>,
742 cx: &mut Context<Self>,
743 ) -> Option<Entity<Copilot>> {
744 let state = self.get_or_init_project(project, cx);
745
746 if state.copilot.is_some() {
747 return state.copilot.clone();
748 }
749 let _project = project.clone();
750 let project = project.read(cx);
751
752 let node = project.node_runtime().cloned();
753 if let Some(node) = node {
754 let next_id = project.languages().next_language_server_id();
755 let fs = project.fs().clone();
756
757 let copilot = cx.new(|cx| Copilot::new(_project, next_id, fs, node, cx));
758 state.copilot = Some(copilot.clone());
759 Some(copilot)
760 } else {
761 None
762 }
763 }
764
765 pub fn context_for_project_with_buffers<'a>(
766 &'a self,
767 project: &Entity<Project>,
768 cx: &'a mut App,
769 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
770 self.projects
771 .get(&project.entity_id())
772 .map(|project| {
773 project
774 .context
775 .update(cx, |context, cx| context.related_files_with_buffers(cx))
776 })
777 .unwrap_or_default()
778 }
779
780 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
781 if matches!(
782 self.edit_prediction_model,
783 EditPredictionModel::Zeta2 { .. }
784 ) {
785 self.user_store.read(cx).edit_prediction_usage()
786 } else {
787 None
788 }
789 }
790
791 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
792 self.get_or_init_project(project, cx);
793 }
794
795 pub fn register_buffer(
796 &mut self,
797 buffer: &Entity<Buffer>,
798 project: &Entity<Project>,
799 cx: &mut Context<Self>,
800 ) {
801 let project_state = self.get_or_init_project(project, cx);
802 Self::register_buffer_impl(project_state, buffer, project, cx);
803 }
804
805 fn get_or_init_project(
806 &mut self,
807 project: &Entity<Project>,
808 cx: &mut Context<Self>,
809 ) -> &mut ProjectState {
810 let entity_id = project.entity_id();
811 self.projects
812 .entry(entity_id)
813 .or_insert_with(|| ProjectState {
814 context: {
815 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
816 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
817 this.handle_excerpt_store_event(entity_id, event);
818 })
819 .detach();
820 related_excerpt_store
821 },
822 events: VecDeque::new(),
823 last_event: None,
824 recent_paths: VecDeque::new(),
825 debug_tx: None,
826 registered_buffers: HashMap::default(),
827 current_prediction: None,
828 cancelled_predictions: HashSet::default(),
829 pending_predictions: ArrayVec::new(),
830 next_pending_prediction_id: 0,
831 last_prediction_refresh: None,
832 license_detection_watchers: HashMap::default(),
833 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
834 _subscription: cx.subscribe(&project, Self::handle_project_event),
835 copilot: None,
836 })
837 }
838
839 pub fn remove_project(&mut self, project: &Entity<Project>) {
840 self.projects.remove(&project.entity_id());
841 }
842
843 fn handle_excerpt_store_event(
844 &mut self,
845 project_entity_id: EntityId,
846 event: &RelatedExcerptStoreEvent,
847 ) {
848 if let Some(project_state) = self.projects.get(&project_entity_id) {
849 if let Some(debug_tx) = project_state.debug_tx.clone() {
850 match event {
851 RelatedExcerptStoreEvent::StartedRefresh => {
852 debug_tx
853 .unbounded_send(DebugEvent::ContextRetrievalStarted(
854 ContextRetrievalStartedDebugEvent {
855 project_entity_id: project_entity_id,
856 timestamp: Instant::now(),
857 search_prompt: String::new(),
858 },
859 ))
860 .ok();
861 }
862 RelatedExcerptStoreEvent::FinishedRefresh {
863 cache_hit_count,
864 cache_miss_count,
865 mean_definition_latency,
866 max_definition_latency,
867 } => {
868 debug_tx
869 .unbounded_send(DebugEvent::ContextRetrievalFinished(
870 ContextRetrievalFinishedDebugEvent {
871 project_entity_id: project_entity_id,
872 timestamp: Instant::now(),
873 metadata: vec![
874 (
875 "Cache Hits",
876 format!(
877 "{}/{}",
878 cache_hit_count,
879 cache_hit_count + cache_miss_count
880 )
881 .into(),
882 ),
883 (
884 "Max LSP Time",
885 format!("{} ms", max_definition_latency.as_millis())
886 .into(),
887 ),
888 (
889 "Mean LSP Time",
890 format!("{} ms", mean_definition_latency.as_millis())
891 .into(),
892 ),
893 ],
894 },
895 ))
896 .ok();
897 }
898 }
899 }
900 }
901 }
902
903 pub fn debug_info(
904 &mut self,
905 project: &Entity<Project>,
906 cx: &mut Context<Self>,
907 ) -> mpsc::UnboundedReceiver<DebugEvent> {
908 let project_state = self.get_or_init_project(project, cx);
909 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
910 project_state.debug_tx = Some(debug_watch_tx);
911 debug_watch_rx
912 }
913
914 fn handle_project_event(
915 &mut self,
916 project: Entity<Project>,
917 event: &project::Event,
918 cx: &mut Context<Self>,
919 ) {
920 // TODO [zeta2] init with recent paths
921 match event {
922 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
923 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
924 return;
925 };
926 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
927 if let Some(path) = path {
928 if let Some(ix) = project_state
929 .recent_paths
930 .iter()
931 .position(|probe| probe == &path)
932 {
933 project_state.recent_paths.remove(ix);
934 }
935 project_state.recent_paths.push_front(path);
936 }
937 }
938 project::Event::DiagnosticsUpdated { .. } => {
939 if cx.has_flag::<Zeta2FeatureFlag>() {
940 self.refresh_prediction_from_diagnostics(project, cx);
941 }
942 }
943 _ => (),
944 }
945 }
946
947 fn register_buffer_impl<'a>(
948 project_state: &'a mut ProjectState,
949 buffer: &Entity<Buffer>,
950 project: &Entity<Project>,
951 cx: &mut Context<Self>,
952 ) -> &'a mut RegisteredBuffer {
953 let buffer_id = buffer.entity_id();
954
955 if let Some(file) = buffer.read(cx).file() {
956 let worktree_id = file.worktree_id(cx);
957 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
958 project_state
959 .license_detection_watchers
960 .entry(worktree_id)
961 .or_insert_with(|| {
962 let project_entity_id = project.entity_id();
963 cx.observe_release(&worktree, move |this, _worktree, _cx| {
964 let Some(project_state) = this.projects.get_mut(&project_entity_id)
965 else {
966 return;
967 };
968 project_state
969 .license_detection_watchers
970 .remove(&worktree_id);
971 })
972 .detach();
973 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
974 });
975 }
976 }
977
978 match project_state.registered_buffers.entry(buffer_id) {
979 hash_map::Entry::Occupied(entry) => entry.into_mut(),
980 hash_map::Entry::Vacant(entry) => {
981 let buf = buffer.read(cx);
982 let snapshot = buf.text_snapshot();
983 let file = buf.file().cloned();
984 let project_entity_id = project.entity_id();
985 entry.insert(RegisteredBuffer {
986 snapshot,
987 file,
988 last_position: None,
989 _subscriptions: [
990 cx.subscribe(buffer, {
991 let project = project.downgrade();
992 move |this, buffer, event, cx| {
993 if let language::BufferEvent::Edited = event
994 && let Some(project) = project.upgrade()
995 {
996 this.report_changes_for_buffer(&buffer, &project, cx);
997 }
998 }
999 }),
1000 cx.observe_release(buffer, move |this, _buffer, _cx| {
1001 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1002 else {
1003 return;
1004 };
1005 project_state.registered_buffers.remove(&buffer_id);
1006 }),
1007 ],
1008 })
1009 }
1010 }
1011 }
1012
1013 fn report_changes_for_buffer(
1014 &mut self,
1015 buffer: &Entity<Buffer>,
1016 project: &Entity<Project>,
1017 cx: &mut Context<Self>,
1018 ) {
1019 let project_state = self.get_or_init_project(project, cx);
1020 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1021
1022 let buf = buffer.read(cx);
1023 let new_file = buf.file().cloned();
1024 let new_snapshot = buf.text_snapshot();
1025 if new_snapshot.version == registered_buffer.snapshot.version {
1026 return;
1027 }
1028
1029 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1030 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1031 let mut num_edits = 0usize;
1032 let mut total_deleted = 0usize;
1033 let mut total_inserted = 0usize;
1034 let mut edit_range: Option<Range<Anchor>> = None;
1035 let mut last_offset: Option<usize> = None;
1036
1037 for (edit, anchor_range) in
1038 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1039 {
1040 num_edits += 1;
1041 total_deleted += edit.old.len();
1042 total_inserted += edit.new.len();
1043 edit_range = Some(match edit_range {
1044 None => anchor_range,
1045 Some(acc) => acc.start..anchor_range.end,
1046 });
1047 last_offset = Some(edit.new.end);
1048 }
1049
1050 if num_edits > 0 {
1051 let action_type = match (total_deleted, total_inserted, num_edits) {
1052 (0, ins, n) if ins == n => UserActionType::InsertChar,
1053 (0, _, _) => UserActionType::InsertSelection,
1054 (del, 0, n) if del == n => UserActionType::DeleteChar,
1055 (_, 0, _) => UserActionType::DeleteSelection,
1056 (_, ins, n) if ins == n => UserActionType::InsertChar,
1057 (_, _, _) => UserActionType::InsertSelection,
1058 };
1059
1060 if let Some(offset) = last_offset {
1061 let point = new_snapshot.offset_to_point(offset);
1062 let timestamp_epoch_ms = SystemTime::now()
1063 .duration_since(UNIX_EPOCH)
1064 .map(|d| d.as_millis() as u64)
1065 .unwrap_or(0);
1066 project_state.record_user_action(UserActionRecord {
1067 action_type,
1068 buffer_id: buffer.entity_id(),
1069 line_number: point.row,
1070 offset,
1071 timestamp_epoch_ms,
1072 });
1073 }
1074 }
1075
1076 let events = &mut project_state.events;
1077
1078 let now = cx.background_executor().now();
1079 if let Some(last_event) = project_state.last_event.as_mut() {
1080 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1081 == last_event.new_snapshot.remote_id()
1082 && old_snapshot.version == last_event.new_snapshot.version;
1083
1084 let should_coalesce = is_next_snapshot_of_same_buffer
1085 && edit_range
1086 .as_ref()
1087 .zip(last_event.edit_range.as_ref())
1088 .is_some_and(|(a, b)| {
1089 let a = a.to_point(&new_snapshot);
1090 let b = b.to_point(&new_snapshot);
1091 if a.start > b.end {
1092 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1093 } else if b.start > a.end {
1094 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1095 } else {
1096 true
1097 }
1098 });
1099
1100 if should_coalesce {
1101 let pause_elapsed = last_event
1102 .last_edit_time
1103 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1104 .unwrap_or(false);
1105 if pause_elapsed {
1106 last_event.snapshot_after_last_editing_pause =
1107 Some(last_event.new_snapshot.clone());
1108 }
1109
1110 last_event.edit_range = edit_range;
1111 last_event.new_snapshot = new_snapshot;
1112 last_event.last_edit_time = Some(now);
1113 return;
1114 }
1115 }
1116
1117 if let Some(event) = project_state.last_event.take() {
1118 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1119 if events.len() + 1 >= EVENT_COUNT_MAX {
1120 events.pop_front();
1121 }
1122 events.push_back(event);
1123 }
1124 }
1125
1126 project_state.last_event = Some(LastEvent {
1127 old_file,
1128 new_file,
1129 old_snapshot,
1130 new_snapshot,
1131 edit_range,
1132 snapshot_after_last_editing_pause: None,
1133 last_edit_time: Some(now),
1134 });
1135 }
1136
1137 fn prediction_at(
1138 &mut self,
1139 buffer: &Entity<Buffer>,
1140 position: Option<language::Anchor>,
1141 project: &Entity<Project>,
1142 cx: &App,
1143 ) -> Option<BufferEditPrediction<'_>> {
1144 let project_state = self.projects.get_mut(&project.entity_id())?;
1145 if let Some(position) = position
1146 && let Some(buffer) = project_state
1147 .registered_buffers
1148 .get_mut(&buffer.entity_id())
1149 {
1150 buffer.last_position = Some(position);
1151 }
1152
1153 let CurrentEditPrediction {
1154 requested_by,
1155 prediction,
1156 ..
1157 } = project_state.current_prediction.as_ref()?;
1158
1159 if prediction.targets_buffer(buffer.read(cx)) {
1160 Some(BufferEditPrediction::Local { prediction })
1161 } else {
1162 let show_jump = match requested_by {
1163 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1164 requested_by_buffer_id == &buffer.entity_id()
1165 }
1166 PredictionRequestedBy::DiagnosticsUpdate => true,
1167 };
1168
1169 if show_jump {
1170 Some(BufferEditPrediction::Jump { prediction })
1171 } else {
1172 None
1173 }
1174 }
1175 }
1176
1177 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1178 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1179 return;
1180 };
1181
1182 let Some(current_prediction) = project_state.current_prediction.take() else {
1183 return;
1184 };
1185
1186 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1187 project_state.cancel_pending_prediction(pending_prediction, cx);
1188 }
1189
1190 match self.edit_prediction_model {
1191 EditPredictionModel::Sweep => {
1192 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1193 }
1194 EditPredictionModel::Mercury => {}
1195 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1196 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1197 }
1198 }
1199 }
1200
1201 async fn handle_rejected_predictions(
1202 rx: UnboundedReceiver<EditPredictionRejection>,
1203 client: Arc<Client>,
1204 llm_token: LlmApiToken,
1205 app_version: Version,
1206 background_executor: BackgroundExecutor,
1207 ) {
1208 let mut rx = std::pin::pin!(rx.peekable());
1209 let mut batched = Vec::new();
1210
1211 while let Some(rejection) = rx.next().await {
1212 batched.push(rejection);
1213
1214 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1215 select_biased! {
1216 next = rx.as_mut().peek().fuse() => {
1217 if next.is_some() {
1218 continue;
1219 }
1220 }
1221 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1222 }
1223 }
1224
1225 let url = client
1226 .http_client()
1227 .build_zed_llm_url("/predict_edits/reject", &[])
1228 .unwrap();
1229
1230 let flush_count = batched
1231 .len()
1232 // in case items have accumulated after failure
1233 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1234 let start = batched.len() - flush_count;
1235
1236 let body = RejectEditPredictionsBodyRef {
1237 rejections: &batched[start..],
1238 };
1239
1240 let result = Self::send_api_request::<()>(
1241 |builder| {
1242 let req = builder
1243 .uri(url.as_ref())
1244 .body(serde_json::to_string(&body)?.into());
1245 anyhow::Ok(req?)
1246 },
1247 client.clone(),
1248 llm_token.clone(),
1249 app_version.clone(),
1250 true,
1251 )
1252 .await;
1253
1254 if result.log_err().is_some() {
1255 batched.drain(start..);
1256 }
1257 }
1258 }
1259
1260 fn reject_current_prediction(
1261 &mut self,
1262 reason: EditPredictionRejectReason,
1263 project: &Entity<Project>,
1264 ) {
1265 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1266 project_state.pending_predictions.clear();
1267 if let Some(prediction) = project_state.current_prediction.take() {
1268 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1269 }
1270 };
1271 }
1272
1273 fn did_show_current_prediction(
1274 &mut self,
1275 project: &Entity<Project>,
1276 display_type: edit_prediction_types::SuggestionDisplayType,
1277 cx: &mut Context<Self>,
1278 ) {
1279 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1280 return;
1281 };
1282
1283 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1284 return;
1285 };
1286
1287 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1288 let previous_shown_with = current_prediction.shown_with;
1289
1290 if previous_shown_with.is_none() || !is_jump {
1291 current_prediction.shown_with = Some(display_type);
1292 }
1293
1294 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1295
1296 if is_first_non_jump_show {
1297 current_prediction.was_shown = true;
1298 }
1299
1300 let display_type_changed = previous_shown_with != Some(display_type);
1301
1302 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1303 sweep_ai::edit_prediction_shown(
1304 &self.sweep_ai,
1305 self.client.clone(),
1306 ¤t_prediction.prediction,
1307 display_type,
1308 cx,
1309 );
1310 }
1311
1312 if is_first_non_jump_show {
1313 self.shown_predictions
1314 .push_front(current_prediction.prediction.clone());
1315 if self.shown_predictions.len() > 50 {
1316 let completion = self.shown_predictions.pop_back().unwrap();
1317 self.rated_predictions.remove(&completion.id);
1318 }
1319 }
1320 }
1321
1322 fn reject_prediction(
1323 &mut self,
1324 prediction_id: EditPredictionId,
1325 reason: EditPredictionRejectReason,
1326 was_shown: bool,
1327 ) {
1328 match self.edit_prediction_model {
1329 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1330 if self.custom_predict_edits_url.is_some() {
1331 return;
1332 }
1333 }
1334 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1335 }
1336
1337 self.reject_predictions_tx
1338 .unbounded_send(EditPredictionRejection {
1339 request_id: prediction_id.to_string(),
1340 reason,
1341 was_shown,
1342 })
1343 .log_err();
1344 }
1345
1346 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1347 self.projects
1348 .get(&project.entity_id())
1349 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1350 }
1351
1352 pub fn refresh_prediction_from_buffer(
1353 &mut self,
1354 project: Entity<Project>,
1355 buffer: Entity<Buffer>,
1356 position: language::Anchor,
1357 cx: &mut Context<Self>,
1358 ) {
1359 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1360 let Some(request_task) = this
1361 .update(cx, |this, cx| {
1362 this.request_prediction(
1363 &project,
1364 &buffer,
1365 position,
1366 PredictEditsRequestTrigger::Other,
1367 cx,
1368 )
1369 })
1370 .log_err()
1371 else {
1372 return Task::ready(anyhow::Ok(None));
1373 };
1374
1375 cx.spawn(async move |_cx| {
1376 request_task.await.map(|prediction_result| {
1377 prediction_result.map(|prediction_result| {
1378 (
1379 prediction_result,
1380 PredictionRequestedBy::Buffer(buffer.entity_id()),
1381 )
1382 })
1383 })
1384 })
1385 })
1386 }
1387
1388 pub fn refresh_prediction_from_diagnostics(
1389 &mut self,
1390 project: Entity<Project>,
1391 cx: &mut Context<Self>,
1392 ) {
1393 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1394 return;
1395 };
1396
1397 // Prefer predictions from buffer
1398 if project_state.current_prediction.is_some() {
1399 return;
1400 };
1401
1402 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1403 let Some((active_buffer, snapshot, cursor_point)) = this
1404 .read_with(cx, |this, cx| {
1405 let project_state = this.projects.get(&project.entity_id())?;
1406 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1407 let snapshot = buffer.read(cx).snapshot();
1408
1409 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1410 return None;
1411 }
1412
1413 let cursor_point = position
1414 .map(|pos| pos.to_point(&snapshot))
1415 .unwrap_or_default();
1416
1417 Some((buffer, snapshot, cursor_point))
1418 })
1419 .log_err()
1420 .flatten()
1421 else {
1422 return Task::ready(anyhow::Ok(None));
1423 };
1424
1425 cx.spawn(async move |cx| {
1426 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1427 active_buffer,
1428 &snapshot,
1429 Default::default(),
1430 cursor_point,
1431 &project,
1432 cx,
1433 )
1434 .await?
1435 else {
1436 return anyhow::Ok(None);
1437 };
1438
1439 let Some(prediction_result) = this
1440 .update(cx, |this, cx| {
1441 this.request_prediction(
1442 &project,
1443 &jump_buffer,
1444 jump_position,
1445 PredictEditsRequestTrigger::Diagnostics,
1446 cx,
1447 )
1448 })?
1449 .await?
1450 else {
1451 return anyhow::Ok(None);
1452 };
1453
1454 this.update(cx, |this, cx| {
1455 Some((
1456 if this
1457 .get_or_init_project(&project, cx)
1458 .current_prediction
1459 .is_none()
1460 {
1461 prediction_result
1462 } else {
1463 EditPredictionResult {
1464 id: prediction_result.id,
1465 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1466 }
1467 },
1468 PredictionRequestedBy::DiagnosticsUpdate,
1469 ))
1470 })
1471 })
1472 });
1473 }
1474
1475 fn predictions_enabled_at(
1476 snapshot: &BufferSnapshot,
1477 position: Option<language::Anchor>,
1478 cx: &App,
1479 ) -> bool {
1480 let file = snapshot.file();
1481 let all_settings = all_language_settings(file, cx);
1482 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1483 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1484 {
1485 return false;
1486 }
1487
1488 if let Some(last_position) = position {
1489 let settings = snapshot.settings_at(last_position, cx);
1490
1491 if !settings.edit_predictions_disabled_in.is_empty()
1492 && let Some(scope) = snapshot.language_scope_at(last_position)
1493 && let Some(scope_name) = scope.override_name()
1494 && settings
1495 .edit_predictions_disabled_in
1496 .iter()
1497 .any(|s| s == scope_name)
1498 {
1499 return false;
1500 }
1501 }
1502
1503 true
1504 }
1505
1506 #[cfg(not(test))]
1507 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1508 #[cfg(test)]
1509 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1510
1511 fn queue_prediction_refresh(
1512 &mut self,
1513 project: Entity<Project>,
1514 throttle_entity: EntityId,
1515 cx: &mut Context<Self>,
1516 do_refresh: impl FnOnce(
1517 WeakEntity<Self>,
1518 &mut AsyncApp,
1519 )
1520 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1521 + 'static,
1522 ) {
1523 let project_state = self.get_or_init_project(&project, cx);
1524 let pending_prediction_id = project_state.next_pending_prediction_id;
1525 project_state.next_pending_prediction_id += 1;
1526 let last_request = project_state.last_prediction_refresh;
1527
1528 let task = cx.spawn(async move |this, cx| {
1529 if let Some((last_entity, last_timestamp)) = last_request
1530 && throttle_entity == last_entity
1531 && let Some(timeout) =
1532 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1533 {
1534 cx.background_executor().timer(timeout).await;
1535 }
1536
1537 // If this task was cancelled before the throttle timeout expired,
1538 // do not perform a request.
1539 let mut is_cancelled = true;
1540 this.update(cx, |this, cx| {
1541 let project_state = this.get_or_init_project(&project, cx);
1542 if !project_state
1543 .cancelled_predictions
1544 .remove(&pending_prediction_id)
1545 {
1546 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1547 is_cancelled = false;
1548 }
1549 })
1550 .ok();
1551 if is_cancelled {
1552 return None;
1553 }
1554
1555 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1556 let new_prediction_id = new_prediction_result
1557 .as_ref()
1558 .map(|(prediction, _)| prediction.id.clone());
1559
1560 // When a prediction completes, remove it from the pending list, and cancel
1561 // any pending predictions that were enqueued before it.
1562 this.update(cx, |this, cx| {
1563 let project_state = this.get_or_init_project(&project, cx);
1564
1565 let is_cancelled = project_state
1566 .cancelled_predictions
1567 .remove(&pending_prediction_id);
1568
1569 let new_current_prediction = if !is_cancelled
1570 && let Some((prediction_result, requested_by)) = new_prediction_result
1571 {
1572 match prediction_result.prediction {
1573 Ok(prediction) => {
1574 let new_prediction = CurrentEditPrediction {
1575 requested_by,
1576 prediction,
1577 was_shown: false,
1578 shown_with: None,
1579 };
1580
1581 if let Some(current_prediction) =
1582 project_state.current_prediction.as_ref()
1583 {
1584 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1585 {
1586 this.reject_current_prediction(
1587 EditPredictionRejectReason::Replaced,
1588 &project,
1589 );
1590
1591 Some(new_prediction)
1592 } else {
1593 this.reject_prediction(
1594 new_prediction.prediction.id,
1595 EditPredictionRejectReason::CurrentPreferred,
1596 false,
1597 );
1598 None
1599 }
1600 } else {
1601 Some(new_prediction)
1602 }
1603 }
1604 Err(reject_reason) => {
1605 this.reject_prediction(prediction_result.id, reject_reason, false);
1606 None
1607 }
1608 }
1609 } else {
1610 None
1611 };
1612
1613 let project_state = this.get_or_init_project(&project, cx);
1614
1615 if let Some(new_prediction) = new_current_prediction {
1616 project_state.current_prediction = Some(new_prediction);
1617 }
1618
1619 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1620 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1621 if pending_prediction.id == pending_prediction_id {
1622 pending_predictions.remove(ix);
1623 for pending_prediction in pending_predictions.drain(0..ix) {
1624 project_state.cancel_pending_prediction(pending_prediction, cx)
1625 }
1626 break;
1627 }
1628 }
1629 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1630 cx.notify();
1631 })
1632 .ok();
1633
1634 new_prediction_id
1635 });
1636
1637 if project_state.pending_predictions.len() <= 1 {
1638 project_state.pending_predictions.push(PendingPrediction {
1639 id: pending_prediction_id,
1640 task,
1641 });
1642 } else if project_state.pending_predictions.len() == 2 {
1643 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1644 project_state.pending_predictions.push(PendingPrediction {
1645 id: pending_prediction_id,
1646 task,
1647 });
1648 project_state.cancel_pending_prediction(pending_prediction, cx);
1649 }
1650 }
1651
1652 pub fn request_prediction(
1653 &mut self,
1654 project: &Entity<Project>,
1655 active_buffer: &Entity<Buffer>,
1656 position: language::Anchor,
1657 trigger: PredictEditsRequestTrigger,
1658 cx: &mut Context<Self>,
1659 ) -> Task<Result<Option<EditPredictionResult>>> {
1660 self.request_prediction_internal(
1661 project.clone(),
1662 active_buffer.clone(),
1663 position,
1664 trigger,
1665 cx.has_flag::<Zeta2FeatureFlag>(),
1666 cx,
1667 )
1668 }
1669
1670 fn request_prediction_internal(
1671 &mut self,
1672 project: Entity<Project>,
1673 active_buffer: Entity<Buffer>,
1674 position: language::Anchor,
1675 trigger: PredictEditsRequestTrigger,
1676 allow_jump: bool,
1677 cx: &mut Context<Self>,
1678 ) -> Task<Result<Option<EditPredictionResult>>> {
1679 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1680
1681 self.get_or_init_project(&project, cx);
1682 let project_state = self.projects.get(&project.entity_id()).unwrap();
1683 let stored_events = project_state.events(cx);
1684 let has_events = !stored_events.is_empty();
1685 let events: Vec<Arc<zeta_prompt::Event>> =
1686 stored_events.into_iter().map(|e| e.event).collect();
1687 let debug_tx = project_state.debug_tx.clone();
1688
1689 let snapshot = active_buffer.read(cx).snapshot();
1690 let cursor_point = position.to_point(&snapshot);
1691 let current_offset = position.to_offset(&snapshot);
1692
1693 let mut user_actions: Vec<UserActionRecord> =
1694 project_state.user_actions.iter().cloned().collect();
1695
1696 if let Some(last_action) = user_actions.last() {
1697 if last_action.buffer_id == active_buffer.entity_id()
1698 && current_offset != last_action.offset
1699 {
1700 let timestamp_epoch_ms = SystemTime::now()
1701 .duration_since(UNIX_EPOCH)
1702 .map(|d| d.as_millis() as u64)
1703 .unwrap_or(0);
1704 user_actions.push(UserActionRecord {
1705 action_type: UserActionType::CursorMovement,
1706 buffer_id: active_buffer.entity_id(),
1707 line_number: cursor_point.row,
1708 offset: current_offset,
1709 timestamp_epoch_ms,
1710 });
1711 }
1712 }
1713 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1714 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1715 let diagnostic_search_range =
1716 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1717
1718 let related_files = if self.use_context {
1719 self.context_for_project(&project, cx)
1720 } else {
1721 Vec::new()
1722 };
1723
1724 let inputs = EditPredictionModelInput {
1725 project: project.clone(),
1726 buffer: active_buffer.clone(),
1727 snapshot: snapshot.clone(),
1728 position,
1729 events,
1730 related_files,
1731 recent_paths: project_state.recent_paths.clone(),
1732 trigger,
1733 diagnostic_search_range: diagnostic_search_range.clone(),
1734 debug_tx,
1735 user_actions,
1736 };
1737
1738 let can_collect_example = snapshot
1739 .file()
1740 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1741 && self.can_collect_events(&inputs.events, cx);
1742
1743 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1744 let events_for_capture =
1745 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1746 if let Some(example_task) = capture_example::capture_example(
1747 project.clone(),
1748 active_buffer.clone(),
1749 position,
1750 events_for_capture,
1751 false,
1752 cx,
1753 ) {
1754 cx.spawn(async move |_this, _cx| {
1755 let example = example_task.await?;
1756 telemetry::event!("Edit Prediction Example Captured", example = example);
1757 anyhow::Ok(())
1758 })
1759 .detach_and_log_err(cx);
1760 }
1761 }
1762 let task = match self.edit_prediction_model {
1763 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1764 EditPredictionModel::Zeta2 { version } => {
1765 zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1766 }
1767 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1768 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1769 };
1770
1771 cx.spawn(async move |this, cx| {
1772 let prediction = task.await?;
1773
1774 if prediction.is_none() && allow_jump {
1775 let cursor_point = position.to_point(&snapshot);
1776 if has_events
1777 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1778 active_buffer.clone(),
1779 &snapshot,
1780 diagnostic_search_range,
1781 cursor_point,
1782 &project,
1783 cx,
1784 )
1785 .await?
1786 {
1787 return this
1788 .update(cx, |this, cx| {
1789 this.request_prediction_internal(
1790 project,
1791 jump_buffer,
1792 jump_position,
1793 trigger,
1794 false,
1795 cx,
1796 )
1797 })?
1798 .await;
1799 }
1800
1801 return anyhow::Ok(None);
1802 }
1803
1804 Ok(prediction)
1805 })
1806 }
1807
1808 async fn next_diagnostic_location(
1809 active_buffer: Entity<Buffer>,
1810 active_buffer_snapshot: &BufferSnapshot,
1811 active_buffer_diagnostic_search_range: Range<Point>,
1812 active_buffer_cursor_point: Point,
1813 project: &Entity<Project>,
1814 cx: &mut AsyncApp,
1815 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1816 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1817 let mut jump_location = active_buffer_snapshot
1818 .diagnostic_groups(None)
1819 .into_iter()
1820 .filter_map(|(_, group)| {
1821 let range = &group.entries[group.primary_ix]
1822 .range
1823 .to_point(&active_buffer_snapshot);
1824 if range.overlaps(&active_buffer_diagnostic_search_range) {
1825 None
1826 } else {
1827 Some(range.start)
1828 }
1829 })
1830 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1831 .map(|position| {
1832 (
1833 active_buffer.clone(),
1834 active_buffer_snapshot.anchor_before(position),
1835 )
1836 });
1837
1838 if jump_location.is_none() {
1839 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1840 let file = buffer.file()?;
1841
1842 Some(ProjectPath {
1843 worktree_id: file.worktree_id(cx),
1844 path: file.path().clone(),
1845 })
1846 });
1847
1848 let buffer_task = project.update(cx, |project, cx| {
1849 let (path, _, _) = project
1850 .diagnostic_summaries(false, cx)
1851 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1852 .max_by_key(|(path, _, _)| {
1853 // find the buffer with errors that shares most parent directories
1854 path.path
1855 .components()
1856 .zip(
1857 active_buffer_path
1858 .as_ref()
1859 .map(|p| p.path.components())
1860 .unwrap_or_default(),
1861 )
1862 .take_while(|(a, b)| a == b)
1863 .count()
1864 })?;
1865
1866 Some(project.open_buffer(path, cx))
1867 });
1868
1869 if let Some(buffer_task) = buffer_task {
1870 let closest_buffer = buffer_task.await?;
1871
1872 jump_location = closest_buffer
1873 .read_with(cx, |buffer, _cx| {
1874 buffer
1875 .buffer_diagnostics(None)
1876 .into_iter()
1877 .min_by_key(|entry| entry.diagnostic.severity)
1878 .map(|entry| entry.range.start)
1879 })
1880 .map(|position| (closest_buffer, position));
1881 }
1882 }
1883
1884 anyhow::Ok(jump_location)
1885 }
1886
1887 async fn send_raw_llm_request(
1888 request: RawCompletionRequest,
1889 client: Arc<Client>,
1890 custom_url: Option<Arc<Url>>,
1891 llm_token: LlmApiToken,
1892 app_version: Version,
1893 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1894 let url = if let Some(custom_url) = custom_url {
1895 custom_url.as_ref().clone()
1896 } else {
1897 client
1898 .http_client()
1899 .build_zed_llm_url("/predict_edits/raw", &[])?
1900 };
1901
1902 Self::send_api_request(
1903 |builder| {
1904 let req = builder
1905 .uri(url.as_ref())
1906 .body(serde_json::to_string(&request)?.into());
1907 Ok(req?)
1908 },
1909 client,
1910 llm_token,
1911 app_version,
1912 true,
1913 )
1914 .await
1915 }
1916
1917 pub(crate) async fn send_v3_request(
1918 input: ZetaPromptInput,
1919 prompt_version: ZetaVersion,
1920 client: Arc<Client>,
1921 llm_token: LlmApiToken,
1922 app_version: Version,
1923 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1924 let url = client
1925 .http_client()
1926 .build_zed_llm_url("/predict_edits/v3", &[])?;
1927
1928 let request = PredictEditsV3Request {
1929 input,
1930 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1931 prompt_version,
1932 };
1933
1934 Self::send_api_request(
1935 |builder| {
1936 let req = builder
1937 .uri(url.as_ref())
1938 .body(serde_json::to_string(&request)?.into());
1939 Ok(req?)
1940 },
1941 client,
1942 llm_token,
1943 app_version,
1944 true,
1945 )
1946 .await
1947 }
1948
1949 fn handle_api_response<T>(
1950 this: &WeakEntity<Self>,
1951 response: Result<(T, Option<EditPredictionUsage>)>,
1952 cx: &mut gpui::AsyncApp,
1953 ) -> Result<T> {
1954 match response {
1955 Ok((data, usage)) => {
1956 if let Some(usage) = usage {
1957 this.update(cx, |this, cx| {
1958 this.user_store.update(cx, |user_store, cx| {
1959 user_store.update_edit_prediction_usage(usage, cx);
1960 });
1961 })
1962 .ok();
1963 }
1964 Ok(data)
1965 }
1966 Err(err) => {
1967 if err.is::<ZedUpdateRequiredError>() {
1968 cx.update(|cx| {
1969 this.update(cx, |this, _cx| {
1970 this.update_required = true;
1971 })
1972 .ok();
1973
1974 let error_message: SharedString = err.to_string().into();
1975 show_app_notification(
1976 NotificationId::unique::<ZedUpdateRequiredError>(),
1977 cx,
1978 move |cx| {
1979 cx.new(|cx| {
1980 ErrorMessagePrompt::new(error_message.clone(), cx)
1981 .with_link_button("Update Zed", "https://zed.dev/releases")
1982 })
1983 },
1984 );
1985 });
1986 }
1987 Err(err)
1988 }
1989 }
1990 }
1991
1992 async fn send_api_request<Res>(
1993 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1994 client: Arc<Client>,
1995 llm_token: LlmApiToken,
1996 app_version: Version,
1997 require_auth: bool,
1998 ) -> Result<(Res, Option<EditPredictionUsage>)>
1999 where
2000 Res: DeserializeOwned,
2001 {
2002 let http_client = client.http_client();
2003
2004 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2005 Some(custom_token)
2006 } else if require_auth {
2007 Some(llm_token.acquire(&client).await?)
2008 } else {
2009 llm_token.acquire(&client).await.ok()
2010 };
2011 let mut did_retry = false;
2012
2013 loop {
2014 let request_builder = http_client::Request::builder().method(Method::POST);
2015
2016 let mut request_builder = request_builder
2017 .header("Content-Type", "application/json")
2018 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2019
2020 // Only add Authorization header if we have a token
2021 if let Some(ref token_value) = token {
2022 request_builder =
2023 request_builder.header("Authorization", format!("Bearer {}", token_value));
2024 }
2025
2026 let request = build(request_builder)?;
2027
2028 let mut response = http_client.send(request).await?;
2029
2030 if let Some(minimum_required_version) = response
2031 .headers()
2032 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2033 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2034 {
2035 anyhow::ensure!(
2036 app_version >= minimum_required_version,
2037 ZedUpdateRequiredError {
2038 minimum_version: minimum_required_version
2039 }
2040 );
2041 }
2042
2043 if response.status().is_success() {
2044 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2045
2046 let mut body = Vec::new();
2047 response.body_mut().read_to_end(&mut body).await?;
2048 return Ok((serde_json::from_slice(&body)?, usage));
2049 } else if !did_retry
2050 && token.is_some()
2051 && response
2052 .headers()
2053 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2054 .is_some()
2055 {
2056 did_retry = true;
2057 token = Some(llm_token.refresh(&client).await?);
2058 } else {
2059 let mut body = String::new();
2060 response.body_mut().read_to_string(&mut body).await?;
2061 anyhow::bail!(
2062 "Request failed with status: {:?}\nBody: {}",
2063 response.status(),
2064 body
2065 );
2066 }
2067 }
2068 }
2069
2070 pub fn refresh_context(
2071 &mut self,
2072 project: &Entity<Project>,
2073 buffer: &Entity<language::Buffer>,
2074 cursor_position: language::Anchor,
2075 cx: &mut Context<Self>,
2076 ) {
2077 if self.use_context {
2078 self.get_or_init_project(project, cx)
2079 .context
2080 .update(cx, |store, cx| {
2081 store.refresh(buffer.clone(), cursor_position, cx);
2082 });
2083 }
2084 }
2085
2086 #[cfg(feature = "cli-support")]
2087 pub fn set_context_for_buffer(
2088 &mut self,
2089 project: &Entity<Project>,
2090 related_files: Vec<RelatedFile>,
2091 cx: &mut Context<Self>,
2092 ) {
2093 self.get_or_init_project(project, cx)
2094 .context
2095 .update(cx, |store, cx| {
2096 store.set_related_files(related_files, cx);
2097 });
2098 }
2099
2100 #[cfg(feature = "cli-support")]
2101 pub fn set_recent_paths_for_project(
2102 &mut self,
2103 project: &Entity<Project>,
2104 paths: impl IntoIterator<Item = project::ProjectPath>,
2105 cx: &mut Context<Self>,
2106 ) {
2107 let project_state = self.get_or_init_project(project, cx);
2108 project_state.recent_paths = paths.into_iter().collect();
2109 }
2110
2111 fn is_file_open_source(
2112 &self,
2113 project: &Entity<Project>,
2114 file: &Arc<dyn File>,
2115 cx: &App,
2116 ) -> bool {
2117 if !file.is_local() || file.is_private() {
2118 return false;
2119 }
2120 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2121 return false;
2122 };
2123 project_state
2124 .license_detection_watchers
2125 .get(&file.worktree_id(cx))
2126 .as_ref()
2127 .is_some_and(|watcher| watcher.is_project_open_source())
2128 }
2129
2130 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2131 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2132 }
2133
2134 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2135 if !self.data_collection_choice.is_enabled(cx) {
2136 return false;
2137 }
2138 events.iter().all(|event| {
2139 matches!(
2140 event.as_ref(),
2141 zeta_prompt::Event::BufferChange {
2142 in_open_source_repo: true,
2143 ..
2144 }
2145 )
2146 })
2147 }
2148
2149 fn load_data_collection_choice() -> DataCollectionChoice {
2150 let choice = KEY_VALUE_STORE
2151 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2152 .log_err()
2153 .flatten();
2154
2155 match choice.as_deref() {
2156 Some("true") => DataCollectionChoice::Enabled,
2157 Some("false") => DataCollectionChoice::Disabled,
2158 Some(_) => {
2159 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2160 DataCollectionChoice::NotAnswered
2161 }
2162 None => DataCollectionChoice::NotAnswered,
2163 }
2164 }
2165
2166 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2167 self.data_collection_choice = self.data_collection_choice.toggle();
2168 let new_choice = self.data_collection_choice;
2169 let is_enabled = new_choice.is_enabled(cx);
2170 db::write_and_log(cx, move || {
2171 KEY_VALUE_STORE.write_kvp(
2172 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2173 is_enabled.to_string(),
2174 )
2175 });
2176 }
2177
2178 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2179 self.shown_predictions.iter()
2180 }
2181
2182 pub fn shown_completions_len(&self) -> usize {
2183 self.shown_predictions.len()
2184 }
2185
2186 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2187 self.rated_predictions.contains(id)
2188 }
2189
2190 pub fn rate_prediction(
2191 &mut self,
2192 prediction: &EditPrediction,
2193 rating: EditPredictionRating,
2194 feedback: String,
2195 cx: &mut Context<Self>,
2196 ) {
2197 self.rated_predictions.insert(prediction.id.clone());
2198 telemetry::event!(
2199 "Edit Prediction Rated",
2200 rating,
2201 inputs = prediction.inputs,
2202 output = prediction
2203 .edit_preview
2204 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2205 feedback
2206 );
2207 self.client.telemetry().flush_events().detach();
2208 cx.notify();
2209 }
2210
2211 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2212 if cfg!(feature = "cli-support") {
2213 return;
2214 }
2215 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2216 && all_language_settings(None, cx).edit_predictions.use_context;
2217 }
2218}
2219
2220pub(crate) fn filter_redundant_excerpts(
2221 mut related_files: Vec<RelatedFile>,
2222 cursor_path: &Path,
2223 cursor_row_range: Range<u32>,
2224) -> Vec<RelatedFile> {
2225 for file in &mut related_files {
2226 if file.path.as_ref() == cursor_path {
2227 file.excerpts.retain(|excerpt| {
2228 excerpt.row_range.start < cursor_row_range.start
2229 || excerpt.row_range.end > cursor_row_range.end
2230 });
2231 }
2232 }
2233 related_files.retain(|file| !file.excerpts.is_empty());
2234 related_files
2235}
2236
2237#[derive(Error, Debug)]
2238#[error(
2239 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2240)]
2241pub struct ZedUpdateRequiredError {
2242 minimum_version: Version,
2243}
2244
2245#[derive(Debug, Clone, Copy)]
2246pub enum DataCollectionChoice {
2247 NotAnswered,
2248 Enabled,
2249 Disabled,
2250}
2251
2252impl DataCollectionChoice {
2253 pub fn is_enabled(self, cx: &App) -> bool {
2254 if cx.is_staff() {
2255 return true;
2256 }
2257 match self {
2258 Self::Enabled => true,
2259 Self::NotAnswered | Self::Disabled => false,
2260 }
2261 }
2262
2263 #[must_use]
2264 pub fn toggle(&self) -> DataCollectionChoice {
2265 match self {
2266 Self::Enabled => Self::Disabled,
2267 Self::Disabled => Self::Enabled,
2268 Self::NotAnswered => Self::Enabled,
2269 }
2270 }
2271}
2272
2273impl From<bool> for DataCollectionChoice {
2274 fn from(value: bool) -> Self {
2275 match value {
2276 true => DataCollectionChoice::Enabled,
2277 false => DataCollectionChoice::Disabled,
2278 }
2279 }
2280}
2281
2282struct ZedPredictUpsell;
2283
2284impl Dismissable for ZedPredictUpsell {
2285 const KEY: &'static str = "dismissed-edit-predict-upsell";
2286
2287 fn dismissed() -> bool {
2288 // To make this backwards compatible with older versions of Zed, we
2289 // check if the user has seen the previous Edit Prediction Onboarding
2290 // before, by checking the data collection choice which was written to
2291 // the database once the user clicked on "Accept and Enable"
2292 if KEY_VALUE_STORE
2293 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2294 .log_err()
2295 .is_some_and(|s| s.is_some())
2296 {
2297 return true;
2298 }
2299
2300 KEY_VALUE_STORE
2301 .read_kvp(Self::KEY)
2302 .log_err()
2303 .is_some_and(|s| s.is_some())
2304 }
2305}
2306
2307pub fn should_show_upsell_modal() -> bool {
2308 !ZedPredictUpsell::dismissed()
2309}
2310
2311pub fn init(cx: &mut App) {
2312 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2313 workspace.register_action(
2314 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2315 ZedPredictModal::toggle(
2316 workspace,
2317 workspace.user_store().clone(),
2318 workspace.client().clone(),
2319 window,
2320 cx,
2321 )
2322 },
2323 );
2324
2325 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2326 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2327 settings
2328 .project
2329 .all_languages
2330 .features
2331 .get_or_insert_default()
2332 .edit_prediction_provider = Some(EditPredictionProvider::None)
2333 });
2334 });
2335 })
2336 .detach();
2337}