1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
33use project::{Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use text::Edit;
40use workspace::Workspace;
41use zeta_prompt::ZetaPromptInput;
42use zeta_prompt::ZetaVersion;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
50use std::{env, mem};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59mod onboarding_modal;
60pub mod open_ai_response;
61mod prediction;
62pub mod sweep_ai;
63
64pub mod udiff;
65
66mod capture_example;
67mod zed_edit_prediction_delegate;
68pub mod zeta1;
69pub mod zeta2;
70
71#[cfg(test)]
72mod edit_prediction_tests;
73
74use crate::capture_example::{
75 should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
76};
77use crate::license_detection::LicenseDetectionWatcher;
78use crate::mercury::Mercury;
79use crate::onboarding_modal::ZedPredictModal;
80pub use crate::prediction::EditPrediction;
81pub use crate::prediction::EditPredictionId;
82use crate::prediction::EditPredictionResult;
83pub use crate::sweep_ai::SweepAi;
84pub use capture_example::capture_example;
85pub use language_model::ApiKeyState;
86pub use telemetry_events::EditPredictionRating;
87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
88
89actions!(
90 edit_prediction,
91 [
92 /// Resets the edit prediction onboarding state.
93 ResetOnboarding,
94 /// Clears the edit prediction history.
95 ClearHistory,
96 ]
97);
98
99/// Maximum number of events to track.
100const EVENT_COUNT_MAX: usize = 6;
101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
105
106pub struct SweepFeatureFlag;
107
108impl FeatureFlag for SweepFeatureFlag {
109 const NAME: &str = "sweep-ai";
110}
111
112pub struct MercuryFeatureFlag;
113
114impl FeatureFlag for MercuryFeatureFlag {
115 const NAME: &str = "mercury";
116}
117
118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
119 LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
120
121pub struct Zeta2FeatureFlag;
122
123impl FeatureFlag for Zeta2FeatureFlag {
124 const NAME: &'static str = "zeta2";
125
126 fn enabled_for_staff() -> bool {
127 true
128 }
129}
130
131pub struct EditPredictionExampleCaptureFeatureFlag;
132
133impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
134 const NAME: &'static str = "edit-prediction-example-capture";
135
136 fn enabled_for_staff() -> bool {
137 true
138 }
139}
140
141#[derive(Clone)]
142struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
143
144impl Global for EditPredictionStoreGlobal {}
145
146pub struct EditPredictionStore {
147 client: Arc<Client>,
148 user_store: Entity<UserStore>,
149 llm_token: LlmApiToken,
150 _llm_token_subscription: Subscription,
151 projects: HashMap<EntityId, ProjectState>,
152 update_required: bool,
153 edit_prediction_model: EditPredictionModel,
154 pub sweep_ai: SweepAi,
155 pub mercury: Mercury,
156 data_collection_choice: DataCollectionChoice,
157 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
158 shown_predictions: VecDeque<EditPrediction>,
159 rated_predictions: HashSet<EditPredictionId>,
160 custom_predict_edits_url: Option<Arc<Url>>,
161}
162
163#[derive(Copy, Clone, Default, PartialEq, Eq)]
164pub enum EditPredictionModel {
165 #[default]
166 Zeta1,
167 Zeta2 {
168 version: ZetaVersion,
169 },
170 Sweep,
171 Mercury,
172}
173
174#[derive(Clone)]
175pub struct EditPredictionModelInput {
176 project: Entity<Project>,
177 buffer: Entity<Buffer>,
178 snapshot: BufferSnapshot,
179 position: Anchor,
180 events: Vec<Arc<zeta_prompt::Event>>,
181 related_files: Vec<RelatedFile>,
182 recent_paths: VecDeque<ProjectPath>,
183 trigger: PredictEditsRequestTrigger,
184 diagnostic_search_range: Range<Point>,
185 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
186 pub user_actions: Vec<UserActionRecord>,
187}
188
189#[derive(Debug)]
190pub enum DebugEvent {
191 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
192 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
193 EditPredictionStarted(EditPredictionStartedDebugEvent),
194 EditPredictionFinished(EditPredictionFinishedDebugEvent),
195}
196
197#[derive(Debug)]
198pub struct ContextRetrievalStartedDebugEvent {
199 pub project_entity_id: EntityId,
200 pub timestamp: Instant,
201 pub search_prompt: String,
202}
203
204#[derive(Debug)]
205pub struct ContextRetrievalFinishedDebugEvent {
206 pub project_entity_id: EntityId,
207 pub timestamp: Instant,
208 pub metadata: Vec<(&'static str, SharedString)>,
209}
210
211#[derive(Debug)]
212pub struct EditPredictionStartedDebugEvent {
213 pub buffer: WeakEntity<Buffer>,
214 pub position: Anchor,
215 pub prompt: Option<String>,
216}
217
218#[derive(Debug)]
219pub struct EditPredictionFinishedDebugEvent {
220 pub buffer: WeakEntity<Buffer>,
221 pub position: Anchor,
222 pub model_output: Option<String>,
223}
224
225const USER_ACTION_HISTORY_SIZE: usize = 16;
226
227#[derive(Clone, Debug)]
228pub struct UserActionRecord {
229 pub action_type: UserActionType,
230 pub buffer_id: EntityId,
231 pub line_number: u32,
232 pub offset: usize,
233 pub timestamp_epoch_ms: u64,
234}
235
236#[derive(Clone, Copy, Debug, PartialEq, Eq)]
237pub enum UserActionType {
238 InsertChar,
239 InsertSelection,
240 DeleteChar,
241 DeleteSelection,
242 CursorMovement,
243}
244
245/// An event with associated metadata for reconstructing buffer state.
246#[derive(Clone)]
247pub struct StoredEvent {
248 pub event: Arc<zeta_prompt::Event>,
249 pub old_snapshot: TextBufferSnapshot,
250}
251
252struct ProjectState {
253 events: VecDeque<StoredEvent>,
254 last_event: Option<LastEvent>,
255 recent_paths: VecDeque<ProjectPath>,
256 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
257 current_prediction: Option<CurrentEditPrediction>,
258 next_pending_prediction_id: usize,
259 pending_predictions: ArrayVec<PendingPrediction, 2>,
260 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
261 last_prediction_refresh: Option<(EntityId, Instant)>,
262 cancelled_predictions: HashSet<usize>,
263 context: Entity<RelatedExcerptStore>,
264 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
265 user_actions: VecDeque<UserActionRecord>,
266 _subscription: gpui::Subscription,
267 copilot: Option<Entity<Copilot>>,
268}
269
270impl ProjectState {
271 fn record_user_action(&mut self, action: UserActionRecord) {
272 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
273 self.user_actions.pop_front();
274 }
275 self.user_actions.push_back(action);
276 }
277
278 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
279 self.events
280 .iter()
281 .cloned()
282 .chain(
283 self.last_event
284 .as_ref()
285 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
286 )
287 .collect()
288 }
289
290 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
291 self.events
292 .iter()
293 .cloned()
294 .chain(self.last_event.as_ref().iter().flat_map(|event| {
295 let (one, two) = event.split_by_pause();
296 let one = one.finalize(&self.license_detection_watchers, cx);
297 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
298 one.into_iter().chain(two)
299 }))
300 .collect()
301 }
302
303 fn cancel_pending_prediction(
304 &mut self,
305 pending_prediction: PendingPrediction,
306 cx: &mut Context<EditPredictionStore>,
307 ) {
308 self.cancelled_predictions.insert(pending_prediction.id);
309
310 cx.spawn(async move |this, cx| {
311 let Some(prediction_id) = pending_prediction.task.await else {
312 return;
313 };
314
315 this.update(cx, |this, _cx| {
316 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
317 })
318 .ok();
319 })
320 .detach()
321 }
322
323 fn active_buffer(
324 &self,
325 project: &Entity<Project>,
326 cx: &App,
327 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
328 let project = project.read(cx);
329 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
330 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
331 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
332 Some((active_buffer, registered_buffer.last_position))
333 }
334}
335
336#[derive(Debug, Clone)]
337struct CurrentEditPrediction {
338 pub requested_by: PredictionRequestedBy,
339 pub prediction: EditPrediction,
340 pub was_shown: bool,
341 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
342}
343
344impl CurrentEditPrediction {
345 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
346 let Some(new_edits) = self
347 .prediction
348 .interpolate(&self.prediction.buffer.read(cx))
349 else {
350 return false;
351 };
352
353 if self.prediction.buffer != old_prediction.prediction.buffer {
354 return true;
355 }
356
357 let Some(old_edits) = old_prediction
358 .prediction
359 .interpolate(&old_prediction.prediction.buffer.read(cx))
360 else {
361 return true;
362 };
363
364 let requested_by_buffer_id = self.requested_by.buffer_id();
365
366 // This reduces the occurrence of UI thrash from replacing edits
367 //
368 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
369 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
370 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
371 && old_edits.len() == 1
372 && new_edits.len() == 1
373 {
374 let (old_range, old_text) = &old_edits[0];
375 let (new_range, new_text) = &new_edits[0];
376 new_range == old_range && new_text.starts_with(old_text.as_ref())
377 } else {
378 true
379 }
380 }
381}
382
383#[derive(Debug, Clone)]
384enum PredictionRequestedBy {
385 DiagnosticsUpdate,
386 Buffer(EntityId),
387}
388
389impl PredictionRequestedBy {
390 pub fn buffer_id(&self) -> Option<EntityId> {
391 match self {
392 PredictionRequestedBy::DiagnosticsUpdate => None,
393 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
394 }
395 }
396}
397
398#[derive(Debug)]
399struct PendingPrediction {
400 id: usize,
401 task: Task<Option<EditPredictionId>>,
402}
403
404/// A prediction from the perspective of a buffer.
405#[derive(Debug)]
406enum BufferEditPrediction<'a> {
407 Local { prediction: &'a EditPrediction },
408 Jump { prediction: &'a EditPrediction },
409}
410
411#[cfg(test)]
412impl std::ops::Deref for BufferEditPrediction<'_> {
413 type Target = EditPrediction;
414
415 fn deref(&self) -> &Self::Target {
416 match self {
417 BufferEditPrediction::Local { prediction } => prediction,
418 BufferEditPrediction::Jump { prediction } => prediction,
419 }
420 }
421}
422
423struct RegisteredBuffer {
424 file: Option<Arc<dyn File>>,
425 snapshot: TextBufferSnapshot,
426 last_position: Option<Anchor>,
427 _subscriptions: [gpui::Subscription; 2],
428}
429
430#[derive(Clone)]
431struct LastEvent {
432 old_snapshot: TextBufferSnapshot,
433 new_snapshot: TextBufferSnapshot,
434 old_file: Option<Arc<dyn File>>,
435 new_file: Option<Arc<dyn File>>,
436 edit_range: Option<Range<Anchor>>,
437 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
438 last_edit_time: Option<Instant>,
439}
440
441impl LastEvent {
442 pub fn finalize(
443 &self,
444 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
445 cx: &App,
446 ) -> Option<StoredEvent> {
447 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
448 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
449
450 let in_open_source_repo =
451 [self.new_file.as_ref(), self.old_file.as_ref()]
452 .iter()
453 .all(|file| {
454 file.is_some_and(|file| {
455 license_detection_watchers
456 .get(&file.worktree_id(cx))
457 .is_some_and(|watcher| watcher.is_project_open_source())
458 })
459 });
460
461 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
462
463 if path == old_path && diff.is_empty() {
464 None
465 } else {
466 Some(StoredEvent {
467 event: Arc::new(zeta_prompt::Event::BufferChange {
468 old_path,
469 path,
470 diff,
471 in_open_source_repo,
472 // TODO: Actually detect if this edit was predicted or not
473 predicted: false,
474 }),
475 old_snapshot: self.old_snapshot.clone(),
476 })
477 }
478 }
479
480 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
481 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
482 return (self.clone(), None);
483 };
484
485 let before = LastEvent {
486 old_snapshot: self.old_snapshot.clone(),
487 new_snapshot: boundary_snapshot.clone(),
488 old_file: self.old_file.clone(),
489 new_file: self.new_file.clone(),
490 edit_range: None,
491 snapshot_after_last_editing_pause: None,
492 last_edit_time: self.last_edit_time,
493 };
494
495 let after = LastEvent {
496 old_snapshot: boundary_snapshot.clone(),
497 new_snapshot: self.new_snapshot.clone(),
498 old_file: self.old_file.clone(),
499 new_file: self.new_file.clone(),
500 edit_range: None,
501 snapshot_after_last_editing_pause: None,
502 last_edit_time: self.last_edit_time,
503 };
504
505 (before, Some(after))
506 }
507}
508
509pub(crate) fn compute_diff_between_snapshots(
510 old_snapshot: &TextBufferSnapshot,
511 new_snapshot: &TextBufferSnapshot,
512) -> Option<String> {
513 let edits: Vec<Edit<usize>> = new_snapshot
514 .edits_since::<usize>(&old_snapshot.version)
515 .collect();
516
517 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
518
519 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
520 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
521 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
522 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
523
524 const CONTEXT_LINES: u32 = 3;
525
526 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
527 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
528 let old_context_end_row =
529 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
530 let new_context_end_row =
531 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
532
533 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
534 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
535 let old_end_line_offset = old_snapshot
536 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
537 let new_end_line_offset = new_snapshot
538 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
539 let old_edit_range = old_start_line_offset..old_end_line_offset;
540 let new_edit_range = new_start_line_offset..new_end_line_offset;
541
542 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
543 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
544
545 let diff = language::unified_diff_with_offsets(
546 &old_region_text,
547 &new_region_text,
548 old_context_start_row,
549 new_context_start_row,
550 );
551
552 Some(diff)
553}
554
555fn buffer_path_with_id_fallback(
556 file: Option<&Arc<dyn File>>,
557 snapshot: &TextBufferSnapshot,
558 cx: &App,
559) -> Arc<Path> {
560 if let Some(file) = file {
561 file.full_path(cx).into()
562 } else {
563 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
564 }
565}
566
567impl EditPredictionStore {
568 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
569 cx.try_global::<EditPredictionStoreGlobal>()
570 .map(|global| global.0.clone())
571 }
572
573 pub fn global(
574 client: &Arc<Client>,
575 user_store: &Entity<UserStore>,
576 cx: &mut App,
577 ) -> Entity<Self> {
578 cx.try_global::<EditPredictionStoreGlobal>()
579 .map(|global| global.0.clone())
580 .unwrap_or_else(|| {
581 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
582 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
583 ep_store
584 })
585 }
586
587 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
588 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
589 let data_collection_choice = Self::load_data_collection_choice();
590
591 let llm_token = LlmApiToken::default();
592
593 let (reject_tx, reject_rx) = mpsc::unbounded();
594 cx.background_spawn({
595 let client = client.clone();
596 let llm_token = llm_token.clone();
597 let app_version = AppVersion::global(cx);
598 let background_executor = cx.background_executor().clone();
599 async move {
600 Self::handle_rejected_predictions(
601 reject_rx,
602 client,
603 llm_token,
604 app_version,
605 background_executor,
606 )
607 .await
608 }
609 })
610 .detach();
611
612 let this = Self {
613 projects: HashMap::default(),
614 client,
615 user_store,
616 llm_token,
617 _llm_token_subscription: cx.subscribe(
618 &refresh_llm_token_listener,
619 |this, _listener, _event, cx| {
620 let client = this.client.clone();
621 let llm_token = this.llm_token.clone();
622 cx.spawn(async move |_this, _cx| {
623 llm_token.refresh(&client).await?;
624 anyhow::Ok(())
625 })
626 .detach_and_log_err(cx);
627 },
628 ),
629 update_required: false,
630 edit_prediction_model: EditPredictionModel::Zeta2 {
631 version: Default::default(),
632 },
633 sweep_ai: SweepAi::new(cx),
634 mercury: Mercury::new(cx),
635
636 data_collection_choice,
637 reject_predictions_tx: reject_tx,
638 rated_predictions: Default::default(),
639 shown_predictions: Default::default(),
640 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
641 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
642 Err(_) => None,
643 },
644 };
645
646 this
647 }
648
649 #[cfg(test)]
650 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
651 self.custom_predict_edits_url = Some(url.into());
652 }
653
654 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
655 self.edit_prediction_model = model;
656 }
657
658 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
659 self.sweep_ai.api_token.read(cx).has_key()
660 }
661
662 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
663 self.mercury.api_token.read(cx).has_key()
664 }
665
666 pub fn clear_history(&mut self) {
667 for project_state in self.projects.values_mut() {
668 project_state.events.clear();
669 project_state.last_event.take();
670 }
671 }
672
673 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
674 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
675 project_state.events.clear();
676 project_state.last_event.take();
677 }
678 }
679
680 pub fn edit_history_for_project(
681 &self,
682 project: &Entity<Project>,
683 cx: &App,
684 ) -> Vec<StoredEvent> {
685 self.projects
686 .get(&project.entity_id())
687 .map(|project_state| project_state.events(cx))
688 .unwrap_or_default()
689 }
690
691 pub fn edit_history_for_project_with_pause_split_last_event(
692 &self,
693 project: &Entity<Project>,
694 cx: &App,
695 ) -> Vec<StoredEvent> {
696 self.projects
697 .get(&project.entity_id())
698 .map(|project_state| project_state.events_split_by_pause(cx))
699 .unwrap_or_default()
700 }
701
702 pub fn context_for_project<'a>(
703 &'a self,
704 project: &Entity<Project>,
705 cx: &'a mut App,
706 ) -> Vec<RelatedFile> {
707 self.projects
708 .get(&project.entity_id())
709 .map(|project| {
710 project
711 .context
712 .update(cx, |context, cx| context.related_files(cx))
713 })
714 .unwrap_or_default()
715 }
716
717 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
718 self.projects
719 .get(&project.entity_id())
720 .and_then(|project| project.copilot.clone())
721 }
722
723 pub fn start_copilot_for_project(
724 &mut self,
725 project: &Entity<Project>,
726 cx: &mut Context<Self>,
727 ) -> Option<Entity<Copilot>> {
728 let state = self.get_or_init_project(project, cx);
729
730 if state.copilot.is_some() {
731 return state.copilot.clone();
732 }
733 let _project = project.clone();
734 let project = project.read(cx);
735
736 let node = project.node_runtime().cloned();
737 if let Some(node) = node {
738 let next_id = project.languages().next_language_server_id();
739 let fs = project.fs().clone();
740
741 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
742 state.copilot = Some(copilot.clone());
743 Some(copilot)
744 } else {
745 None
746 }
747 }
748
749 pub fn context_for_project_with_buffers<'a>(
750 &'a self,
751 project: &Entity<Project>,
752 cx: &'a mut App,
753 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
754 self.projects
755 .get(&project.entity_id())
756 .map(|project| {
757 project
758 .context
759 .update(cx, |context, cx| context.related_files_with_buffers(cx))
760 })
761 .unwrap_or_default()
762 }
763
764 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
765 if matches!(
766 self.edit_prediction_model,
767 EditPredictionModel::Zeta2 { .. }
768 ) {
769 self.user_store.read(cx).edit_prediction_usage()
770 } else {
771 None
772 }
773 }
774
775 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
776 self.get_or_init_project(project, cx);
777 }
778
779 pub fn register_buffer(
780 &mut self,
781 buffer: &Entity<Buffer>,
782 project: &Entity<Project>,
783 cx: &mut Context<Self>,
784 ) {
785 let project_state = self.get_or_init_project(project, cx);
786 Self::register_buffer_impl(project_state, buffer, project, cx);
787 }
788
789 fn get_or_init_project(
790 &mut self,
791 project: &Entity<Project>,
792 cx: &mut Context<Self>,
793 ) -> &mut ProjectState {
794 let entity_id = project.entity_id();
795 self.projects
796 .entry(entity_id)
797 .or_insert_with(|| ProjectState {
798 context: {
799 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
800 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
801 this.handle_excerpt_store_event(entity_id, event);
802 })
803 .detach();
804 related_excerpt_store
805 },
806 events: VecDeque::new(),
807 last_event: None,
808 recent_paths: VecDeque::new(),
809 debug_tx: None,
810 registered_buffers: HashMap::default(),
811 current_prediction: None,
812 cancelled_predictions: HashSet::default(),
813 pending_predictions: ArrayVec::new(),
814 next_pending_prediction_id: 0,
815 last_prediction_refresh: None,
816 license_detection_watchers: HashMap::default(),
817 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
818 _subscription: cx.subscribe(&project, Self::handle_project_event),
819 copilot: None,
820 })
821 }
822
823 pub fn remove_project(&mut self, project: &Entity<Project>) {
824 self.projects.remove(&project.entity_id());
825 }
826
827 fn handle_excerpt_store_event(
828 &mut self,
829 project_entity_id: EntityId,
830 event: &RelatedExcerptStoreEvent,
831 ) {
832 if let Some(project_state) = self.projects.get(&project_entity_id) {
833 if let Some(debug_tx) = project_state.debug_tx.clone() {
834 match event {
835 RelatedExcerptStoreEvent::StartedRefresh => {
836 debug_tx
837 .unbounded_send(DebugEvent::ContextRetrievalStarted(
838 ContextRetrievalStartedDebugEvent {
839 project_entity_id: project_entity_id,
840 timestamp: Instant::now(),
841 search_prompt: String::new(),
842 },
843 ))
844 .ok();
845 }
846 RelatedExcerptStoreEvent::FinishedRefresh {
847 cache_hit_count,
848 cache_miss_count,
849 mean_definition_latency,
850 max_definition_latency,
851 } => {
852 debug_tx
853 .unbounded_send(DebugEvent::ContextRetrievalFinished(
854 ContextRetrievalFinishedDebugEvent {
855 project_entity_id: project_entity_id,
856 timestamp: Instant::now(),
857 metadata: vec![
858 (
859 "Cache Hits",
860 format!(
861 "{}/{}",
862 cache_hit_count,
863 cache_hit_count + cache_miss_count
864 )
865 .into(),
866 ),
867 (
868 "Max LSP Time",
869 format!("{} ms", max_definition_latency.as_millis())
870 .into(),
871 ),
872 (
873 "Mean LSP Time",
874 format!("{} ms", mean_definition_latency.as_millis())
875 .into(),
876 ),
877 ],
878 },
879 ))
880 .ok();
881 }
882 }
883 }
884 }
885 }
886
887 pub fn debug_info(
888 &mut self,
889 project: &Entity<Project>,
890 cx: &mut Context<Self>,
891 ) -> mpsc::UnboundedReceiver<DebugEvent> {
892 let project_state = self.get_or_init_project(project, cx);
893 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
894 project_state.debug_tx = Some(debug_watch_tx);
895 debug_watch_rx
896 }
897
898 fn handle_project_event(
899 &mut self,
900 project: Entity<Project>,
901 event: &project::Event,
902 cx: &mut Context<Self>,
903 ) {
904 // TODO [zeta2] init with recent paths
905 match event {
906 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
907 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
908 return;
909 };
910 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
911 if let Some(path) = path {
912 if let Some(ix) = project_state
913 .recent_paths
914 .iter()
915 .position(|probe| probe == &path)
916 {
917 project_state.recent_paths.remove(ix);
918 }
919 project_state.recent_paths.push_front(path);
920 }
921 }
922 project::Event::DiagnosticsUpdated { .. } => {
923 if cx.has_flag::<Zeta2FeatureFlag>() {
924 self.refresh_prediction_from_diagnostics(project, cx);
925 }
926 }
927 _ => (),
928 }
929 }
930
931 fn register_buffer_impl<'a>(
932 project_state: &'a mut ProjectState,
933 buffer: &Entity<Buffer>,
934 project: &Entity<Project>,
935 cx: &mut Context<Self>,
936 ) -> &'a mut RegisteredBuffer {
937 let buffer_id = buffer.entity_id();
938
939 if let Some(file) = buffer.read(cx).file() {
940 let worktree_id = file.worktree_id(cx);
941 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
942 project_state
943 .license_detection_watchers
944 .entry(worktree_id)
945 .or_insert_with(|| {
946 let project_entity_id = project.entity_id();
947 cx.observe_release(&worktree, move |this, _worktree, _cx| {
948 let Some(project_state) = this.projects.get_mut(&project_entity_id)
949 else {
950 return;
951 };
952 project_state
953 .license_detection_watchers
954 .remove(&worktree_id);
955 })
956 .detach();
957 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
958 });
959 }
960 }
961
962 match project_state.registered_buffers.entry(buffer_id) {
963 hash_map::Entry::Occupied(entry) => entry.into_mut(),
964 hash_map::Entry::Vacant(entry) => {
965 let buf = buffer.read(cx);
966 let snapshot = buf.text_snapshot();
967 let file = buf.file().cloned();
968 let project_entity_id = project.entity_id();
969 entry.insert(RegisteredBuffer {
970 snapshot,
971 file,
972 last_position: None,
973 _subscriptions: [
974 cx.subscribe(buffer, {
975 let project = project.downgrade();
976 move |this, buffer, event, cx| {
977 if let language::BufferEvent::Edited = event
978 && let Some(project) = project.upgrade()
979 {
980 this.report_changes_for_buffer(&buffer, &project, cx);
981 }
982 }
983 }),
984 cx.observe_release(buffer, move |this, _buffer, _cx| {
985 let Some(project_state) = this.projects.get_mut(&project_entity_id)
986 else {
987 return;
988 };
989 project_state.registered_buffers.remove(&buffer_id);
990 }),
991 ],
992 })
993 }
994 }
995 }
996
997 fn report_changes_for_buffer(
998 &mut self,
999 buffer: &Entity<Buffer>,
1000 project: &Entity<Project>,
1001 cx: &mut Context<Self>,
1002 ) {
1003 let project_state = self.get_or_init_project(project, cx);
1004 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1005
1006 let buf = buffer.read(cx);
1007 let new_file = buf.file().cloned();
1008 let new_snapshot = buf.text_snapshot();
1009 if new_snapshot.version == registered_buffer.snapshot.version {
1010 return;
1011 }
1012
1013 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1014 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1015 let mut num_edits = 0usize;
1016 let mut total_deleted = 0usize;
1017 let mut total_inserted = 0usize;
1018 let mut edit_range: Option<Range<Anchor>> = None;
1019 let mut last_offset: Option<usize> = None;
1020
1021 for (edit, anchor_range) in
1022 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1023 {
1024 num_edits += 1;
1025 total_deleted += edit.old.len();
1026 total_inserted += edit.new.len();
1027 edit_range = Some(match edit_range {
1028 None => anchor_range,
1029 Some(acc) => acc.start..anchor_range.end,
1030 });
1031 last_offset = Some(edit.new.end);
1032 }
1033
1034 if num_edits > 0 {
1035 let action_type = match (total_deleted, total_inserted, num_edits) {
1036 (0, ins, n) if ins == n => UserActionType::InsertChar,
1037 (0, _, _) => UserActionType::InsertSelection,
1038 (del, 0, n) if del == n => UserActionType::DeleteChar,
1039 (_, 0, _) => UserActionType::DeleteSelection,
1040 (_, ins, n) if ins == n => UserActionType::InsertChar,
1041 (_, _, _) => UserActionType::InsertSelection,
1042 };
1043
1044 if let Some(offset) = last_offset {
1045 let point = new_snapshot.offset_to_point(offset);
1046 let timestamp_epoch_ms = SystemTime::now()
1047 .duration_since(UNIX_EPOCH)
1048 .map(|d| d.as_millis() as u64)
1049 .unwrap_or(0);
1050 project_state.record_user_action(UserActionRecord {
1051 action_type,
1052 buffer_id: buffer.entity_id(),
1053 line_number: point.row,
1054 offset,
1055 timestamp_epoch_ms,
1056 });
1057 }
1058 }
1059
1060 let events = &mut project_state.events;
1061
1062 let now = cx.background_executor().now();
1063 if let Some(last_event) = project_state.last_event.as_mut() {
1064 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1065 == last_event.new_snapshot.remote_id()
1066 && old_snapshot.version == last_event.new_snapshot.version;
1067
1068 let should_coalesce = is_next_snapshot_of_same_buffer
1069 && edit_range
1070 .as_ref()
1071 .zip(last_event.edit_range.as_ref())
1072 .is_some_and(|(a, b)| {
1073 let a = a.to_point(&new_snapshot);
1074 let b = b.to_point(&new_snapshot);
1075 if a.start > b.end {
1076 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1077 } else if b.start > a.end {
1078 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1079 } else {
1080 true
1081 }
1082 });
1083
1084 if should_coalesce {
1085 let pause_elapsed = last_event
1086 .last_edit_time
1087 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1088 .unwrap_or(false);
1089 if pause_elapsed {
1090 last_event.snapshot_after_last_editing_pause =
1091 Some(last_event.new_snapshot.clone());
1092 }
1093
1094 last_event.edit_range = edit_range;
1095 last_event.new_snapshot = new_snapshot;
1096 last_event.last_edit_time = Some(now);
1097 return;
1098 }
1099 }
1100
1101 if let Some(event) = project_state.last_event.take() {
1102 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1103 if events.len() + 1 >= EVENT_COUNT_MAX {
1104 events.pop_front();
1105 }
1106 events.push_back(event);
1107 }
1108 }
1109
1110 project_state.last_event = Some(LastEvent {
1111 old_file,
1112 new_file,
1113 old_snapshot,
1114 new_snapshot,
1115 edit_range,
1116 snapshot_after_last_editing_pause: None,
1117 last_edit_time: Some(now),
1118 });
1119 }
1120
1121 fn prediction_at(
1122 &mut self,
1123 buffer: &Entity<Buffer>,
1124 position: Option<language::Anchor>,
1125 project: &Entity<Project>,
1126 cx: &App,
1127 ) -> Option<BufferEditPrediction<'_>> {
1128 let project_state = self.projects.get_mut(&project.entity_id())?;
1129 if let Some(position) = position
1130 && let Some(buffer) = project_state
1131 .registered_buffers
1132 .get_mut(&buffer.entity_id())
1133 {
1134 buffer.last_position = Some(position);
1135 }
1136
1137 let CurrentEditPrediction {
1138 requested_by,
1139 prediction,
1140 ..
1141 } = project_state.current_prediction.as_ref()?;
1142
1143 if prediction.targets_buffer(buffer.read(cx)) {
1144 Some(BufferEditPrediction::Local { prediction })
1145 } else {
1146 let show_jump = match requested_by {
1147 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1148 requested_by_buffer_id == &buffer.entity_id()
1149 }
1150 PredictionRequestedBy::DiagnosticsUpdate => true,
1151 };
1152
1153 if show_jump {
1154 Some(BufferEditPrediction::Jump { prediction })
1155 } else {
1156 None
1157 }
1158 }
1159 }
1160
1161 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1162 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1163 return;
1164 };
1165
1166 let Some(current_prediction) = project_state.current_prediction.take() else {
1167 return;
1168 };
1169
1170 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1171 project_state.cancel_pending_prediction(pending_prediction, cx);
1172 }
1173
1174 match self.edit_prediction_model {
1175 EditPredictionModel::Sweep => {
1176 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1177 }
1178 EditPredictionModel::Mercury => {}
1179 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1180 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1181 }
1182 }
1183 }
1184
1185 async fn handle_rejected_predictions(
1186 rx: UnboundedReceiver<EditPredictionRejection>,
1187 client: Arc<Client>,
1188 llm_token: LlmApiToken,
1189 app_version: Version,
1190 background_executor: BackgroundExecutor,
1191 ) {
1192 let mut rx = std::pin::pin!(rx.peekable());
1193 let mut batched = Vec::new();
1194
1195 while let Some(rejection) = rx.next().await {
1196 batched.push(rejection);
1197
1198 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1199 select_biased! {
1200 next = rx.as_mut().peek().fuse() => {
1201 if next.is_some() {
1202 continue;
1203 }
1204 }
1205 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1206 }
1207 }
1208
1209 let url = client
1210 .http_client()
1211 .build_zed_llm_url("/predict_edits/reject", &[])
1212 .unwrap();
1213
1214 let flush_count = batched
1215 .len()
1216 // in case items have accumulated after failure
1217 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1218 let start = batched.len() - flush_count;
1219
1220 let body = RejectEditPredictionsBodyRef {
1221 rejections: &batched[start..],
1222 };
1223
1224 let result = Self::send_api_request::<()>(
1225 |builder| {
1226 let req = builder
1227 .uri(url.as_ref())
1228 .body(serde_json::to_string(&body)?.into());
1229 anyhow::Ok(req?)
1230 },
1231 client.clone(),
1232 llm_token.clone(),
1233 app_version.clone(),
1234 true,
1235 )
1236 .await;
1237
1238 if result.log_err().is_some() {
1239 batched.drain(start..);
1240 }
1241 }
1242 }
1243
1244 fn reject_current_prediction(
1245 &mut self,
1246 reason: EditPredictionRejectReason,
1247 project: &Entity<Project>,
1248 ) {
1249 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1250 project_state.pending_predictions.clear();
1251 if let Some(prediction) = project_state.current_prediction.take() {
1252 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1253 }
1254 };
1255 }
1256
1257 fn did_show_current_prediction(
1258 &mut self,
1259 project: &Entity<Project>,
1260 display_type: edit_prediction_types::SuggestionDisplayType,
1261 cx: &mut Context<Self>,
1262 ) {
1263 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1264 return;
1265 };
1266
1267 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1268 return;
1269 };
1270
1271 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1272 let previous_shown_with = current_prediction.shown_with;
1273
1274 if previous_shown_with.is_none() || !is_jump {
1275 current_prediction.shown_with = Some(display_type);
1276 }
1277
1278 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1279
1280 if is_first_non_jump_show {
1281 current_prediction.was_shown = true;
1282 }
1283
1284 let display_type_changed = previous_shown_with != Some(display_type);
1285
1286 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1287 sweep_ai::edit_prediction_shown(
1288 &self.sweep_ai,
1289 self.client.clone(),
1290 ¤t_prediction.prediction,
1291 display_type,
1292 cx,
1293 );
1294 }
1295
1296 if is_first_non_jump_show {
1297 self.shown_predictions
1298 .push_front(current_prediction.prediction.clone());
1299 if self.shown_predictions.len() > 50 {
1300 let completion = self.shown_predictions.pop_back().unwrap();
1301 self.rated_predictions.remove(&completion.id);
1302 }
1303 }
1304 }
1305
1306 fn reject_prediction(
1307 &mut self,
1308 prediction_id: EditPredictionId,
1309 reason: EditPredictionRejectReason,
1310 was_shown: bool,
1311 ) {
1312 match self.edit_prediction_model {
1313 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1314 if self.custom_predict_edits_url.is_some() {
1315 return;
1316 }
1317 }
1318 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1319 }
1320
1321 self.reject_predictions_tx
1322 .unbounded_send(EditPredictionRejection {
1323 request_id: prediction_id.to_string(),
1324 reason,
1325 was_shown,
1326 })
1327 .log_err();
1328 }
1329
1330 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1331 self.projects
1332 .get(&project.entity_id())
1333 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1334 }
1335
1336 pub fn refresh_prediction_from_buffer(
1337 &mut self,
1338 project: Entity<Project>,
1339 buffer: Entity<Buffer>,
1340 position: language::Anchor,
1341 cx: &mut Context<Self>,
1342 ) {
1343 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1344 let Some(request_task) = this
1345 .update(cx, |this, cx| {
1346 this.request_prediction(
1347 &project,
1348 &buffer,
1349 position,
1350 PredictEditsRequestTrigger::Other,
1351 cx,
1352 )
1353 })
1354 .log_err()
1355 else {
1356 return Task::ready(anyhow::Ok(None));
1357 };
1358
1359 cx.spawn(async move |_cx| {
1360 request_task.await.map(|prediction_result| {
1361 prediction_result.map(|prediction_result| {
1362 (
1363 prediction_result,
1364 PredictionRequestedBy::Buffer(buffer.entity_id()),
1365 )
1366 })
1367 })
1368 })
1369 })
1370 }
1371
1372 pub fn refresh_prediction_from_diagnostics(
1373 &mut self,
1374 project: Entity<Project>,
1375 cx: &mut Context<Self>,
1376 ) {
1377 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1378 return;
1379 };
1380
1381 // Prefer predictions from buffer
1382 if project_state.current_prediction.is_some() {
1383 return;
1384 };
1385
1386 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1387 let Some((active_buffer, snapshot, cursor_point)) = this
1388 .read_with(cx, |this, cx| {
1389 let project_state = this.projects.get(&project.entity_id())?;
1390 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1391 let snapshot = buffer.read(cx).snapshot();
1392
1393 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1394 return None;
1395 }
1396
1397 let cursor_point = position
1398 .map(|pos| pos.to_point(&snapshot))
1399 .unwrap_or_default();
1400
1401 Some((buffer, snapshot, cursor_point))
1402 })
1403 .log_err()
1404 .flatten()
1405 else {
1406 return Task::ready(anyhow::Ok(None));
1407 };
1408
1409 cx.spawn(async move |cx| {
1410 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1411 active_buffer,
1412 &snapshot,
1413 Default::default(),
1414 cursor_point,
1415 &project,
1416 cx,
1417 )
1418 .await?
1419 else {
1420 return anyhow::Ok(None);
1421 };
1422
1423 let Some(prediction_result) = this
1424 .update(cx, |this, cx| {
1425 this.request_prediction(
1426 &project,
1427 &jump_buffer,
1428 jump_position,
1429 PredictEditsRequestTrigger::Diagnostics,
1430 cx,
1431 )
1432 })?
1433 .await?
1434 else {
1435 return anyhow::Ok(None);
1436 };
1437
1438 this.update(cx, |this, cx| {
1439 Some((
1440 if this
1441 .get_or_init_project(&project, cx)
1442 .current_prediction
1443 .is_none()
1444 {
1445 prediction_result
1446 } else {
1447 EditPredictionResult {
1448 id: prediction_result.id,
1449 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1450 }
1451 },
1452 PredictionRequestedBy::DiagnosticsUpdate,
1453 ))
1454 })
1455 })
1456 });
1457 }
1458
1459 fn predictions_enabled_at(
1460 snapshot: &BufferSnapshot,
1461 position: Option<language::Anchor>,
1462 cx: &App,
1463 ) -> bool {
1464 let file = snapshot.file();
1465 let all_settings = all_language_settings(file, cx);
1466 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1467 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1468 {
1469 return false;
1470 }
1471
1472 if let Some(last_position) = position {
1473 let settings = snapshot.settings_at(last_position, cx);
1474
1475 if !settings.edit_predictions_disabled_in.is_empty()
1476 && let Some(scope) = snapshot.language_scope_at(last_position)
1477 && let Some(scope_name) = scope.override_name()
1478 && settings
1479 .edit_predictions_disabled_in
1480 .iter()
1481 .any(|s| s == scope_name)
1482 {
1483 return false;
1484 }
1485 }
1486
1487 true
1488 }
1489
1490 #[cfg(not(test))]
1491 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1492 #[cfg(test)]
1493 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1494
1495 fn queue_prediction_refresh(
1496 &mut self,
1497 project: Entity<Project>,
1498 throttle_entity: EntityId,
1499 cx: &mut Context<Self>,
1500 do_refresh: impl FnOnce(
1501 WeakEntity<Self>,
1502 &mut AsyncApp,
1503 )
1504 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1505 + 'static,
1506 ) {
1507 let project_state = self.get_or_init_project(&project, cx);
1508 let pending_prediction_id = project_state.next_pending_prediction_id;
1509 project_state.next_pending_prediction_id += 1;
1510 let last_request = project_state.last_prediction_refresh;
1511
1512 let task = cx.spawn(async move |this, cx| {
1513 if let Some((last_entity, last_timestamp)) = last_request
1514 && throttle_entity == last_entity
1515 && let Some(timeout) =
1516 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1517 {
1518 cx.background_executor().timer(timeout).await;
1519 }
1520
1521 // If this task was cancelled before the throttle timeout expired,
1522 // do not perform a request.
1523 let mut is_cancelled = true;
1524 this.update(cx, |this, cx| {
1525 let project_state = this.get_or_init_project(&project, cx);
1526 if !project_state
1527 .cancelled_predictions
1528 .remove(&pending_prediction_id)
1529 {
1530 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1531 is_cancelled = false;
1532 }
1533 })
1534 .ok();
1535 if is_cancelled {
1536 return None;
1537 }
1538
1539 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1540 let new_prediction_id = new_prediction_result
1541 .as_ref()
1542 .map(|(prediction, _)| prediction.id.clone());
1543
1544 // When a prediction completes, remove it from the pending list, and cancel
1545 // any pending predictions that were enqueued before it.
1546 this.update(cx, |this, cx| {
1547 let project_state = this.get_or_init_project(&project, cx);
1548
1549 let is_cancelled = project_state
1550 .cancelled_predictions
1551 .remove(&pending_prediction_id);
1552
1553 let new_current_prediction = if !is_cancelled
1554 && let Some((prediction_result, requested_by)) = new_prediction_result
1555 {
1556 match prediction_result.prediction {
1557 Ok(prediction) => {
1558 let new_prediction = CurrentEditPrediction {
1559 requested_by,
1560 prediction,
1561 was_shown: false,
1562 shown_with: None,
1563 };
1564
1565 if let Some(current_prediction) =
1566 project_state.current_prediction.as_ref()
1567 {
1568 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1569 {
1570 this.reject_current_prediction(
1571 EditPredictionRejectReason::Replaced,
1572 &project,
1573 );
1574
1575 Some(new_prediction)
1576 } else {
1577 this.reject_prediction(
1578 new_prediction.prediction.id,
1579 EditPredictionRejectReason::CurrentPreferred,
1580 false,
1581 );
1582 None
1583 }
1584 } else {
1585 Some(new_prediction)
1586 }
1587 }
1588 Err(reject_reason) => {
1589 this.reject_prediction(prediction_result.id, reject_reason, false);
1590 None
1591 }
1592 }
1593 } else {
1594 None
1595 };
1596
1597 let project_state = this.get_or_init_project(&project, cx);
1598
1599 if let Some(new_prediction) = new_current_prediction {
1600 project_state.current_prediction = Some(new_prediction);
1601 }
1602
1603 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1604 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1605 if pending_prediction.id == pending_prediction_id {
1606 pending_predictions.remove(ix);
1607 for pending_prediction in pending_predictions.drain(0..ix) {
1608 project_state.cancel_pending_prediction(pending_prediction, cx)
1609 }
1610 break;
1611 }
1612 }
1613 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1614 cx.notify();
1615 })
1616 .ok();
1617
1618 new_prediction_id
1619 });
1620
1621 if project_state.pending_predictions.len() <= 1 {
1622 project_state.pending_predictions.push(PendingPrediction {
1623 id: pending_prediction_id,
1624 task,
1625 });
1626 } else if project_state.pending_predictions.len() == 2 {
1627 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1628 project_state.pending_predictions.push(PendingPrediction {
1629 id: pending_prediction_id,
1630 task,
1631 });
1632 project_state.cancel_pending_prediction(pending_prediction, cx);
1633 }
1634 }
1635
1636 pub fn request_prediction(
1637 &mut self,
1638 project: &Entity<Project>,
1639 active_buffer: &Entity<Buffer>,
1640 position: language::Anchor,
1641 trigger: PredictEditsRequestTrigger,
1642 cx: &mut Context<Self>,
1643 ) -> Task<Result<Option<EditPredictionResult>>> {
1644 self.request_prediction_internal(
1645 project.clone(),
1646 active_buffer.clone(),
1647 position,
1648 trigger,
1649 cx.has_flag::<Zeta2FeatureFlag>(),
1650 cx,
1651 )
1652 }
1653
1654 fn request_prediction_internal(
1655 &mut self,
1656 project: Entity<Project>,
1657 active_buffer: Entity<Buffer>,
1658 position: language::Anchor,
1659 trigger: PredictEditsRequestTrigger,
1660 allow_jump: bool,
1661 cx: &mut Context<Self>,
1662 ) -> Task<Result<Option<EditPredictionResult>>> {
1663 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1664
1665 self.get_or_init_project(&project, cx);
1666 let project_state = self.projects.get(&project.entity_id()).unwrap();
1667 let stored_events = project_state.events(cx);
1668 let has_events = !stored_events.is_empty();
1669 let events: Vec<Arc<zeta_prompt::Event>> =
1670 stored_events.into_iter().map(|e| e.event).collect();
1671 let debug_tx = project_state.debug_tx.clone();
1672
1673 let snapshot = active_buffer.read(cx).snapshot();
1674 let cursor_point = position.to_point(&snapshot);
1675 let current_offset = position.to_offset(&snapshot);
1676
1677 let mut user_actions: Vec<UserActionRecord> =
1678 project_state.user_actions.iter().cloned().collect();
1679
1680 if let Some(last_action) = user_actions.last() {
1681 if last_action.buffer_id == active_buffer.entity_id()
1682 && current_offset != last_action.offset
1683 {
1684 let timestamp_epoch_ms = SystemTime::now()
1685 .duration_since(UNIX_EPOCH)
1686 .map(|d| d.as_millis() as u64)
1687 .unwrap_or(0);
1688 user_actions.push(UserActionRecord {
1689 action_type: UserActionType::CursorMovement,
1690 buffer_id: active_buffer.entity_id(),
1691 line_number: cursor_point.row,
1692 offset: current_offset,
1693 timestamp_epoch_ms,
1694 });
1695 }
1696 }
1697 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1698 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1699 let diagnostic_search_range =
1700 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1701
1702 let related_files = self.context_for_project(&project, cx);
1703
1704 let inputs = EditPredictionModelInput {
1705 project: project.clone(),
1706 buffer: active_buffer.clone(),
1707 snapshot: snapshot.clone(),
1708 position,
1709 events,
1710 related_files,
1711 recent_paths: project_state.recent_paths.clone(),
1712 trigger,
1713 diagnostic_search_range: diagnostic_search_range.clone(),
1714 debug_tx,
1715 user_actions,
1716 };
1717
1718 let can_collect_example = snapshot
1719 .file()
1720 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1721 && self.can_collect_events(&inputs.events, cx)
1722 && self.can_collect_related_files(&project, cx);
1723
1724 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1725 let events_for_capture =
1726 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1727 let related_files_for_capture = inputs.related_files.clone();
1728 if let Some(example_task) = capture_example::capture_example(
1729 project.clone(),
1730 active_buffer.clone(),
1731 position,
1732 events_for_capture,
1733 related_files_for_capture,
1734 false,
1735 cx,
1736 ) {
1737 cx.spawn(async move |_this, _cx| {
1738 let example = example_task.await?;
1739 telemetry::event!("Edit Prediction Example Captured", example = example);
1740 anyhow::Ok(())
1741 })
1742 .detach_and_log_err(cx);
1743 }
1744 }
1745 let task = match self.edit_prediction_model {
1746 EditPredictionModel::Zeta1 => {
1747 if should_send_testing_zeta2_request() {
1748 let mut zeta2_inputs = inputs.clone();
1749 zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1750 zeta2::request_prediction_with_zeta2(
1751 self,
1752 zeta2_inputs,
1753 Default::default(),
1754 cx,
1755 )
1756 .detach();
1757 }
1758 zeta1::request_prediction_with_zeta1(self, inputs, cx)
1759 }
1760 EditPredictionModel::Zeta2 { version } => {
1761 zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1762 }
1763 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1764 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1765 };
1766
1767 cx.spawn(async move |this, cx| {
1768 let prediction = task.await?;
1769
1770 if prediction.is_none() && allow_jump {
1771 let cursor_point = position.to_point(&snapshot);
1772 if has_events
1773 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1774 active_buffer.clone(),
1775 &snapshot,
1776 diagnostic_search_range,
1777 cursor_point,
1778 &project,
1779 cx,
1780 )
1781 .await?
1782 {
1783 return this
1784 .update(cx, |this, cx| {
1785 this.request_prediction_internal(
1786 project,
1787 jump_buffer,
1788 jump_position,
1789 trigger,
1790 false,
1791 cx,
1792 )
1793 })?
1794 .await;
1795 }
1796
1797 return anyhow::Ok(None);
1798 }
1799
1800 Ok(prediction)
1801 })
1802 }
1803
1804 async fn next_diagnostic_location(
1805 active_buffer: Entity<Buffer>,
1806 active_buffer_snapshot: &BufferSnapshot,
1807 active_buffer_diagnostic_search_range: Range<Point>,
1808 active_buffer_cursor_point: Point,
1809 project: &Entity<Project>,
1810 cx: &mut AsyncApp,
1811 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1812 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1813 let mut jump_location = active_buffer_snapshot
1814 .diagnostic_groups(None)
1815 .into_iter()
1816 .filter_map(|(_, group)| {
1817 let range = &group.entries[group.primary_ix]
1818 .range
1819 .to_point(&active_buffer_snapshot);
1820 if range.overlaps(&active_buffer_diagnostic_search_range) {
1821 None
1822 } else {
1823 Some(range.start)
1824 }
1825 })
1826 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1827 .map(|position| {
1828 (
1829 active_buffer.clone(),
1830 active_buffer_snapshot.anchor_before(position),
1831 )
1832 });
1833
1834 if jump_location.is_none() {
1835 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1836 let file = buffer.file()?;
1837
1838 Some(ProjectPath {
1839 worktree_id: file.worktree_id(cx),
1840 path: file.path().clone(),
1841 })
1842 });
1843
1844 let buffer_task = project.update(cx, |project, cx| {
1845 let (path, _, _) = project
1846 .diagnostic_summaries(false, cx)
1847 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1848 .max_by_key(|(path, _, _)| {
1849 // find the buffer with errors that shares most parent directories
1850 path.path
1851 .components()
1852 .zip(
1853 active_buffer_path
1854 .as_ref()
1855 .map(|p| p.path.components())
1856 .unwrap_or_default(),
1857 )
1858 .take_while(|(a, b)| a == b)
1859 .count()
1860 })?;
1861
1862 Some(project.open_buffer(path, cx))
1863 });
1864
1865 if let Some(buffer_task) = buffer_task {
1866 let closest_buffer = buffer_task.await?;
1867
1868 jump_location = closest_buffer
1869 .read_with(cx, |buffer, _cx| {
1870 buffer
1871 .buffer_diagnostics(None)
1872 .into_iter()
1873 .min_by_key(|entry| entry.diagnostic.severity)
1874 .map(|entry| entry.range.start)
1875 })
1876 .map(|position| (closest_buffer, position));
1877 }
1878 }
1879
1880 anyhow::Ok(jump_location)
1881 }
1882
1883 async fn send_raw_llm_request(
1884 request: RawCompletionRequest,
1885 client: Arc<Client>,
1886 custom_url: Option<Arc<Url>>,
1887 llm_token: LlmApiToken,
1888 app_version: Version,
1889 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1890 let url = if let Some(custom_url) = custom_url {
1891 custom_url.as_ref().clone()
1892 } else {
1893 client
1894 .http_client()
1895 .build_zed_llm_url("/predict_edits/raw", &[])?
1896 };
1897
1898 Self::send_api_request(
1899 |builder| {
1900 let req = builder
1901 .uri(url.as_ref())
1902 .body(serde_json::to_string(&request)?.into());
1903 Ok(req?)
1904 },
1905 client,
1906 llm_token,
1907 app_version,
1908 true,
1909 )
1910 .await
1911 }
1912
1913 pub(crate) async fn send_v3_request(
1914 input: ZetaPromptInput,
1915 prompt_version: ZetaVersion,
1916 client: Arc<Client>,
1917 llm_token: LlmApiToken,
1918 app_version: Version,
1919 trigger: PredictEditsRequestTrigger,
1920 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1921 let url = client
1922 .http_client()
1923 .build_zed_llm_url("/predict_edits/v3", &[])?;
1924
1925 let request = PredictEditsV3Request {
1926 input,
1927 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1928 prompt_version,
1929 trigger,
1930 };
1931
1932 Self::send_api_request(
1933 |builder| {
1934 let req = builder
1935 .uri(url.as_ref())
1936 .body(serde_json::to_string(&request)?.into());
1937 Ok(req?)
1938 },
1939 client,
1940 llm_token,
1941 app_version,
1942 true,
1943 )
1944 .await
1945 }
1946
1947 fn handle_api_response<T>(
1948 this: &WeakEntity<Self>,
1949 response: Result<(T, Option<EditPredictionUsage>)>,
1950 cx: &mut gpui::AsyncApp,
1951 ) -> Result<T> {
1952 match response {
1953 Ok((data, usage)) => {
1954 if let Some(usage) = usage {
1955 this.update(cx, |this, cx| {
1956 this.user_store.update(cx, |user_store, cx| {
1957 user_store.update_edit_prediction_usage(usage, cx);
1958 });
1959 })
1960 .ok();
1961 }
1962 Ok(data)
1963 }
1964 Err(err) => {
1965 if err.is::<ZedUpdateRequiredError>() {
1966 cx.update(|cx| {
1967 this.update(cx, |this, _cx| {
1968 this.update_required = true;
1969 })
1970 .ok();
1971
1972 let error_message: SharedString = err.to_string().into();
1973 show_app_notification(
1974 NotificationId::unique::<ZedUpdateRequiredError>(),
1975 cx,
1976 move |cx| {
1977 cx.new(|cx| {
1978 ErrorMessagePrompt::new(error_message.clone(), cx)
1979 .with_link_button("Update Zed", "https://zed.dev/releases")
1980 })
1981 },
1982 );
1983 });
1984 }
1985 Err(err)
1986 }
1987 }
1988 }
1989
1990 async fn send_api_request<Res>(
1991 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1992 client: Arc<Client>,
1993 llm_token: LlmApiToken,
1994 app_version: Version,
1995 require_auth: bool,
1996 ) -> Result<(Res, Option<EditPredictionUsage>)>
1997 where
1998 Res: DeserializeOwned,
1999 {
2000 let http_client = client.http_client();
2001
2002 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2003 Some(custom_token)
2004 } else if require_auth {
2005 Some(llm_token.acquire(&client).await?)
2006 } else {
2007 llm_token.acquire(&client).await.ok()
2008 };
2009 let mut did_retry = false;
2010
2011 loop {
2012 let request_builder = http_client::Request::builder().method(Method::POST);
2013
2014 let mut request_builder = request_builder
2015 .header("Content-Type", "application/json")
2016 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2017
2018 // Only add Authorization header if we have a token
2019 if let Some(ref token_value) = token {
2020 request_builder =
2021 request_builder.header("Authorization", format!("Bearer {}", token_value));
2022 }
2023
2024 let request = build(request_builder)?;
2025
2026 let mut response = http_client.send(request).await?;
2027
2028 if let Some(minimum_required_version) = response
2029 .headers()
2030 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2031 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2032 {
2033 anyhow::ensure!(
2034 app_version >= minimum_required_version,
2035 ZedUpdateRequiredError {
2036 minimum_version: minimum_required_version
2037 }
2038 );
2039 }
2040
2041 if response.status().is_success() {
2042 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2043
2044 let mut body = Vec::new();
2045 response.body_mut().read_to_end(&mut body).await?;
2046 return Ok((serde_json::from_slice(&body)?, usage));
2047 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2048 did_retry = true;
2049 token = Some(llm_token.refresh(&client).await?);
2050 } else {
2051 let mut body = String::new();
2052 response.body_mut().read_to_string(&mut body).await?;
2053 anyhow::bail!(
2054 "Request failed with status: {:?}\nBody: {}",
2055 response.status(),
2056 body
2057 );
2058 }
2059 }
2060 }
2061
2062 pub fn refresh_context(
2063 &mut self,
2064 project: &Entity<Project>,
2065 buffer: &Entity<language::Buffer>,
2066 cursor_position: language::Anchor,
2067 cx: &mut Context<Self>,
2068 ) {
2069 self.get_or_init_project(project, cx)
2070 .context
2071 .update(cx, |store, cx| {
2072 store.refresh(buffer.clone(), cursor_position, cx);
2073 });
2074 }
2075
2076 #[cfg(feature = "cli-support")]
2077 pub fn set_context_for_buffer(
2078 &mut self,
2079 project: &Entity<Project>,
2080 related_files: Vec<RelatedFile>,
2081 cx: &mut Context<Self>,
2082 ) {
2083 self.get_or_init_project(project, cx)
2084 .context
2085 .update(cx, |store, cx| {
2086 store.set_related_files(related_files, cx);
2087 });
2088 }
2089
2090 #[cfg(feature = "cli-support")]
2091 pub fn set_recent_paths_for_project(
2092 &mut self,
2093 project: &Entity<Project>,
2094 paths: impl IntoIterator<Item = project::ProjectPath>,
2095 cx: &mut Context<Self>,
2096 ) {
2097 let project_state = self.get_or_init_project(project, cx);
2098 project_state.recent_paths = paths.into_iter().collect();
2099 }
2100
2101 fn is_file_open_source(
2102 &self,
2103 project: &Entity<Project>,
2104 file: &Arc<dyn File>,
2105 cx: &App,
2106 ) -> bool {
2107 if !file.is_local() || file.is_private() {
2108 return false;
2109 }
2110 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2111 return false;
2112 };
2113 project_state
2114 .license_detection_watchers
2115 .get(&file.worktree_id(cx))
2116 .as_ref()
2117 .is_some_and(|watcher| watcher.is_project_open_source())
2118 }
2119
2120 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2121 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2122 }
2123
2124 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2125 if !self.data_collection_choice.is_enabled(cx) {
2126 return false;
2127 }
2128 events.iter().all(|event| {
2129 matches!(
2130 event.as_ref(),
2131 zeta_prompt::Event::BufferChange {
2132 in_open_source_repo: true,
2133 ..
2134 }
2135 )
2136 })
2137 }
2138
2139 fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
2140 if !self.data_collection_choice.is_enabled(cx) {
2141 return false;
2142 }
2143
2144 let related_with_buffers = self.context_for_project_with_buffers(project, cx);
2145
2146 related_with_buffers.iter().all(|(_, buffer)| {
2147 buffer
2148 .read(cx)
2149 .file()
2150 .is_some_and(|file| self.is_file_open_source(project, &file, cx))
2151 })
2152 }
2153
2154 fn load_data_collection_choice() -> DataCollectionChoice {
2155 let choice = KEY_VALUE_STORE
2156 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2157 .log_err()
2158 .flatten();
2159
2160 match choice.as_deref() {
2161 Some("true") => DataCollectionChoice::Enabled,
2162 Some("false") => DataCollectionChoice::Disabled,
2163 Some(_) => {
2164 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2165 DataCollectionChoice::NotAnswered
2166 }
2167 None => DataCollectionChoice::NotAnswered,
2168 }
2169 }
2170
2171 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2172 self.data_collection_choice = self.data_collection_choice.toggle();
2173 let new_choice = self.data_collection_choice;
2174 let is_enabled = new_choice.is_enabled(cx);
2175 db::write_and_log(cx, move || {
2176 KEY_VALUE_STORE.write_kvp(
2177 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2178 is_enabled.to_string(),
2179 )
2180 });
2181 }
2182
2183 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2184 self.shown_predictions.iter()
2185 }
2186
2187 pub fn shown_completions_len(&self) -> usize {
2188 self.shown_predictions.len()
2189 }
2190
2191 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2192 self.rated_predictions.contains(id)
2193 }
2194
2195 pub fn rate_prediction(
2196 &mut self,
2197 prediction: &EditPrediction,
2198 rating: EditPredictionRating,
2199 feedback: String,
2200 cx: &mut Context<Self>,
2201 ) {
2202 self.rated_predictions.insert(prediction.id.clone());
2203 telemetry::event!(
2204 "Edit Prediction Rated",
2205 rating,
2206 inputs = prediction.inputs,
2207 output = prediction
2208 .edit_preview
2209 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2210 feedback
2211 );
2212 self.client.telemetry().flush_events().detach();
2213 cx.notify();
2214 }
2215}
2216
2217pub(crate) fn filter_redundant_excerpts(
2218 mut related_files: Vec<RelatedFile>,
2219 cursor_path: &Path,
2220 cursor_row_range: Range<u32>,
2221) -> Vec<RelatedFile> {
2222 for file in &mut related_files {
2223 if file.path.as_ref() == cursor_path {
2224 file.excerpts.retain(|excerpt| {
2225 excerpt.row_range.start < cursor_row_range.start
2226 || excerpt.row_range.end > cursor_row_range.end
2227 });
2228 }
2229 }
2230 related_files.retain(|file| !file.excerpts.is_empty());
2231 related_files
2232}
2233
2234#[derive(Error, Debug)]
2235#[error(
2236 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2237)]
2238pub struct ZedUpdateRequiredError {
2239 minimum_version: Version,
2240}
2241
2242#[derive(Debug, Clone, Copy)]
2243pub enum DataCollectionChoice {
2244 NotAnswered,
2245 Enabled,
2246 Disabled,
2247}
2248
2249impl DataCollectionChoice {
2250 pub fn is_enabled(self, cx: &App) -> bool {
2251 if cx.is_staff() {
2252 return true;
2253 }
2254 match self {
2255 Self::Enabled => true,
2256 Self::NotAnswered | Self::Disabled => false,
2257 }
2258 }
2259
2260 #[must_use]
2261 pub fn toggle(&self) -> DataCollectionChoice {
2262 match self {
2263 Self::Enabled => Self::Disabled,
2264 Self::Disabled => Self::Enabled,
2265 Self::NotAnswered => Self::Enabled,
2266 }
2267 }
2268}
2269
2270impl From<bool> for DataCollectionChoice {
2271 fn from(value: bool) -> Self {
2272 match value {
2273 true => DataCollectionChoice::Enabled,
2274 false => DataCollectionChoice::Disabled,
2275 }
2276 }
2277}
2278
2279struct ZedPredictUpsell;
2280
2281impl Dismissable for ZedPredictUpsell {
2282 const KEY: &'static str = "dismissed-edit-predict-upsell";
2283
2284 fn dismissed() -> bool {
2285 // To make this backwards compatible with older versions of Zed, we
2286 // check if the user has seen the previous Edit Prediction Onboarding
2287 // before, by checking the data collection choice which was written to
2288 // the database once the user clicked on "Accept and Enable"
2289 if KEY_VALUE_STORE
2290 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2291 .log_err()
2292 .is_some_and(|s| s.is_some())
2293 {
2294 return true;
2295 }
2296
2297 KEY_VALUE_STORE
2298 .read_kvp(Self::KEY)
2299 .log_err()
2300 .is_some_and(|s| s.is_some())
2301 }
2302}
2303
2304pub fn should_show_upsell_modal() -> bool {
2305 !ZedPredictUpsell::dismissed()
2306}
2307
2308pub fn init(cx: &mut App) {
2309 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2310 workspace.register_action(
2311 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2312 ZedPredictModal::toggle(
2313 workspace,
2314 workspace.user_store().clone(),
2315 workspace.client().clone(),
2316 window,
2317 cx,
2318 )
2319 },
2320 );
2321
2322 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2323 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2324 settings
2325 .project
2326 .all_languages
2327 .features
2328 .get_or_insert_default()
2329 .edit_prediction_provider = Some(EditPredictionProvider::None)
2330 });
2331 });
2332 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2333 EditPredictionStore::try_global(cx).and_then(|store| {
2334 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2335 })
2336 }
2337
2338 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2339 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2340 copilot_ui::initiate_sign_in(copilot, window, cx);
2341 }
2342 });
2343 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2344 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2345 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2346 }
2347 });
2348 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2349 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2350 copilot_ui::initiate_sign_out(copilot, window, cx);
2351 }
2352 });
2353 })
2354 .detach();
2355}