1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
33use project::{Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use text::Edit;
40use workspace::Workspace;
41use zeta_prompt::ZetaPromptInput;
42use zeta_prompt::ZetaVersion;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
50use std::{env, mem};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59mod onboarding_modal;
60pub mod open_ai_response;
61mod prediction;
62pub mod sweep_ai;
63
64pub mod udiff;
65
66mod capture_example;
67mod zed_edit_prediction_delegate;
68pub mod zeta1;
69pub mod zeta2;
70
71#[cfg(test)]
72mod edit_prediction_tests;
73
74use crate::capture_example::{
75 should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
76};
77use crate::license_detection::LicenseDetectionWatcher;
78use crate::mercury::Mercury;
79use crate::onboarding_modal::ZedPredictModal;
80pub use crate::prediction::EditPrediction;
81pub use crate::prediction::EditPredictionId;
82use crate::prediction::EditPredictionResult;
83pub use crate::sweep_ai::SweepAi;
84pub use capture_example::capture_example;
85pub use language_model::ApiKeyState;
86pub use telemetry_events::EditPredictionRating;
87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
88
89actions!(
90 edit_prediction,
91 [
92 /// Resets the edit prediction onboarding state.
93 ResetOnboarding,
94 /// Clears the edit prediction history.
95 ClearHistory,
96 ]
97);
98
99/// Maximum number of events to track.
100const EVENT_COUNT_MAX: usize = 6;
101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
105
106pub struct SweepFeatureFlag;
107
108impl FeatureFlag for SweepFeatureFlag {
109 const NAME: &str = "sweep-ai";
110}
111
112pub struct MercuryFeatureFlag;
113
114impl FeatureFlag for MercuryFeatureFlag {
115 const NAME: &str = "mercury";
116}
117
118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
119 LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
120
121pub struct Zeta2FeatureFlag;
122
123impl FeatureFlag for Zeta2FeatureFlag {
124 const NAME: &'static str = "zeta2";
125
126 fn enabled_for_staff() -> bool {
127 true
128 }
129}
130
131pub struct EditPredictionExampleCaptureFeatureFlag;
132
133impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
134 const NAME: &'static str = "edit-prediction-example-capture";
135
136 fn enabled_for_staff() -> bool {
137 true
138 }
139}
140
141#[derive(Clone)]
142struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
143
144impl Global for EditPredictionStoreGlobal {}
145
146pub struct EditPredictionStore {
147 client: Arc<Client>,
148 user_store: Entity<UserStore>,
149 llm_token: LlmApiToken,
150 _llm_token_subscription: Subscription,
151 projects: HashMap<EntityId, ProjectState>,
152 use_context: bool,
153 update_required: bool,
154 edit_prediction_model: EditPredictionModel,
155 pub sweep_ai: SweepAi,
156 pub mercury: Mercury,
157 data_collection_choice: DataCollectionChoice,
158 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
159 shown_predictions: VecDeque<EditPrediction>,
160 rated_predictions: HashSet<EditPredictionId>,
161 custom_predict_edits_url: Option<Arc<Url>>,
162}
163
164#[derive(Copy, Clone, Default, PartialEq, Eq)]
165pub enum EditPredictionModel {
166 #[default]
167 Zeta1,
168 Zeta2 {
169 version: ZetaVersion,
170 },
171 Sweep,
172 Mercury,
173}
174
175#[derive(Clone)]
176pub struct EditPredictionModelInput {
177 project: Entity<Project>,
178 buffer: Entity<Buffer>,
179 snapshot: BufferSnapshot,
180 position: Anchor,
181 events: Vec<Arc<zeta_prompt::Event>>,
182 related_files: Vec<RelatedFile>,
183 recent_paths: VecDeque<ProjectPath>,
184 trigger: PredictEditsRequestTrigger,
185 diagnostic_search_range: Range<Point>,
186 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
187 pub user_actions: Vec<UserActionRecord>,
188}
189
190#[derive(Debug)]
191pub enum DebugEvent {
192 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
193 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
194 EditPredictionStarted(EditPredictionStartedDebugEvent),
195 EditPredictionFinished(EditPredictionFinishedDebugEvent),
196}
197
198#[derive(Debug)]
199pub struct ContextRetrievalStartedDebugEvent {
200 pub project_entity_id: EntityId,
201 pub timestamp: Instant,
202 pub search_prompt: String,
203}
204
205#[derive(Debug)]
206pub struct ContextRetrievalFinishedDebugEvent {
207 pub project_entity_id: EntityId,
208 pub timestamp: Instant,
209 pub metadata: Vec<(&'static str, SharedString)>,
210}
211
212#[derive(Debug)]
213pub struct EditPredictionStartedDebugEvent {
214 pub buffer: WeakEntity<Buffer>,
215 pub position: Anchor,
216 pub prompt: Option<String>,
217}
218
219#[derive(Debug)]
220pub struct EditPredictionFinishedDebugEvent {
221 pub buffer: WeakEntity<Buffer>,
222 pub position: Anchor,
223 pub model_output: Option<String>,
224}
225
226const USER_ACTION_HISTORY_SIZE: usize = 16;
227
228#[derive(Clone, Debug)]
229pub struct UserActionRecord {
230 pub action_type: UserActionType,
231 pub buffer_id: EntityId,
232 pub line_number: u32,
233 pub offset: usize,
234 pub timestamp_epoch_ms: u64,
235}
236
237#[derive(Clone, Copy, Debug, PartialEq, Eq)]
238pub enum UserActionType {
239 InsertChar,
240 InsertSelection,
241 DeleteChar,
242 DeleteSelection,
243 CursorMovement,
244}
245
246/// An event with associated metadata for reconstructing buffer state.
247#[derive(Clone)]
248pub struct StoredEvent {
249 pub event: Arc<zeta_prompt::Event>,
250 pub old_snapshot: TextBufferSnapshot,
251}
252
253struct ProjectState {
254 events: VecDeque<StoredEvent>,
255 last_event: Option<LastEvent>,
256 recent_paths: VecDeque<ProjectPath>,
257 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
258 current_prediction: Option<CurrentEditPrediction>,
259 next_pending_prediction_id: usize,
260 pending_predictions: ArrayVec<PendingPrediction, 2>,
261 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
262 last_prediction_refresh: Option<(EntityId, Instant)>,
263 cancelled_predictions: HashSet<usize>,
264 context: Entity<RelatedExcerptStore>,
265 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
266 user_actions: VecDeque<UserActionRecord>,
267 _subscription: gpui::Subscription,
268 copilot: Option<Entity<Copilot>>,
269}
270
271impl ProjectState {
272 fn record_user_action(&mut self, action: UserActionRecord) {
273 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
274 self.user_actions.pop_front();
275 }
276 self.user_actions.push_back(action);
277 }
278
279 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
280 self.events
281 .iter()
282 .cloned()
283 .chain(
284 self.last_event
285 .as_ref()
286 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
287 )
288 .collect()
289 }
290
291 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
292 self.events
293 .iter()
294 .cloned()
295 .chain(self.last_event.as_ref().iter().flat_map(|event| {
296 let (one, two) = event.split_by_pause();
297 let one = one.finalize(&self.license_detection_watchers, cx);
298 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
299 one.into_iter().chain(two)
300 }))
301 .collect()
302 }
303
304 fn cancel_pending_prediction(
305 &mut self,
306 pending_prediction: PendingPrediction,
307 cx: &mut Context<EditPredictionStore>,
308 ) {
309 self.cancelled_predictions.insert(pending_prediction.id);
310
311 cx.spawn(async move |this, cx| {
312 let Some(prediction_id) = pending_prediction.task.await else {
313 return;
314 };
315
316 this.update(cx, |this, _cx| {
317 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
318 })
319 .ok();
320 })
321 .detach()
322 }
323
324 fn active_buffer(
325 &self,
326 project: &Entity<Project>,
327 cx: &App,
328 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
329 let project = project.read(cx);
330 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
331 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
332 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
333 Some((active_buffer, registered_buffer.last_position))
334 }
335}
336
337#[derive(Debug, Clone)]
338struct CurrentEditPrediction {
339 pub requested_by: PredictionRequestedBy,
340 pub prediction: EditPrediction,
341 pub was_shown: bool,
342 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
343}
344
345impl CurrentEditPrediction {
346 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
347 let Some(new_edits) = self
348 .prediction
349 .interpolate(&self.prediction.buffer.read(cx))
350 else {
351 return false;
352 };
353
354 if self.prediction.buffer != old_prediction.prediction.buffer {
355 return true;
356 }
357
358 let Some(old_edits) = old_prediction
359 .prediction
360 .interpolate(&old_prediction.prediction.buffer.read(cx))
361 else {
362 return true;
363 };
364
365 let requested_by_buffer_id = self.requested_by.buffer_id();
366
367 // This reduces the occurrence of UI thrash from replacing edits
368 //
369 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
370 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
371 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
372 && old_edits.len() == 1
373 && new_edits.len() == 1
374 {
375 let (old_range, old_text) = &old_edits[0];
376 let (new_range, new_text) = &new_edits[0];
377 new_range == old_range && new_text.starts_with(old_text.as_ref())
378 } else {
379 true
380 }
381 }
382}
383
384#[derive(Debug, Clone)]
385enum PredictionRequestedBy {
386 DiagnosticsUpdate,
387 Buffer(EntityId),
388}
389
390impl PredictionRequestedBy {
391 pub fn buffer_id(&self) -> Option<EntityId> {
392 match self {
393 PredictionRequestedBy::DiagnosticsUpdate => None,
394 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
395 }
396 }
397}
398
399#[derive(Debug)]
400struct PendingPrediction {
401 id: usize,
402 task: Task<Option<EditPredictionId>>,
403}
404
405/// A prediction from the perspective of a buffer.
406#[derive(Debug)]
407enum BufferEditPrediction<'a> {
408 Local { prediction: &'a EditPrediction },
409 Jump { prediction: &'a EditPrediction },
410}
411
412#[cfg(test)]
413impl std::ops::Deref for BufferEditPrediction<'_> {
414 type Target = EditPrediction;
415
416 fn deref(&self) -> &Self::Target {
417 match self {
418 BufferEditPrediction::Local { prediction } => prediction,
419 BufferEditPrediction::Jump { prediction } => prediction,
420 }
421 }
422}
423
424struct RegisteredBuffer {
425 file: Option<Arc<dyn File>>,
426 snapshot: TextBufferSnapshot,
427 last_position: Option<Anchor>,
428 _subscriptions: [gpui::Subscription; 2],
429}
430
431#[derive(Clone)]
432struct LastEvent {
433 old_snapshot: TextBufferSnapshot,
434 new_snapshot: TextBufferSnapshot,
435 old_file: Option<Arc<dyn File>>,
436 new_file: Option<Arc<dyn File>>,
437 edit_range: Option<Range<Anchor>>,
438 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
439 last_edit_time: Option<Instant>,
440}
441
442impl LastEvent {
443 pub fn finalize(
444 &self,
445 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
446 cx: &App,
447 ) -> Option<StoredEvent> {
448 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
449 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
450
451 let in_open_source_repo =
452 [self.new_file.as_ref(), self.old_file.as_ref()]
453 .iter()
454 .all(|file| {
455 file.is_some_and(|file| {
456 license_detection_watchers
457 .get(&file.worktree_id(cx))
458 .is_some_and(|watcher| watcher.is_project_open_source())
459 })
460 });
461
462 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
463
464 if path == old_path && diff.is_empty() {
465 None
466 } else {
467 Some(StoredEvent {
468 event: Arc::new(zeta_prompt::Event::BufferChange {
469 old_path,
470 path,
471 diff,
472 in_open_source_repo,
473 // TODO: Actually detect if this edit was predicted or not
474 predicted: false,
475 }),
476 old_snapshot: self.old_snapshot.clone(),
477 })
478 }
479 }
480
481 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
482 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
483 return (self.clone(), None);
484 };
485
486 let before = LastEvent {
487 old_snapshot: self.old_snapshot.clone(),
488 new_snapshot: boundary_snapshot.clone(),
489 old_file: self.old_file.clone(),
490 new_file: self.new_file.clone(),
491 edit_range: None,
492 snapshot_after_last_editing_pause: None,
493 last_edit_time: self.last_edit_time,
494 };
495
496 let after = LastEvent {
497 old_snapshot: boundary_snapshot.clone(),
498 new_snapshot: self.new_snapshot.clone(),
499 old_file: self.old_file.clone(),
500 new_file: self.new_file.clone(),
501 edit_range: None,
502 snapshot_after_last_editing_pause: None,
503 last_edit_time: self.last_edit_time,
504 };
505
506 (before, Some(after))
507 }
508}
509
510pub(crate) fn compute_diff_between_snapshots(
511 old_snapshot: &TextBufferSnapshot,
512 new_snapshot: &TextBufferSnapshot,
513) -> Option<String> {
514 let edits: Vec<Edit<usize>> = new_snapshot
515 .edits_since::<usize>(&old_snapshot.version)
516 .collect();
517
518 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
519
520 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
521 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
522 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
523 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
524
525 const CONTEXT_LINES: u32 = 3;
526
527 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
528 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
529 let old_context_end_row =
530 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
531 let new_context_end_row =
532 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
533
534 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
535 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
536 let old_end_line_offset = old_snapshot
537 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
538 let new_end_line_offset = new_snapshot
539 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
540 let old_edit_range = old_start_line_offset..old_end_line_offset;
541 let new_edit_range = new_start_line_offset..new_end_line_offset;
542
543 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
544 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
545
546 let diff = language::unified_diff_with_offsets(
547 &old_region_text,
548 &new_region_text,
549 old_context_start_row,
550 new_context_start_row,
551 );
552
553 Some(diff)
554}
555
556fn buffer_path_with_id_fallback(
557 file: Option<&Arc<dyn File>>,
558 snapshot: &TextBufferSnapshot,
559 cx: &App,
560) -> Arc<Path> {
561 if let Some(file) = file {
562 file.full_path(cx).into()
563 } else {
564 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
565 }
566}
567
568impl EditPredictionStore {
569 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
570 cx.try_global::<EditPredictionStoreGlobal>()
571 .map(|global| global.0.clone())
572 }
573
574 pub fn global(
575 client: &Arc<Client>,
576 user_store: &Entity<UserStore>,
577 cx: &mut App,
578 ) -> Entity<Self> {
579 cx.try_global::<EditPredictionStoreGlobal>()
580 .map(|global| global.0.clone())
581 .unwrap_or_else(|| {
582 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
583 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
584 ep_store
585 })
586 }
587
588 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
589 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
590 let data_collection_choice = Self::load_data_collection_choice();
591
592 let llm_token = LlmApiToken::default();
593
594 let (reject_tx, reject_rx) = mpsc::unbounded();
595 cx.background_spawn({
596 let client = client.clone();
597 let llm_token = llm_token.clone();
598 let app_version = AppVersion::global(cx);
599 let background_executor = cx.background_executor().clone();
600 async move {
601 Self::handle_rejected_predictions(
602 reject_rx,
603 client,
604 llm_token,
605 app_version,
606 background_executor,
607 )
608 .await
609 }
610 })
611 .detach();
612
613 let mut this = Self {
614 projects: HashMap::default(),
615 client,
616 user_store,
617 use_context: false,
618 llm_token,
619 _llm_token_subscription: cx.subscribe(
620 &refresh_llm_token_listener,
621 |this, _listener, _event, cx| {
622 let client = this.client.clone();
623 let llm_token = this.llm_token.clone();
624 cx.spawn(async move |_this, _cx| {
625 llm_token.refresh(&client).await?;
626 anyhow::Ok(())
627 })
628 .detach_and_log_err(cx);
629 },
630 ),
631 update_required: false,
632 edit_prediction_model: EditPredictionModel::Zeta2 {
633 version: Default::default(),
634 },
635 sweep_ai: SweepAi::new(cx),
636 mercury: Mercury::new(cx),
637
638 data_collection_choice,
639 reject_predictions_tx: reject_tx,
640 rated_predictions: Default::default(),
641 shown_predictions: Default::default(),
642 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
643 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
644 Err(_) => None,
645 },
646 };
647
648 this.configure_context_retrieval(cx);
649 let weak_this = cx.weak_entity();
650 cx.on_flags_ready(move |_, cx| {
651 weak_this
652 .update(cx, |this, cx| this.configure_context_retrieval(cx))
653 .ok();
654 })
655 .detach();
656 cx.observe_global::<SettingsStore>(|this, cx| {
657 this.configure_context_retrieval(cx);
658 })
659 .detach();
660
661 this
662 }
663
664 #[cfg(test)]
665 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
666 self.custom_predict_edits_url = Some(url.into());
667 }
668
669 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
670 self.edit_prediction_model = model;
671 }
672
673 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
674 self.sweep_ai.api_token.read(cx).has_key()
675 }
676
677 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
678 self.mercury.api_token.read(cx).has_key()
679 }
680
681 pub fn set_use_context(&mut self, use_context: bool) {
682 self.use_context = use_context;
683 }
684
685 pub fn clear_history(&mut self) {
686 for project_state in self.projects.values_mut() {
687 project_state.events.clear();
688 project_state.last_event.take();
689 }
690 }
691
692 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
693 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
694 project_state.events.clear();
695 project_state.last_event.take();
696 }
697 }
698
699 pub fn edit_history_for_project(
700 &self,
701 project: &Entity<Project>,
702 cx: &App,
703 ) -> Vec<StoredEvent> {
704 self.projects
705 .get(&project.entity_id())
706 .map(|project_state| project_state.events(cx))
707 .unwrap_or_default()
708 }
709
710 pub fn edit_history_for_project_with_pause_split_last_event(
711 &self,
712 project: &Entity<Project>,
713 cx: &App,
714 ) -> Vec<StoredEvent> {
715 self.projects
716 .get(&project.entity_id())
717 .map(|project_state| project_state.events_split_by_pause(cx))
718 .unwrap_or_default()
719 }
720
721 pub fn context_for_project<'a>(
722 &'a self,
723 project: &Entity<Project>,
724 cx: &'a mut App,
725 ) -> Vec<RelatedFile> {
726 self.projects
727 .get(&project.entity_id())
728 .map(|project| {
729 project
730 .context
731 .update(cx, |context, cx| context.related_files(cx))
732 })
733 .unwrap_or_default()
734 }
735
736 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
737 self.projects
738 .get(&project.entity_id())
739 .and_then(|project| project.copilot.clone())
740 }
741
742 pub fn start_copilot_for_project(
743 &mut self,
744 project: &Entity<Project>,
745 cx: &mut Context<Self>,
746 ) -> Option<Entity<Copilot>> {
747 let state = self.get_or_init_project(project, cx);
748
749 if state.copilot.is_some() {
750 return state.copilot.clone();
751 }
752 let _project = project.clone();
753 let project = project.read(cx);
754
755 let node = project.node_runtime().cloned();
756 if let Some(node) = node {
757 let next_id = project.languages().next_language_server_id();
758 let fs = project.fs().clone();
759
760 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
761 state.copilot = Some(copilot.clone());
762 Some(copilot)
763 } else {
764 None
765 }
766 }
767
768 pub fn context_for_project_with_buffers<'a>(
769 &'a self,
770 project: &Entity<Project>,
771 cx: &'a mut App,
772 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
773 self.projects
774 .get(&project.entity_id())
775 .map(|project| {
776 project
777 .context
778 .update(cx, |context, cx| context.related_files_with_buffers(cx))
779 })
780 .unwrap_or_default()
781 }
782
783 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
784 if matches!(
785 self.edit_prediction_model,
786 EditPredictionModel::Zeta2 { .. }
787 ) {
788 self.user_store.read(cx).edit_prediction_usage()
789 } else {
790 None
791 }
792 }
793
794 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
795 self.get_or_init_project(project, cx);
796 }
797
798 pub fn register_buffer(
799 &mut self,
800 buffer: &Entity<Buffer>,
801 project: &Entity<Project>,
802 cx: &mut Context<Self>,
803 ) {
804 let project_state = self.get_or_init_project(project, cx);
805 Self::register_buffer_impl(project_state, buffer, project, cx);
806 }
807
808 fn get_or_init_project(
809 &mut self,
810 project: &Entity<Project>,
811 cx: &mut Context<Self>,
812 ) -> &mut ProjectState {
813 let entity_id = project.entity_id();
814 self.projects
815 .entry(entity_id)
816 .or_insert_with(|| ProjectState {
817 context: {
818 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
819 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
820 this.handle_excerpt_store_event(entity_id, event);
821 })
822 .detach();
823 related_excerpt_store
824 },
825 events: VecDeque::new(),
826 last_event: None,
827 recent_paths: VecDeque::new(),
828 debug_tx: None,
829 registered_buffers: HashMap::default(),
830 current_prediction: None,
831 cancelled_predictions: HashSet::default(),
832 pending_predictions: ArrayVec::new(),
833 next_pending_prediction_id: 0,
834 last_prediction_refresh: None,
835 license_detection_watchers: HashMap::default(),
836 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
837 _subscription: cx.subscribe(&project, Self::handle_project_event),
838 copilot: None,
839 })
840 }
841
842 pub fn remove_project(&mut self, project: &Entity<Project>) {
843 self.projects.remove(&project.entity_id());
844 }
845
846 fn handle_excerpt_store_event(
847 &mut self,
848 project_entity_id: EntityId,
849 event: &RelatedExcerptStoreEvent,
850 ) {
851 if let Some(project_state) = self.projects.get(&project_entity_id) {
852 if let Some(debug_tx) = project_state.debug_tx.clone() {
853 match event {
854 RelatedExcerptStoreEvent::StartedRefresh => {
855 debug_tx
856 .unbounded_send(DebugEvent::ContextRetrievalStarted(
857 ContextRetrievalStartedDebugEvent {
858 project_entity_id: project_entity_id,
859 timestamp: Instant::now(),
860 search_prompt: String::new(),
861 },
862 ))
863 .ok();
864 }
865 RelatedExcerptStoreEvent::FinishedRefresh {
866 cache_hit_count,
867 cache_miss_count,
868 mean_definition_latency,
869 max_definition_latency,
870 } => {
871 debug_tx
872 .unbounded_send(DebugEvent::ContextRetrievalFinished(
873 ContextRetrievalFinishedDebugEvent {
874 project_entity_id: project_entity_id,
875 timestamp: Instant::now(),
876 metadata: vec![
877 (
878 "Cache Hits",
879 format!(
880 "{}/{}",
881 cache_hit_count,
882 cache_hit_count + cache_miss_count
883 )
884 .into(),
885 ),
886 (
887 "Max LSP Time",
888 format!("{} ms", max_definition_latency.as_millis())
889 .into(),
890 ),
891 (
892 "Mean LSP Time",
893 format!("{} ms", mean_definition_latency.as_millis())
894 .into(),
895 ),
896 ],
897 },
898 ))
899 .ok();
900 }
901 }
902 }
903 }
904 }
905
906 pub fn debug_info(
907 &mut self,
908 project: &Entity<Project>,
909 cx: &mut Context<Self>,
910 ) -> mpsc::UnboundedReceiver<DebugEvent> {
911 let project_state = self.get_or_init_project(project, cx);
912 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
913 project_state.debug_tx = Some(debug_watch_tx);
914 debug_watch_rx
915 }
916
917 fn handle_project_event(
918 &mut self,
919 project: Entity<Project>,
920 event: &project::Event,
921 cx: &mut Context<Self>,
922 ) {
923 // TODO [zeta2] init with recent paths
924 match event {
925 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
926 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
927 return;
928 };
929 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
930 if let Some(path) = path {
931 if let Some(ix) = project_state
932 .recent_paths
933 .iter()
934 .position(|probe| probe == &path)
935 {
936 project_state.recent_paths.remove(ix);
937 }
938 project_state.recent_paths.push_front(path);
939 }
940 }
941 project::Event::DiagnosticsUpdated { .. } => {
942 if cx.has_flag::<Zeta2FeatureFlag>() {
943 self.refresh_prediction_from_diagnostics(project, cx);
944 }
945 }
946 _ => (),
947 }
948 }
949
950 fn register_buffer_impl<'a>(
951 project_state: &'a mut ProjectState,
952 buffer: &Entity<Buffer>,
953 project: &Entity<Project>,
954 cx: &mut Context<Self>,
955 ) -> &'a mut RegisteredBuffer {
956 let buffer_id = buffer.entity_id();
957
958 if let Some(file) = buffer.read(cx).file() {
959 let worktree_id = file.worktree_id(cx);
960 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
961 project_state
962 .license_detection_watchers
963 .entry(worktree_id)
964 .or_insert_with(|| {
965 let project_entity_id = project.entity_id();
966 cx.observe_release(&worktree, move |this, _worktree, _cx| {
967 let Some(project_state) = this.projects.get_mut(&project_entity_id)
968 else {
969 return;
970 };
971 project_state
972 .license_detection_watchers
973 .remove(&worktree_id);
974 })
975 .detach();
976 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
977 });
978 }
979 }
980
981 match project_state.registered_buffers.entry(buffer_id) {
982 hash_map::Entry::Occupied(entry) => entry.into_mut(),
983 hash_map::Entry::Vacant(entry) => {
984 let buf = buffer.read(cx);
985 let snapshot = buf.text_snapshot();
986 let file = buf.file().cloned();
987 let project_entity_id = project.entity_id();
988 entry.insert(RegisteredBuffer {
989 snapshot,
990 file,
991 last_position: None,
992 _subscriptions: [
993 cx.subscribe(buffer, {
994 let project = project.downgrade();
995 move |this, buffer, event, cx| {
996 if let language::BufferEvent::Edited = event
997 && let Some(project) = project.upgrade()
998 {
999 this.report_changes_for_buffer(&buffer, &project, cx);
1000 }
1001 }
1002 }),
1003 cx.observe_release(buffer, move |this, _buffer, _cx| {
1004 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1005 else {
1006 return;
1007 };
1008 project_state.registered_buffers.remove(&buffer_id);
1009 }),
1010 ],
1011 })
1012 }
1013 }
1014 }
1015
1016 fn report_changes_for_buffer(
1017 &mut self,
1018 buffer: &Entity<Buffer>,
1019 project: &Entity<Project>,
1020 cx: &mut Context<Self>,
1021 ) {
1022 let project_state = self.get_or_init_project(project, cx);
1023 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1024
1025 let buf = buffer.read(cx);
1026 let new_file = buf.file().cloned();
1027 let new_snapshot = buf.text_snapshot();
1028 if new_snapshot.version == registered_buffer.snapshot.version {
1029 return;
1030 }
1031
1032 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1033 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1034 let mut num_edits = 0usize;
1035 let mut total_deleted = 0usize;
1036 let mut total_inserted = 0usize;
1037 let mut edit_range: Option<Range<Anchor>> = None;
1038 let mut last_offset: Option<usize> = None;
1039
1040 for (edit, anchor_range) in
1041 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1042 {
1043 num_edits += 1;
1044 total_deleted += edit.old.len();
1045 total_inserted += edit.new.len();
1046 edit_range = Some(match edit_range {
1047 None => anchor_range,
1048 Some(acc) => acc.start..anchor_range.end,
1049 });
1050 last_offset = Some(edit.new.end);
1051 }
1052
1053 if num_edits > 0 {
1054 let action_type = match (total_deleted, total_inserted, num_edits) {
1055 (0, ins, n) if ins == n => UserActionType::InsertChar,
1056 (0, _, _) => UserActionType::InsertSelection,
1057 (del, 0, n) if del == n => UserActionType::DeleteChar,
1058 (_, 0, _) => UserActionType::DeleteSelection,
1059 (_, ins, n) if ins == n => UserActionType::InsertChar,
1060 (_, _, _) => UserActionType::InsertSelection,
1061 };
1062
1063 if let Some(offset) = last_offset {
1064 let point = new_snapshot.offset_to_point(offset);
1065 let timestamp_epoch_ms = SystemTime::now()
1066 .duration_since(UNIX_EPOCH)
1067 .map(|d| d.as_millis() as u64)
1068 .unwrap_or(0);
1069 project_state.record_user_action(UserActionRecord {
1070 action_type,
1071 buffer_id: buffer.entity_id(),
1072 line_number: point.row,
1073 offset,
1074 timestamp_epoch_ms,
1075 });
1076 }
1077 }
1078
1079 let events = &mut project_state.events;
1080
1081 let now = cx.background_executor().now();
1082 if let Some(last_event) = project_state.last_event.as_mut() {
1083 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1084 == last_event.new_snapshot.remote_id()
1085 && old_snapshot.version == last_event.new_snapshot.version;
1086
1087 let should_coalesce = is_next_snapshot_of_same_buffer
1088 && edit_range
1089 .as_ref()
1090 .zip(last_event.edit_range.as_ref())
1091 .is_some_and(|(a, b)| {
1092 let a = a.to_point(&new_snapshot);
1093 let b = b.to_point(&new_snapshot);
1094 if a.start > b.end {
1095 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1096 } else if b.start > a.end {
1097 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1098 } else {
1099 true
1100 }
1101 });
1102
1103 if should_coalesce {
1104 let pause_elapsed = last_event
1105 .last_edit_time
1106 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1107 .unwrap_or(false);
1108 if pause_elapsed {
1109 last_event.snapshot_after_last_editing_pause =
1110 Some(last_event.new_snapshot.clone());
1111 }
1112
1113 last_event.edit_range = edit_range;
1114 last_event.new_snapshot = new_snapshot;
1115 last_event.last_edit_time = Some(now);
1116 return;
1117 }
1118 }
1119
1120 if let Some(event) = project_state.last_event.take() {
1121 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1122 if events.len() + 1 >= EVENT_COUNT_MAX {
1123 events.pop_front();
1124 }
1125 events.push_back(event);
1126 }
1127 }
1128
1129 project_state.last_event = Some(LastEvent {
1130 old_file,
1131 new_file,
1132 old_snapshot,
1133 new_snapshot,
1134 edit_range,
1135 snapshot_after_last_editing_pause: None,
1136 last_edit_time: Some(now),
1137 });
1138 }
1139
1140 fn prediction_at(
1141 &mut self,
1142 buffer: &Entity<Buffer>,
1143 position: Option<language::Anchor>,
1144 project: &Entity<Project>,
1145 cx: &App,
1146 ) -> Option<BufferEditPrediction<'_>> {
1147 let project_state = self.projects.get_mut(&project.entity_id())?;
1148 if let Some(position) = position
1149 && let Some(buffer) = project_state
1150 .registered_buffers
1151 .get_mut(&buffer.entity_id())
1152 {
1153 buffer.last_position = Some(position);
1154 }
1155
1156 let CurrentEditPrediction {
1157 requested_by,
1158 prediction,
1159 ..
1160 } = project_state.current_prediction.as_ref()?;
1161
1162 if prediction.targets_buffer(buffer.read(cx)) {
1163 Some(BufferEditPrediction::Local { prediction })
1164 } else {
1165 let show_jump = match requested_by {
1166 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1167 requested_by_buffer_id == &buffer.entity_id()
1168 }
1169 PredictionRequestedBy::DiagnosticsUpdate => true,
1170 };
1171
1172 if show_jump {
1173 Some(BufferEditPrediction::Jump { prediction })
1174 } else {
1175 None
1176 }
1177 }
1178 }
1179
1180 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1181 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1182 return;
1183 };
1184
1185 let Some(current_prediction) = project_state.current_prediction.take() else {
1186 return;
1187 };
1188
1189 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1190 project_state.cancel_pending_prediction(pending_prediction, cx);
1191 }
1192
1193 match self.edit_prediction_model {
1194 EditPredictionModel::Sweep => {
1195 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1196 }
1197 EditPredictionModel::Mercury => {}
1198 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1199 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1200 }
1201 }
1202 }
1203
1204 async fn handle_rejected_predictions(
1205 rx: UnboundedReceiver<EditPredictionRejection>,
1206 client: Arc<Client>,
1207 llm_token: LlmApiToken,
1208 app_version: Version,
1209 background_executor: BackgroundExecutor,
1210 ) {
1211 let mut rx = std::pin::pin!(rx.peekable());
1212 let mut batched = Vec::new();
1213
1214 while let Some(rejection) = rx.next().await {
1215 batched.push(rejection);
1216
1217 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1218 select_biased! {
1219 next = rx.as_mut().peek().fuse() => {
1220 if next.is_some() {
1221 continue;
1222 }
1223 }
1224 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1225 }
1226 }
1227
1228 let url = client
1229 .http_client()
1230 .build_zed_llm_url("/predict_edits/reject", &[])
1231 .unwrap();
1232
1233 let flush_count = batched
1234 .len()
1235 // in case items have accumulated after failure
1236 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1237 let start = batched.len() - flush_count;
1238
1239 let body = RejectEditPredictionsBodyRef {
1240 rejections: &batched[start..],
1241 };
1242
1243 let result = Self::send_api_request::<()>(
1244 |builder| {
1245 let req = builder
1246 .uri(url.as_ref())
1247 .body(serde_json::to_string(&body)?.into());
1248 anyhow::Ok(req?)
1249 },
1250 client.clone(),
1251 llm_token.clone(),
1252 app_version.clone(),
1253 true,
1254 )
1255 .await;
1256
1257 if result.log_err().is_some() {
1258 batched.drain(start..);
1259 }
1260 }
1261 }
1262
1263 fn reject_current_prediction(
1264 &mut self,
1265 reason: EditPredictionRejectReason,
1266 project: &Entity<Project>,
1267 ) {
1268 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1269 project_state.pending_predictions.clear();
1270 if let Some(prediction) = project_state.current_prediction.take() {
1271 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1272 }
1273 };
1274 }
1275
1276 fn did_show_current_prediction(
1277 &mut self,
1278 project: &Entity<Project>,
1279 display_type: edit_prediction_types::SuggestionDisplayType,
1280 cx: &mut Context<Self>,
1281 ) {
1282 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1283 return;
1284 };
1285
1286 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1287 return;
1288 };
1289
1290 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1291 let previous_shown_with = current_prediction.shown_with;
1292
1293 if previous_shown_with.is_none() || !is_jump {
1294 current_prediction.shown_with = Some(display_type);
1295 }
1296
1297 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1298
1299 if is_first_non_jump_show {
1300 current_prediction.was_shown = true;
1301 }
1302
1303 let display_type_changed = previous_shown_with != Some(display_type);
1304
1305 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1306 sweep_ai::edit_prediction_shown(
1307 &self.sweep_ai,
1308 self.client.clone(),
1309 ¤t_prediction.prediction,
1310 display_type,
1311 cx,
1312 );
1313 }
1314
1315 if is_first_non_jump_show {
1316 self.shown_predictions
1317 .push_front(current_prediction.prediction.clone());
1318 if self.shown_predictions.len() > 50 {
1319 let completion = self.shown_predictions.pop_back().unwrap();
1320 self.rated_predictions.remove(&completion.id);
1321 }
1322 }
1323 }
1324
1325 fn reject_prediction(
1326 &mut self,
1327 prediction_id: EditPredictionId,
1328 reason: EditPredictionRejectReason,
1329 was_shown: bool,
1330 ) {
1331 match self.edit_prediction_model {
1332 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1333 if self.custom_predict_edits_url.is_some() {
1334 return;
1335 }
1336 }
1337 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1338 }
1339
1340 self.reject_predictions_tx
1341 .unbounded_send(EditPredictionRejection {
1342 request_id: prediction_id.to_string(),
1343 reason,
1344 was_shown,
1345 })
1346 .log_err();
1347 }
1348
1349 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1350 self.projects
1351 .get(&project.entity_id())
1352 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1353 }
1354
1355 pub fn refresh_prediction_from_buffer(
1356 &mut self,
1357 project: Entity<Project>,
1358 buffer: Entity<Buffer>,
1359 position: language::Anchor,
1360 cx: &mut Context<Self>,
1361 ) {
1362 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1363 let Some(request_task) = this
1364 .update(cx, |this, cx| {
1365 this.request_prediction(
1366 &project,
1367 &buffer,
1368 position,
1369 PredictEditsRequestTrigger::Other,
1370 cx,
1371 )
1372 })
1373 .log_err()
1374 else {
1375 return Task::ready(anyhow::Ok(None));
1376 };
1377
1378 cx.spawn(async move |_cx| {
1379 request_task.await.map(|prediction_result| {
1380 prediction_result.map(|prediction_result| {
1381 (
1382 prediction_result,
1383 PredictionRequestedBy::Buffer(buffer.entity_id()),
1384 )
1385 })
1386 })
1387 })
1388 })
1389 }
1390
1391 pub fn refresh_prediction_from_diagnostics(
1392 &mut self,
1393 project: Entity<Project>,
1394 cx: &mut Context<Self>,
1395 ) {
1396 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1397 return;
1398 };
1399
1400 // Prefer predictions from buffer
1401 if project_state.current_prediction.is_some() {
1402 return;
1403 };
1404
1405 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1406 let Some((active_buffer, snapshot, cursor_point)) = this
1407 .read_with(cx, |this, cx| {
1408 let project_state = this.projects.get(&project.entity_id())?;
1409 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1410 let snapshot = buffer.read(cx).snapshot();
1411
1412 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1413 return None;
1414 }
1415
1416 let cursor_point = position
1417 .map(|pos| pos.to_point(&snapshot))
1418 .unwrap_or_default();
1419
1420 Some((buffer, snapshot, cursor_point))
1421 })
1422 .log_err()
1423 .flatten()
1424 else {
1425 return Task::ready(anyhow::Ok(None));
1426 };
1427
1428 cx.spawn(async move |cx| {
1429 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1430 active_buffer,
1431 &snapshot,
1432 Default::default(),
1433 cursor_point,
1434 &project,
1435 cx,
1436 )
1437 .await?
1438 else {
1439 return anyhow::Ok(None);
1440 };
1441
1442 let Some(prediction_result) = this
1443 .update(cx, |this, cx| {
1444 this.request_prediction(
1445 &project,
1446 &jump_buffer,
1447 jump_position,
1448 PredictEditsRequestTrigger::Diagnostics,
1449 cx,
1450 )
1451 })?
1452 .await?
1453 else {
1454 return anyhow::Ok(None);
1455 };
1456
1457 this.update(cx, |this, cx| {
1458 Some((
1459 if this
1460 .get_or_init_project(&project, cx)
1461 .current_prediction
1462 .is_none()
1463 {
1464 prediction_result
1465 } else {
1466 EditPredictionResult {
1467 id: prediction_result.id,
1468 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1469 }
1470 },
1471 PredictionRequestedBy::DiagnosticsUpdate,
1472 ))
1473 })
1474 })
1475 });
1476 }
1477
1478 fn predictions_enabled_at(
1479 snapshot: &BufferSnapshot,
1480 position: Option<language::Anchor>,
1481 cx: &App,
1482 ) -> bool {
1483 let file = snapshot.file();
1484 let all_settings = all_language_settings(file, cx);
1485 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1486 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1487 {
1488 return false;
1489 }
1490
1491 if let Some(last_position) = position {
1492 let settings = snapshot.settings_at(last_position, cx);
1493
1494 if !settings.edit_predictions_disabled_in.is_empty()
1495 && let Some(scope) = snapshot.language_scope_at(last_position)
1496 && let Some(scope_name) = scope.override_name()
1497 && settings
1498 .edit_predictions_disabled_in
1499 .iter()
1500 .any(|s| s == scope_name)
1501 {
1502 return false;
1503 }
1504 }
1505
1506 true
1507 }
1508
1509 #[cfg(not(test))]
1510 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1511 #[cfg(test)]
1512 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1513
1514 fn queue_prediction_refresh(
1515 &mut self,
1516 project: Entity<Project>,
1517 throttle_entity: EntityId,
1518 cx: &mut Context<Self>,
1519 do_refresh: impl FnOnce(
1520 WeakEntity<Self>,
1521 &mut AsyncApp,
1522 )
1523 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1524 + 'static,
1525 ) {
1526 let project_state = self.get_or_init_project(&project, cx);
1527 let pending_prediction_id = project_state.next_pending_prediction_id;
1528 project_state.next_pending_prediction_id += 1;
1529 let last_request = project_state.last_prediction_refresh;
1530
1531 let task = cx.spawn(async move |this, cx| {
1532 if let Some((last_entity, last_timestamp)) = last_request
1533 && throttle_entity == last_entity
1534 && let Some(timeout) =
1535 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1536 {
1537 cx.background_executor().timer(timeout).await;
1538 }
1539
1540 // If this task was cancelled before the throttle timeout expired,
1541 // do not perform a request.
1542 let mut is_cancelled = true;
1543 this.update(cx, |this, cx| {
1544 let project_state = this.get_or_init_project(&project, cx);
1545 if !project_state
1546 .cancelled_predictions
1547 .remove(&pending_prediction_id)
1548 {
1549 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1550 is_cancelled = false;
1551 }
1552 })
1553 .ok();
1554 if is_cancelled {
1555 return None;
1556 }
1557
1558 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1559 let new_prediction_id = new_prediction_result
1560 .as_ref()
1561 .map(|(prediction, _)| prediction.id.clone());
1562
1563 // When a prediction completes, remove it from the pending list, and cancel
1564 // any pending predictions that were enqueued before it.
1565 this.update(cx, |this, cx| {
1566 let project_state = this.get_or_init_project(&project, cx);
1567
1568 let is_cancelled = project_state
1569 .cancelled_predictions
1570 .remove(&pending_prediction_id);
1571
1572 let new_current_prediction = if !is_cancelled
1573 && let Some((prediction_result, requested_by)) = new_prediction_result
1574 {
1575 match prediction_result.prediction {
1576 Ok(prediction) => {
1577 let new_prediction = CurrentEditPrediction {
1578 requested_by,
1579 prediction,
1580 was_shown: false,
1581 shown_with: None,
1582 };
1583
1584 if let Some(current_prediction) =
1585 project_state.current_prediction.as_ref()
1586 {
1587 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1588 {
1589 this.reject_current_prediction(
1590 EditPredictionRejectReason::Replaced,
1591 &project,
1592 );
1593
1594 Some(new_prediction)
1595 } else {
1596 this.reject_prediction(
1597 new_prediction.prediction.id,
1598 EditPredictionRejectReason::CurrentPreferred,
1599 false,
1600 );
1601 None
1602 }
1603 } else {
1604 Some(new_prediction)
1605 }
1606 }
1607 Err(reject_reason) => {
1608 this.reject_prediction(prediction_result.id, reject_reason, false);
1609 None
1610 }
1611 }
1612 } else {
1613 None
1614 };
1615
1616 let project_state = this.get_or_init_project(&project, cx);
1617
1618 if let Some(new_prediction) = new_current_prediction {
1619 project_state.current_prediction = Some(new_prediction);
1620 }
1621
1622 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1623 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1624 if pending_prediction.id == pending_prediction_id {
1625 pending_predictions.remove(ix);
1626 for pending_prediction in pending_predictions.drain(0..ix) {
1627 project_state.cancel_pending_prediction(pending_prediction, cx)
1628 }
1629 break;
1630 }
1631 }
1632 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1633 cx.notify();
1634 })
1635 .ok();
1636
1637 new_prediction_id
1638 });
1639
1640 if project_state.pending_predictions.len() <= 1 {
1641 project_state.pending_predictions.push(PendingPrediction {
1642 id: pending_prediction_id,
1643 task,
1644 });
1645 } else if project_state.pending_predictions.len() == 2 {
1646 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1647 project_state.pending_predictions.push(PendingPrediction {
1648 id: pending_prediction_id,
1649 task,
1650 });
1651 project_state.cancel_pending_prediction(pending_prediction, cx);
1652 }
1653 }
1654
1655 pub fn request_prediction(
1656 &mut self,
1657 project: &Entity<Project>,
1658 active_buffer: &Entity<Buffer>,
1659 position: language::Anchor,
1660 trigger: PredictEditsRequestTrigger,
1661 cx: &mut Context<Self>,
1662 ) -> Task<Result<Option<EditPredictionResult>>> {
1663 self.request_prediction_internal(
1664 project.clone(),
1665 active_buffer.clone(),
1666 position,
1667 trigger,
1668 cx.has_flag::<Zeta2FeatureFlag>(),
1669 cx,
1670 )
1671 }
1672
1673 fn request_prediction_internal(
1674 &mut self,
1675 project: Entity<Project>,
1676 active_buffer: Entity<Buffer>,
1677 position: language::Anchor,
1678 trigger: PredictEditsRequestTrigger,
1679 allow_jump: bool,
1680 cx: &mut Context<Self>,
1681 ) -> Task<Result<Option<EditPredictionResult>>> {
1682 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1683
1684 self.get_or_init_project(&project, cx);
1685 let project_state = self.projects.get(&project.entity_id()).unwrap();
1686 let stored_events = project_state.events(cx);
1687 let has_events = !stored_events.is_empty();
1688 let events: Vec<Arc<zeta_prompt::Event>> =
1689 stored_events.into_iter().map(|e| e.event).collect();
1690 let debug_tx = project_state.debug_tx.clone();
1691
1692 let snapshot = active_buffer.read(cx).snapshot();
1693 let cursor_point = position.to_point(&snapshot);
1694 let current_offset = position.to_offset(&snapshot);
1695
1696 let mut user_actions: Vec<UserActionRecord> =
1697 project_state.user_actions.iter().cloned().collect();
1698
1699 if let Some(last_action) = user_actions.last() {
1700 if last_action.buffer_id == active_buffer.entity_id()
1701 && current_offset != last_action.offset
1702 {
1703 let timestamp_epoch_ms = SystemTime::now()
1704 .duration_since(UNIX_EPOCH)
1705 .map(|d| d.as_millis() as u64)
1706 .unwrap_or(0);
1707 user_actions.push(UserActionRecord {
1708 action_type: UserActionType::CursorMovement,
1709 buffer_id: active_buffer.entity_id(),
1710 line_number: cursor_point.row,
1711 offset: current_offset,
1712 timestamp_epoch_ms,
1713 });
1714 }
1715 }
1716 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1717 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1718 let diagnostic_search_range =
1719 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1720
1721 let related_files = if self.use_context {
1722 self.context_for_project(&project, cx)
1723 } else {
1724 Vec::new()
1725 };
1726
1727 let inputs = EditPredictionModelInput {
1728 project: project.clone(),
1729 buffer: active_buffer.clone(),
1730 snapshot: snapshot.clone(),
1731 position,
1732 events,
1733 related_files,
1734 recent_paths: project_state.recent_paths.clone(),
1735 trigger,
1736 diagnostic_search_range: diagnostic_search_range.clone(),
1737 debug_tx,
1738 user_actions,
1739 };
1740
1741 let can_collect_example = snapshot
1742 .file()
1743 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1744 && self.can_collect_events(&inputs.events, cx)
1745 && self.can_collect_related_files(&project, cx);
1746
1747 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1748 let events_for_capture =
1749 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1750 let related_files_for_capture = inputs.related_files.clone();
1751 if let Some(example_task) = capture_example::capture_example(
1752 project.clone(),
1753 active_buffer.clone(),
1754 position,
1755 events_for_capture,
1756 related_files_for_capture,
1757 false,
1758 cx,
1759 ) {
1760 cx.spawn(async move |_this, _cx| {
1761 let example = example_task.await?;
1762 telemetry::event!("Edit Prediction Example Captured", example = example);
1763 anyhow::Ok(())
1764 })
1765 .detach_and_log_err(cx);
1766 }
1767 }
1768 let task = match self.edit_prediction_model {
1769 EditPredictionModel::Zeta1 => {
1770 if should_send_testing_zeta2_request() {
1771 let mut zeta2_inputs = inputs.clone();
1772 zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1773 zeta2::request_prediction_with_zeta2(
1774 self,
1775 zeta2_inputs,
1776 Default::default(),
1777 cx,
1778 )
1779 .detach();
1780 }
1781 zeta1::request_prediction_with_zeta1(self, inputs, cx)
1782 }
1783 EditPredictionModel::Zeta2 { version } => {
1784 zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1785 }
1786 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1787 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1788 };
1789
1790 cx.spawn(async move |this, cx| {
1791 let prediction = task.await?;
1792
1793 if prediction.is_none() && allow_jump {
1794 let cursor_point = position.to_point(&snapshot);
1795 if has_events
1796 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1797 active_buffer.clone(),
1798 &snapshot,
1799 diagnostic_search_range,
1800 cursor_point,
1801 &project,
1802 cx,
1803 )
1804 .await?
1805 {
1806 return this
1807 .update(cx, |this, cx| {
1808 this.request_prediction_internal(
1809 project,
1810 jump_buffer,
1811 jump_position,
1812 trigger,
1813 false,
1814 cx,
1815 )
1816 })?
1817 .await;
1818 }
1819
1820 return anyhow::Ok(None);
1821 }
1822
1823 Ok(prediction)
1824 })
1825 }
1826
1827 async fn next_diagnostic_location(
1828 active_buffer: Entity<Buffer>,
1829 active_buffer_snapshot: &BufferSnapshot,
1830 active_buffer_diagnostic_search_range: Range<Point>,
1831 active_buffer_cursor_point: Point,
1832 project: &Entity<Project>,
1833 cx: &mut AsyncApp,
1834 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1835 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1836 let mut jump_location = active_buffer_snapshot
1837 .diagnostic_groups(None)
1838 .into_iter()
1839 .filter_map(|(_, group)| {
1840 let range = &group.entries[group.primary_ix]
1841 .range
1842 .to_point(&active_buffer_snapshot);
1843 if range.overlaps(&active_buffer_diagnostic_search_range) {
1844 None
1845 } else {
1846 Some(range.start)
1847 }
1848 })
1849 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1850 .map(|position| {
1851 (
1852 active_buffer.clone(),
1853 active_buffer_snapshot.anchor_before(position),
1854 )
1855 });
1856
1857 if jump_location.is_none() {
1858 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1859 let file = buffer.file()?;
1860
1861 Some(ProjectPath {
1862 worktree_id: file.worktree_id(cx),
1863 path: file.path().clone(),
1864 })
1865 });
1866
1867 let buffer_task = project.update(cx, |project, cx| {
1868 let (path, _, _) = project
1869 .diagnostic_summaries(false, cx)
1870 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1871 .max_by_key(|(path, _, _)| {
1872 // find the buffer with errors that shares most parent directories
1873 path.path
1874 .components()
1875 .zip(
1876 active_buffer_path
1877 .as_ref()
1878 .map(|p| p.path.components())
1879 .unwrap_or_default(),
1880 )
1881 .take_while(|(a, b)| a == b)
1882 .count()
1883 })?;
1884
1885 Some(project.open_buffer(path, cx))
1886 });
1887
1888 if let Some(buffer_task) = buffer_task {
1889 let closest_buffer = buffer_task.await?;
1890
1891 jump_location = closest_buffer
1892 .read_with(cx, |buffer, _cx| {
1893 buffer
1894 .buffer_diagnostics(None)
1895 .into_iter()
1896 .min_by_key(|entry| entry.diagnostic.severity)
1897 .map(|entry| entry.range.start)
1898 })
1899 .map(|position| (closest_buffer, position));
1900 }
1901 }
1902
1903 anyhow::Ok(jump_location)
1904 }
1905
1906 async fn send_raw_llm_request(
1907 request: RawCompletionRequest,
1908 client: Arc<Client>,
1909 custom_url: Option<Arc<Url>>,
1910 llm_token: LlmApiToken,
1911 app_version: Version,
1912 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1913 let url = if let Some(custom_url) = custom_url {
1914 custom_url.as_ref().clone()
1915 } else {
1916 client
1917 .http_client()
1918 .build_zed_llm_url("/predict_edits/raw", &[])?
1919 };
1920
1921 Self::send_api_request(
1922 |builder| {
1923 let req = builder
1924 .uri(url.as_ref())
1925 .body(serde_json::to_string(&request)?.into());
1926 Ok(req?)
1927 },
1928 client,
1929 llm_token,
1930 app_version,
1931 true,
1932 )
1933 .await
1934 }
1935
1936 pub(crate) async fn send_v3_request(
1937 input: ZetaPromptInput,
1938 prompt_version: ZetaVersion,
1939 client: Arc<Client>,
1940 llm_token: LlmApiToken,
1941 app_version: Version,
1942 trigger: PredictEditsRequestTrigger,
1943 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1944 let url = client
1945 .http_client()
1946 .build_zed_llm_url("/predict_edits/v3", &[])?;
1947
1948 let request = PredictEditsV3Request {
1949 input,
1950 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1951 prompt_version,
1952 trigger,
1953 };
1954
1955 Self::send_api_request(
1956 |builder| {
1957 let req = builder
1958 .uri(url.as_ref())
1959 .body(serde_json::to_string(&request)?.into());
1960 Ok(req?)
1961 },
1962 client,
1963 llm_token,
1964 app_version,
1965 true,
1966 )
1967 .await
1968 }
1969
1970 fn handle_api_response<T>(
1971 this: &WeakEntity<Self>,
1972 response: Result<(T, Option<EditPredictionUsage>)>,
1973 cx: &mut gpui::AsyncApp,
1974 ) -> Result<T> {
1975 match response {
1976 Ok((data, usage)) => {
1977 if let Some(usage) = usage {
1978 this.update(cx, |this, cx| {
1979 this.user_store.update(cx, |user_store, cx| {
1980 user_store.update_edit_prediction_usage(usage, cx);
1981 });
1982 })
1983 .ok();
1984 }
1985 Ok(data)
1986 }
1987 Err(err) => {
1988 if err.is::<ZedUpdateRequiredError>() {
1989 cx.update(|cx| {
1990 this.update(cx, |this, _cx| {
1991 this.update_required = true;
1992 })
1993 .ok();
1994
1995 let error_message: SharedString = err.to_string().into();
1996 show_app_notification(
1997 NotificationId::unique::<ZedUpdateRequiredError>(),
1998 cx,
1999 move |cx| {
2000 cx.new(|cx| {
2001 ErrorMessagePrompt::new(error_message.clone(), cx)
2002 .with_link_button("Update Zed", "https://zed.dev/releases")
2003 })
2004 },
2005 );
2006 });
2007 }
2008 Err(err)
2009 }
2010 }
2011 }
2012
2013 async fn send_api_request<Res>(
2014 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2015 client: Arc<Client>,
2016 llm_token: LlmApiToken,
2017 app_version: Version,
2018 require_auth: bool,
2019 ) -> Result<(Res, Option<EditPredictionUsage>)>
2020 where
2021 Res: DeserializeOwned,
2022 {
2023 let http_client = client.http_client();
2024
2025 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2026 Some(custom_token)
2027 } else if require_auth {
2028 Some(llm_token.acquire(&client).await?)
2029 } else {
2030 llm_token.acquire(&client).await.ok()
2031 };
2032 let mut did_retry = false;
2033
2034 loop {
2035 let request_builder = http_client::Request::builder().method(Method::POST);
2036
2037 let mut request_builder = request_builder
2038 .header("Content-Type", "application/json")
2039 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2040
2041 // Only add Authorization header if we have a token
2042 if let Some(ref token_value) = token {
2043 request_builder =
2044 request_builder.header("Authorization", format!("Bearer {}", token_value));
2045 }
2046
2047 let request = build(request_builder)?;
2048
2049 let mut response = http_client.send(request).await?;
2050
2051 if let Some(minimum_required_version) = response
2052 .headers()
2053 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2054 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2055 {
2056 anyhow::ensure!(
2057 app_version >= minimum_required_version,
2058 ZedUpdateRequiredError {
2059 minimum_version: minimum_required_version
2060 }
2061 );
2062 }
2063
2064 if response.status().is_success() {
2065 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2066
2067 let mut body = Vec::new();
2068 response.body_mut().read_to_end(&mut body).await?;
2069 return Ok((serde_json::from_slice(&body)?, usage));
2070 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2071 did_retry = true;
2072 token = Some(llm_token.refresh(&client).await?);
2073 } else {
2074 let mut body = String::new();
2075 response.body_mut().read_to_string(&mut body).await?;
2076 anyhow::bail!(
2077 "Request failed with status: {:?}\nBody: {}",
2078 response.status(),
2079 body
2080 );
2081 }
2082 }
2083 }
2084
2085 pub fn refresh_context(
2086 &mut self,
2087 project: &Entity<Project>,
2088 buffer: &Entity<language::Buffer>,
2089 cursor_position: language::Anchor,
2090 cx: &mut Context<Self>,
2091 ) {
2092 if self.use_context {
2093 self.get_or_init_project(project, cx)
2094 .context
2095 .update(cx, |store, cx| {
2096 store.refresh(buffer.clone(), cursor_position, cx);
2097 });
2098 }
2099 }
2100
2101 #[cfg(feature = "cli-support")]
2102 pub fn set_context_for_buffer(
2103 &mut self,
2104 project: &Entity<Project>,
2105 related_files: Vec<RelatedFile>,
2106 cx: &mut Context<Self>,
2107 ) {
2108 self.get_or_init_project(project, cx)
2109 .context
2110 .update(cx, |store, cx| {
2111 store.set_related_files(related_files, cx);
2112 });
2113 }
2114
2115 #[cfg(feature = "cli-support")]
2116 pub fn set_recent_paths_for_project(
2117 &mut self,
2118 project: &Entity<Project>,
2119 paths: impl IntoIterator<Item = project::ProjectPath>,
2120 cx: &mut Context<Self>,
2121 ) {
2122 let project_state = self.get_or_init_project(project, cx);
2123 project_state.recent_paths = paths.into_iter().collect();
2124 }
2125
2126 fn is_file_open_source(
2127 &self,
2128 project: &Entity<Project>,
2129 file: &Arc<dyn File>,
2130 cx: &App,
2131 ) -> bool {
2132 if !file.is_local() || file.is_private() {
2133 return false;
2134 }
2135 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2136 return false;
2137 };
2138 project_state
2139 .license_detection_watchers
2140 .get(&file.worktree_id(cx))
2141 .as_ref()
2142 .is_some_and(|watcher| watcher.is_project_open_source())
2143 }
2144
2145 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2146 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2147 }
2148
2149 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2150 if !self.data_collection_choice.is_enabled(cx) {
2151 return false;
2152 }
2153 events.iter().all(|event| {
2154 matches!(
2155 event.as_ref(),
2156 zeta_prompt::Event::BufferChange {
2157 in_open_source_repo: true,
2158 ..
2159 }
2160 )
2161 })
2162 }
2163
2164 fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
2165 if !self.data_collection_choice.is_enabled(cx) {
2166 return false;
2167 }
2168
2169 let related_with_buffers = self.context_for_project_with_buffers(project, cx);
2170
2171 related_with_buffers.iter().all(|(_, buffer)| {
2172 buffer
2173 .read(cx)
2174 .file()
2175 .is_some_and(|file| self.is_file_open_source(project, &file, cx))
2176 })
2177 }
2178
2179 fn load_data_collection_choice() -> DataCollectionChoice {
2180 let choice = KEY_VALUE_STORE
2181 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2182 .log_err()
2183 .flatten();
2184
2185 match choice.as_deref() {
2186 Some("true") => DataCollectionChoice::Enabled,
2187 Some("false") => DataCollectionChoice::Disabled,
2188 Some(_) => {
2189 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2190 DataCollectionChoice::NotAnswered
2191 }
2192 None => DataCollectionChoice::NotAnswered,
2193 }
2194 }
2195
2196 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2197 self.data_collection_choice = self.data_collection_choice.toggle();
2198 let new_choice = self.data_collection_choice;
2199 let is_enabled = new_choice.is_enabled(cx);
2200 db::write_and_log(cx, move || {
2201 KEY_VALUE_STORE.write_kvp(
2202 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2203 is_enabled.to_string(),
2204 )
2205 });
2206 }
2207
2208 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2209 self.shown_predictions.iter()
2210 }
2211
2212 pub fn shown_completions_len(&self) -> usize {
2213 self.shown_predictions.len()
2214 }
2215
2216 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2217 self.rated_predictions.contains(id)
2218 }
2219
2220 pub fn rate_prediction(
2221 &mut self,
2222 prediction: &EditPrediction,
2223 rating: EditPredictionRating,
2224 feedback: String,
2225 cx: &mut Context<Self>,
2226 ) {
2227 self.rated_predictions.insert(prediction.id.clone());
2228 telemetry::event!(
2229 "Edit Prediction Rated",
2230 rating,
2231 inputs = prediction.inputs,
2232 output = prediction
2233 .edit_preview
2234 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2235 feedback
2236 );
2237 self.client.telemetry().flush_events().detach();
2238 cx.notify();
2239 }
2240
2241 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2242 if cfg!(feature = "cli-support") {
2243 return;
2244 }
2245 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2246 && all_language_settings(None, cx).edit_predictions.use_context;
2247 }
2248}
2249
2250pub(crate) fn filter_redundant_excerpts(
2251 mut related_files: Vec<RelatedFile>,
2252 cursor_path: &Path,
2253 cursor_row_range: Range<u32>,
2254) -> Vec<RelatedFile> {
2255 for file in &mut related_files {
2256 if file.path.as_ref() == cursor_path {
2257 file.excerpts.retain(|excerpt| {
2258 excerpt.row_range.start < cursor_row_range.start
2259 || excerpt.row_range.end > cursor_row_range.end
2260 });
2261 }
2262 }
2263 related_files.retain(|file| !file.excerpts.is_empty());
2264 related_files
2265}
2266
2267#[derive(Error, Debug)]
2268#[error(
2269 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2270)]
2271pub struct ZedUpdateRequiredError {
2272 minimum_version: Version,
2273}
2274
2275#[derive(Debug, Clone, Copy)]
2276pub enum DataCollectionChoice {
2277 NotAnswered,
2278 Enabled,
2279 Disabled,
2280}
2281
2282impl DataCollectionChoice {
2283 pub fn is_enabled(self, cx: &App) -> bool {
2284 if cx.is_staff() {
2285 return true;
2286 }
2287 match self {
2288 Self::Enabled => true,
2289 Self::NotAnswered | Self::Disabled => false,
2290 }
2291 }
2292
2293 #[must_use]
2294 pub fn toggle(&self) -> DataCollectionChoice {
2295 match self {
2296 Self::Enabled => Self::Disabled,
2297 Self::Disabled => Self::Enabled,
2298 Self::NotAnswered => Self::Enabled,
2299 }
2300 }
2301}
2302
2303impl From<bool> for DataCollectionChoice {
2304 fn from(value: bool) -> Self {
2305 match value {
2306 true => DataCollectionChoice::Enabled,
2307 false => DataCollectionChoice::Disabled,
2308 }
2309 }
2310}
2311
2312struct ZedPredictUpsell;
2313
2314impl Dismissable for ZedPredictUpsell {
2315 const KEY: &'static str = "dismissed-edit-predict-upsell";
2316
2317 fn dismissed() -> bool {
2318 // To make this backwards compatible with older versions of Zed, we
2319 // check if the user has seen the previous Edit Prediction Onboarding
2320 // before, by checking the data collection choice which was written to
2321 // the database once the user clicked on "Accept and Enable"
2322 if KEY_VALUE_STORE
2323 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2324 .log_err()
2325 .is_some_and(|s| s.is_some())
2326 {
2327 return true;
2328 }
2329
2330 KEY_VALUE_STORE
2331 .read_kvp(Self::KEY)
2332 .log_err()
2333 .is_some_and(|s| s.is_some())
2334 }
2335}
2336
2337pub fn should_show_upsell_modal() -> bool {
2338 !ZedPredictUpsell::dismissed()
2339}
2340
2341pub fn init(cx: &mut App) {
2342 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2343 workspace.register_action(
2344 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2345 ZedPredictModal::toggle(
2346 workspace,
2347 workspace.user_store().clone(),
2348 workspace.client().clone(),
2349 window,
2350 cx,
2351 )
2352 },
2353 );
2354
2355 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2356 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2357 settings
2358 .project
2359 .all_languages
2360 .features
2361 .get_or_insert_default()
2362 .edit_prediction_provider = Some(EditPredictionProvider::None)
2363 });
2364 });
2365 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2366 EditPredictionStore::try_global(cx).and_then(|store| {
2367 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2368 })
2369 }
2370
2371 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2372 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2373 copilot_ui::initiate_sign_in(copilot, window, cx);
2374 }
2375 });
2376 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2377 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2378 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2379 }
2380 });
2381 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2382 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2383 copilot_ui::initiate_sign_out(copilot, window, cx);
2384 }
2385 });
2386 })
2387 .detach();
2388}