1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::SubmitEditPredictionFeedbackBody;
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod cursor_excerpt;
59pub mod example_spec;
60pub mod fim;
61mod license_detection;
62pub mod mercury;
63pub mod ollama;
64mod onboarding_modal;
65pub mod open_ai_response;
66mod prediction;
67pub mod sweep_ai;
68
69pub mod udiff;
70
71mod capture_example;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::license_detection::LicenseDetectionWatcher;
79use crate::mercury::Mercury;
80use crate::onboarding_modal::ZedPredictModal;
81pub use crate::prediction::EditPrediction;
82pub use crate::prediction::EditPredictionId;
83use crate::prediction::EditPredictionResult;
84pub use crate::sweep_ai::SweepAi;
85pub use capture_example::capture_example;
86pub use language_model::ApiKeyState;
87pub use telemetry_events::EditPredictionRating;
88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
89
90actions!(
91 edit_prediction,
92 [
93 /// Resets the edit prediction onboarding state.
94 ResetOnboarding,
95 /// Clears the edit prediction history.
96 ClearHistory,
97 ]
98);
99
100/// Maximum number of events to track.
101const EVENT_COUNT_MAX: usize = 6;
102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
106const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
107const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
108const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
109
110pub struct Zeta2FeatureFlag;
111pub struct EditPredictionJumpsFeatureFlag;
112
113impl FeatureFlag for Zeta2FeatureFlag {
114 const NAME: &'static str = "zeta2";
115}
116
117impl FeatureFlag for EditPredictionJumpsFeatureFlag {
118 const NAME: &'static str = "edit_prediction_jumps";
119}
120
121#[derive(Clone)]
122struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
123
124impl Global for EditPredictionStoreGlobal {}
125
126/// Configuration for using the raw Zeta2 endpoint.
127/// When set, the client uses the raw endpoint and constructs the prompt itself.
128/// The version is also used as the Baseten environment name (lowercased).
129#[derive(Clone)]
130pub struct Zeta2RawConfig {
131 pub model_id: Option<String>,
132 pub format: ZetaFormat,
133}
134
135pub struct EditPredictionStore {
136 client: Arc<Client>,
137 user_store: Entity<UserStore>,
138 llm_token: LlmApiToken,
139 _llm_token_subscription: Subscription,
140 _fetch_experiments_task: Task<()>,
141 projects: HashMap<EntityId, ProjectState>,
142 update_required: bool,
143 edit_prediction_model: EditPredictionModel,
144 zeta2_raw_config: Option<Zeta2RawConfig>,
145 preferred_experiment: Option<String>,
146 available_experiments: Vec<String>,
147 pub sweep_ai: SweepAi,
148 pub mercury: Mercury,
149 data_collection_choice: DataCollectionChoice,
150 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
151 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
152 shown_predictions: VecDeque<EditPrediction>,
153 rated_predictions: HashSet<EditPredictionId>,
154 #[cfg(test)]
155 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
156}
157
158#[derive(Copy, Clone, PartialEq, Eq)]
159pub enum EditPredictionModel {
160 Zeta,
161 Fim { format: EditPredictionPromptFormat },
162 Sweep,
163 Mercury,
164}
165
166#[derive(Clone)]
167pub struct EditPredictionModelInput {
168 project: Entity<Project>,
169 buffer: Entity<Buffer>,
170 snapshot: BufferSnapshot,
171 position: Anchor,
172 events: Vec<Arc<zeta_prompt::Event>>,
173 related_files: Vec<RelatedFile>,
174 recent_paths: VecDeque<ProjectPath>,
175 trigger: PredictEditsRequestTrigger,
176 diagnostic_search_range: Range<Point>,
177 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
178 can_collect_data: bool,
179 is_open_source: bool,
180 pub user_actions: Vec<UserActionRecord>,
181}
182
183#[derive(Debug)]
184pub enum DebugEvent {
185 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
186 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
187 EditPredictionStarted(EditPredictionStartedDebugEvent),
188 EditPredictionFinished(EditPredictionFinishedDebugEvent),
189}
190
191#[derive(Debug)]
192pub struct ContextRetrievalStartedDebugEvent {
193 pub project_entity_id: EntityId,
194 pub timestamp: Instant,
195 pub search_prompt: String,
196}
197
198#[derive(Debug)]
199pub struct ContextRetrievalFinishedDebugEvent {
200 pub project_entity_id: EntityId,
201 pub timestamp: Instant,
202 pub metadata: Vec<(&'static str, SharedString)>,
203}
204
205#[derive(Debug)]
206pub struct EditPredictionStartedDebugEvent {
207 pub buffer: WeakEntity<Buffer>,
208 pub position: Anchor,
209 pub prompt: Option<String>,
210}
211
212#[derive(Debug)]
213pub struct EditPredictionFinishedDebugEvent {
214 pub buffer: WeakEntity<Buffer>,
215 pub position: Anchor,
216 pub model_output: Option<String>,
217}
218
219const USER_ACTION_HISTORY_SIZE: usize = 16;
220
221#[derive(Clone, Debug)]
222pub struct UserActionRecord {
223 pub action_type: UserActionType,
224 pub buffer_id: EntityId,
225 pub line_number: u32,
226 pub offset: usize,
227 pub timestamp_epoch_ms: u64,
228}
229
230#[derive(Clone, Copy, Debug, PartialEq, Eq)]
231pub enum UserActionType {
232 InsertChar,
233 InsertSelection,
234 DeleteChar,
235 DeleteSelection,
236 CursorMovement,
237}
238
239/// An event with associated metadata for reconstructing buffer state.
240#[derive(Clone)]
241pub struct StoredEvent {
242 pub event: Arc<zeta_prompt::Event>,
243 pub old_snapshot: TextBufferSnapshot,
244 pub edit_range: Range<Anchor>,
245}
246
247impl StoredEvent {
248 fn can_merge(
249 &self,
250 next_old_event: &&&StoredEvent,
251 new_snapshot: &TextBufferSnapshot,
252 last_edit_range: &Range<Anchor>,
253 ) -> bool {
254 // Events must be for the same buffer
255 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
256 return false;
257 }
258 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
259 return false;
260 }
261
262 let a_is_predicted = matches!(
263 self.event.as_ref(),
264 zeta_prompt::Event::BufferChange {
265 predicted: true,
266 ..
267 }
268 );
269 let b_is_predicted = matches!(
270 next_old_event.event.as_ref(),
271 zeta_prompt::Event::BufferChange {
272 predicted: true,
273 ..
274 }
275 );
276
277 // If events come from the same source (both predicted or both manual) then
278 // we would have coalesced them already.
279 if a_is_predicted == b_is_predicted {
280 return false;
281 }
282
283 let left_range = self.edit_range.to_point(new_snapshot);
284 let right_range = next_old_event.edit_range.to_point(new_snapshot);
285 let latest_range = last_edit_range.to_point(&new_snapshot);
286
287 // Events near to the latest edit are not merged if their sources differ.
288 if lines_between_ranges(&left_range, &latest_range)
289 .min(lines_between_ranges(&right_range, &latest_range))
290 <= CHANGE_GROUPING_LINE_SPAN
291 {
292 return false;
293 }
294
295 // Events that are distant from each other are not merged.
296 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
297 return false;
298 }
299
300 true
301 }
302}
303
304fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
305 if left.start > right.end {
306 return left.start.row - right.end.row;
307 }
308 if right.start > left.end {
309 return right.start.row - left.end.row;
310 }
311 0
312}
313
314struct ProjectState {
315 events: VecDeque<StoredEvent>,
316 last_event: Option<LastEvent>,
317 recent_paths: VecDeque<ProjectPath>,
318 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
319 current_prediction: Option<CurrentEditPrediction>,
320 next_pending_prediction_id: usize,
321 pending_predictions: ArrayVec<PendingPrediction, 2>,
322 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
323 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
324 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
325 cancelled_predictions: HashSet<usize>,
326 context: Entity<RelatedExcerptStore>,
327 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
328 user_actions: VecDeque<UserActionRecord>,
329 _subscriptions: [gpui::Subscription; 2],
330 copilot: Option<Entity<Copilot>>,
331}
332
333impl ProjectState {
334 fn record_user_action(&mut self, action: UserActionRecord) {
335 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
336 self.user_actions.pop_front();
337 }
338 self.user_actions.push_back(action);
339 }
340
341 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
342 self.events
343 .iter()
344 .cloned()
345 .chain(self.last_event.as_ref().iter().flat_map(|event| {
346 let (one, two) = event.split_by_pause();
347 let one = one.finalize(&self.license_detection_watchers, cx);
348 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
349 one.into_iter().chain(two)
350 }))
351 .collect()
352 }
353
354 fn cancel_pending_prediction(
355 &mut self,
356 pending_prediction: PendingPrediction,
357 cx: &mut Context<EditPredictionStore>,
358 ) {
359 self.cancelled_predictions.insert(pending_prediction.id);
360
361 if pending_prediction.drop_on_cancel {
362 drop(pending_prediction.task);
363 } else {
364 cx.spawn(async move |this, cx| {
365 let Some(prediction_id) = pending_prediction.task.await else {
366 return;
367 };
368
369 this.update(cx, |this, cx| {
370 this.reject_prediction(
371 prediction_id,
372 EditPredictionRejectReason::Canceled,
373 false,
374 None,
375 cx,
376 );
377 })
378 .ok();
379 })
380 .detach()
381 }
382 }
383
384 fn active_buffer(
385 &self,
386 project: &Entity<Project>,
387 cx: &App,
388 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
389 let project = project.read(cx);
390 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
391 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
392 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
393 Some((active_buffer, registered_buffer.last_position))
394 }
395}
396
397#[derive(Debug, Clone)]
398struct CurrentEditPrediction {
399 pub requested_by: PredictionRequestedBy,
400 pub prediction: EditPrediction,
401 pub was_shown: bool,
402 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
403}
404
405impl CurrentEditPrediction {
406 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
407 let Some(new_edits) = self
408 .prediction
409 .interpolate(&self.prediction.buffer.read(cx))
410 else {
411 return false;
412 };
413
414 if self.prediction.buffer != old_prediction.prediction.buffer {
415 return true;
416 }
417
418 let Some(old_edits) = old_prediction
419 .prediction
420 .interpolate(&old_prediction.prediction.buffer.read(cx))
421 else {
422 return true;
423 };
424
425 let requested_by_buffer_id = self.requested_by.buffer_id();
426
427 // This reduces the occurrence of UI thrash from replacing edits
428 //
429 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
430 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
431 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
432 && old_edits.len() == 1
433 && new_edits.len() == 1
434 {
435 let (old_range, old_text) = &old_edits[0];
436 let (new_range, new_text) = &new_edits[0];
437 new_range == old_range && new_text.starts_with(old_text.as_ref())
438 } else {
439 true
440 }
441 }
442}
443
444#[derive(Debug, Clone)]
445enum PredictionRequestedBy {
446 DiagnosticsUpdate,
447 Buffer(EntityId),
448}
449
450impl PredictionRequestedBy {
451 pub fn buffer_id(&self) -> Option<EntityId> {
452 match self {
453 PredictionRequestedBy::DiagnosticsUpdate => None,
454 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
455 }
456 }
457}
458
459const DIAGNOSTIC_LINES_RANGE: u32 = 20;
460
461#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
462pub enum DiagnosticSearchScope {
463 Local,
464 Global,
465}
466
467#[derive(Debug)]
468struct PendingPrediction {
469 id: usize,
470 task: Task<Option<EditPredictionId>>,
471 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
472 /// If false, the task is awaited to completion so rejection can be reported.
473 drop_on_cancel: bool,
474}
475
476/// A prediction from the perspective of a buffer.
477#[derive(Debug)]
478enum BufferEditPrediction<'a> {
479 Local { prediction: &'a EditPrediction },
480 Jump { prediction: &'a EditPrediction },
481}
482
483#[cfg(test)]
484impl std::ops::Deref for BufferEditPrediction<'_> {
485 type Target = EditPrediction;
486
487 fn deref(&self) -> &Self::Target {
488 match self {
489 BufferEditPrediction::Local { prediction } => prediction,
490 BufferEditPrediction::Jump { prediction } => prediction,
491 }
492 }
493}
494
495#[derive(Clone)]
496struct PendingSettledPrediction {
497 request_id: EditPredictionId,
498 editable_anchor_range: Range<Anchor>,
499 enqueued_at: Instant,
500 last_edit_at: Instant,
501}
502
503struct RegisteredBuffer {
504 file: Option<Arc<dyn File>>,
505 snapshot: TextBufferSnapshot,
506 pending_predictions: Vec<PendingSettledPrediction>,
507 last_position: Option<Anchor>,
508 _subscriptions: [gpui::Subscription; 2],
509}
510
511#[derive(Clone)]
512struct LastEvent {
513 old_snapshot: TextBufferSnapshot,
514 new_snapshot: TextBufferSnapshot,
515 old_file: Option<Arc<dyn File>>,
516 new_file: Option<Arc<dyn File>>,
517 edit_range: Option<Range<Anchor>>,
518 predicted: bool,
519 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
520 last_edit_time: Option<Instant>,
521}
522
523impl LastEvent {
524 pub fn finalize(
525 &self,
526 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
527 cx: &App,
528 ) -> Option<StoredEvent> {
529 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
530 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
531
532 let in_open_source_repo =
533 [self.new_file.as_ref(), self.old_file.as_ref()]
534 .iter()
535 .all(|file| {
536 file.is_some_and(|file| {
537 license_detection_watchers
538 .get(&file.worktree_id(cx))
539 .is_some_and(|watcher| watcher.is_project_open_source())
540 })
541 });
542
543 let (diff, edit_range) =
544 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
545
546 if path == old_path && diff.is_empty() {
547 None
548 } else {
549 Some(StoredEvent {
550 event: Arc::new(zeta_prompt::Event::BufferChange {
551 old_path,
552 path,
553 diff,
554 in_open_source_repo,
555 predicted: self.predicted,
556 }),
557 edit_range: self.new_snapshot.anchor_before(edit_range.start)
558 ..self.new_snapshot.anchor_before(edit_range.end),
559 old_snapshot: self.old_snapshot.clone(),
560 })
561 }
562 }
563
564 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
565 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
566 return (self.clone(), None);
567 };
568
569 let before = LastEvent {
570 old_snapshot: self.old_snapshot.clone(),
571 new_snapshot: boundary_snapshot.clone(),
572 old_file: self.old_file.clone(),
573 new_file: self.new_file.clone(),
574 edit_range: None,
575 predicted: self.predicted,
576 snapshot_after_last_editing_pause: None,
577 last_edit_time: self.last_edit_time,
578 };
579
580 let after = LastEvent {
581 old_snapshot: boundary_snapshot.clone(),
582 new_snapshot: self.new_snapshot.clone(),
583 old_file: self.old_file.clone(),
584 new_file: self.new_file.clone(),
585 edit_range: None,
586 predicted: self.predicted,
587 snapshot_after_last_editing_pause: None,
588 last_edit_time: self.last_edit_time,
589 };
590
591 (before, Some(after))
592 }
593}
594
595pub(crate) fn compute_diff_between_snapshots(
596 old_snapshot: &TextBufferSnapshot,
597 new_snapshot: &TextBufferSnapshot,
598) -> Option<(String, Range<Point>)> {
599 let edits: Vec<Edit<usize>> = new_snapshot
600 .edits_since::<usize>(&old_snapshot.version)
601 .collect();
602
603 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
604
605 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
606 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
607 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
608 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
609
610 const CONTEXT_LINES: u32 = 3;
611
612 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
613 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
614 let old_context_end_row =
615 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
616 let new_context_end_row =
617 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
618
619 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
620 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
621 let old_end_line_offset = old_snapshot
622 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
623 let new_end_line_offset = new_snapshot
624 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
625 let old_edit_range = old_start_line_offset..old_end_line_offset;
626 let new_edit_range = new_start_line_offset..new_end_line_offset;
627
628 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
629 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
630
631 let diff = language::unified_diff_with_offsets(
632 &old_region_text,
633 &new_region_text,
634 old_context_start_row,
635 new_context_start_row,
636 );
637
638 Some((diff, new_start_point..new_end_point))
639}
640
641fn buffer_path_with_id_fallback(
642 file: Option<&Arc<dyn File>>,
643 snapshot: &TextBufferSnapshot,
644 cx: &App,
645) -> Arc<Path> {
646 if let Some(file) = file {
647 file.full_path(cx).into()
648 } else {
649 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
650 }
651}
652
653impl EditPredictionStore {
654 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
655 cx.try_global::<EditPredictionStoreGlobal>()
656 .map(|global| global.0.clone())
657 }
658
659 pub fn global(
660 client: &Arc<Client>,
661 user_store: &Entity<UserStore>,
662 cx: &mut App,
663 ) -> Entity<Self> {
664 cx.try_global::<EditPredictionStoreGlobal>()
665 .map(|global| global.0.clone())
666 .unwrap_or_else(|| {
667 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
668 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
669 ep_store
670 })
671 }
672
673 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
674 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
675 let data_collection_choice = Self::load_data_collection_choice();
676
677 let llm_token = LlmApiToken::default();
678
679 let (reject_tx, reject_rx) = mpsc::unbounded();
680 cx.background_spawn({
681 let client = client.clone();
682 let llm_token = llm_token.clone();
683 let app_version = AppVersion::global(cx);
684 let background_executor = cx.background_executor().clone();
685 async move {
686 Self::handle_rejected_predictions(
687 reject_rx,
688 client,
689 llm_token,
690 app_version,
691 background_executor,
692 )
693 .await
694 }
695 })
696 .detach();
697
698 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
699 cx.spawn(async move |this, cx| {
700 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
701 })
702 .detach();
703
704 let mut current_user = user_store.read(cx).watch_current_user();
705 let fetch_experiments_task = cx.spawn(async move |this, cx| {
706 while current_user.borrow().is_none() {
707 current_user.next().await;
708 }
709 this.update(cx, |this, cx| {
710 this.refresh_available_experiments(cx);
711 })
712 .log_err();
713 });
714
715 let this = Self {
716 projects: HashMap::default(),
717 client,
718 user_store,
719 llm_token,
720 _fetch_experiments_task: fetch_experiments_task,
721 _llm_token_subscription: cx.subscribe(
722 &refresh_llm_token_listener,
723 |this, _listener, _event, cx| {
724 let client = this.client.clone();
725 let llm_token = this.llm_token.clone();
726 cx.spawn(async move |_this, _cx| {
727 llm_token.refresh(&client).await?;
728 anyhow::Ok(())
729 })
730 .detach_and_log_err(cx);
731 },
732 ),
733 update_required: false,
734 edit_prediction_model: EditPredictionModel::Zeta,
735 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
736 preferred_experiment: None,
737 available_experiments: Vec::new(),
738 sweep_ai: SweepAi::new(cx),
739 mercury: Mercury::new(cx),
740
741 data_collection_choice,
742 reject_predictions_tx: reject_tx,
743 settled_predictions_tx,
744 rated_predictions: Default::default(),
745 shown_predictions: Default::default(),
746 #[cfg(test)]
747 settled_event_callback: None,
748 };
749
750 this
751 }
752
753 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
754 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
755 let format = ZetaFormat::parse(&version_str).ok()?;
756 let model_id = env::var("ZED_ZETA_MODEL").ok();
757 Some(Zeta2RawConfig { model_id, format })
758 }
759
760 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
761 self.edit_prediction_model = model;
762 }
763
764 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
765 self.zeta2_raw_config = Some(config);
766 }
767
768 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
769 self.zeta2_raw_config.as_ref()
770 }
771
772 pub fn preferred_experiment(&self) -> Option<&str> {
773 self.preferred_experiment.as_deref()
774 }
775
776 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
777 self.preferred_experiment = experiment;
778 }
779
780 pub fn available_experiments(&self) -> &[String] {
781 &self.available_experiments
782 }
783
784 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
785 let client = self.client.clone();
786 let llm_token = self.llm_token.clone();
787 let app_version = AppVersion::global(cx);
788 cx.spawn(async move |this, cx| {
789 let experiments = cx
790 .background_spawn(async move {
791 let http_client = client.http_client();
792 let token = llm_token.acquire(&client).await?;
793 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
794 let request = http_client::Request::builder()
795 .method(Method::GET)
796 .uri(url.as_ref())
797 .header("Authorization", format!("Bearer {}", token))
798 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
799 .body(Default::default())?;
800 let mut response = http_client.send(request).await?;
801 if response.status().is_success() {
802 let mut body = Vec::new();
803 response.body_mut().read_to_end(&mut body).await?;
804 let experiments: Vec<String> = serde_json::from_slice(&body)?;
805 Ok(experiments)
806 } else {
807 let mut body = String::new();
808 response.body_mut().read_to_string(&mut body).await?;
809 anyhow::bail!(
810 "Failed to fetch experiments: {:?}\nBody: {}",
811 response.status(),
812 body
813 );
814 }
815 })
816 .await?;
817 this.update(cx, |this, cx| {
818 this.available_experiments = experiments;
819 cx.notify();
820 })?;
821 anyhow::Ok(())
822 })
823 .detach_and_log_err(cx);
824 }
825
826 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
827 use ui::IconName;
828 match self.edit_prediction_model {
829 EditPredictionModel::Sweep => {
830 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
831 .with_disabled(IconName::SweepAiDisabled)
832 .with_up(IconName::SweepAiUp)
833 .with_down(IconName::SweepAiDown)
834 .with_error(IconName::SweepAiError)
835 }
836 EditPredictionModel::Mercury => {
837 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
838 }
839 EditPredictionModel::Zeta => {
840 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
841 .with_disabled(IconName::ZedPredictDisabled)
842 .with_up(IconName::ZedPredictUp)
843 .with_down(IconName::ZedPredictDown)
844 .with_error(IconName::ZedPredictError)
845 }
846 EditPredictionModel::Fim { .. } => {
847 let settings = &all_language_settings(None, cx).edit_predictions;
848 match settings.provider {
849 EditPredictionProvider::Ollama => {
850 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
851 }
852 _ => {
853 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
854 }
855 }
856 }
857 }
858 }
859
860 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
861 self.sweep_ai.api_token.read(cx).has_key()
862 }
863
864 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
865 self.mercury.api_token.read(cx).has_key()
866 }
867
868 pub fn clear_history(&mut self) {
869 for project_state in self.projects.values_mut() {
870 project_state.events.clear();
871 project_state.last_event.take();
872 }
873 }
874
875 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
876 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
877 project_state.events.clear();
878 project_state.last_event.take();
879 }
880 }
881
882 pub fn edit_history_for_project(
883 &self,
884 project: &Entity<Project>,
885 cx: &App,
886 ) -> Vec<StoredEvent> {
887 self.projects
888 .get(&project.entity_id())
889 .map(|project_state| project_state.events(cx))
890 .unwrap_or_default()
891 }
892
893 pub fn context_for_project<'a>(
894 &'a self,
895 project: &Entity<Project>,
896 cx: &'a mut App,
897 ) -> Vec<RelatedFile> {
898 self.projects
899 .get(&project.entity_id())
900 .map(|project_state| {
901 project_state.context.update(cx, |context, cx| {
902 context
903 .related_files_with_buffers(cx)
904 .map(|(mut related_file, buffer)| {
905 related_file.in_open_source_repo = buffer
906 .read(cx)
907 .file()
908 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
909 related_file
910 })
911 .collect()
912 })
913 })
914 .unwrap_or_default()
915 }
916
917 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
918 self.projects
919 .get(&project.entity_id())
920 .and_then(|project| project.copilot.clone())
921 }
922
923 pub fn start_copilot_for_project(
924 &mut self,
925 project: &Entity<Project>,
926 cx: &mut Context<Self>,
927 ) -> Option<Entity<Copilot>> {
928 if DisableAiSettings::get(None, cx).disable_ai {
929 return None;
930 }
931 let state = self.get_or_init_project(project, cx);
932
933 if state.copilot.is_some() {
934 return state.copilot.clone();
935 }
936 let _project = project.clone();
937 let project = project.read(cx);
938
939 let node = project.node_runtime().cloned();
940 if let Some(node) = node {
941 let next_id = project.languages().next_language_server_id();
942 let fs = project.fs().clone();
943
944 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
945 state.copilot = Some(copilot.clone());
946 Some(copilot)
947 } else {
948 None
949 }
950 }
951
952 pub fn context_for_project_with_buffers<'a>(
953 &'a self,
954 project: &Entity<Project>,
955 cx: &'a mut App,
956 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
957 self.projects
958 .get(&project.entity_id())
959 .map(|project| {
960 project.context.update(cx, |context, cx| {
961 context.related_files_with_buffers(cx).collect()
962 })
963 })
964 .unwrap_or_default()
965 }
966
967 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
968 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
969 self.user_store.read(cx).edit_prediction_usage()
970 } else {
971 None
972 }
973 }
974
975 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
976 self.get_or_init_project(project, cx);
977 }
978
979 pub fn register_buffer(
980 &mut self,
981 buffer: &Entity<Buffer>,
982 project: &Entity<Project>,
983 cx: &mut Context<Self>,
984 ) {
985 let project_state = self.get_or_init_project(project, cx);
986 Self::register_buffer_impl(project_state, buffer, project, cx);
987 }
988
989 fn get_or_init_project(
990 &mut self,
991 project: &Entity<Project>,
992 cx: &mut Context<Self>,
993 ) -> &mut ProjectState {
994 let entity_id = project.entity_id();
995 self.projects
996 .entry(entity_id)
997 .or_insert_with(|| ProjectState {
998 context: {
999 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1000 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1001 this.handle_excerpt_store_event(entity_id, event);
1002 })
1003 .detach();
1004 related_excerpt_store
1005 },
1006 events: VecDeque::new(),
1007 last_event: None,
1008 recent_paths: VecDeque::new(),
1009 debug_tx: None,
1010 registered_buffers: HashMap::default(),
1011 current_prediction: None,
1012 cancelled_predictions: HashSet::default(),
1013 pending_predictions: ArrayVec::new(),
1014 next_pending_prediction_id: 0,
1015 last_edit_prediction_refresh: None,
1016 last_jump_prediction_refresh: None,
1017 license_detection_watchers: HashMap::default(),
1018 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1019 _subscriptions: [
1020 cx.subscribe(&project, Self::handle_project_event),
1021 cx.observe_release(&project, move |this, _, cx| {
1022 this.projects.remove(&entity_id);
1023 cx.notify();
1024 }),
1025 ],
1026 copilot: None,
1027 })
1028 }
1029
1030 pub fn remove_project(&mut self, project: &Entity<Project>) {
1031 self.projects.remove(&project.entity_id());
1032 }
1033
1034 fn handle_excerpt_store_event(
1035 &mut self,
1036 project_entity_id: EntityId,
1037 event: &RelatedExcerptStoreEvent,
1038 ) {
1039 if let Some(project_state) = self.projects.get(&project_entity_id) {
1040 if let Some(debug_tx) = project_state.debug_tx.clone() {
1041 match event {
1042 RelatedExcerptStoreEvent::StartedRefresh => {
1043 debug_tx
1044 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1045 ContextRetrievalStartedDebugEvent {
1046 project_entity_id: project_entity_id,
1047 timestamp: Instant::now(),
1048 search_prompt: String::new(),
1049 },
1050 ))
1051 .ok();
1052 }
1053 RelatedExcerptStoreEvent::FinishedRefresh {
1054 cache_hit_count,
1055 cache_miss_count,
1056 mean_definition_latency,
1057 max_definition_latency,
1058 } => {
1059 debug_tx
1060 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1061 ContextRetrievalFinishedDebugEvent {
1062 project_entity_id: project_entity_id,
1063 timestamp: Instant::now(),
1064 metadata: vec![
1065 (
1066 "Cache Hits",
1067 format!(
1068 "{}/{}",
1069 cache_hit_count,
1070 cache_hit_count + cache_miss_count
1071 )
1072 .into(),
1073 ),
1074 (
1075 "Max LSP Time",
1076 format!("{} ms", max_definition_latency.as_millis())
1077 .into(),
1078 ),
1079 (
1080 "Mean LSP Time",
1081 format!("{} ms", mean_definition_latency.as_millis())
1082 .into(),
1083 ),
1084 ],
1085 },
1086 ))
1087 .ok();
1088 }
1089 }
1090 }
1091 }
1092 }
1093
1094 pub fn debug_info(
1095 &mut self,
1096 project: &Entity<Project>,
1097 cx: &mut Context<Self>,
1098 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1099 let project_state = self.get_or_init_project(project, cx);
1100 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1101 project_state.debug_tx = Some(debug_watch_tx);
1102 debug_watch_rx
1103 }
1104
1105 fn handle_project_event(
1106 &mut self,
1107 project: Entity<Project>,
1108 event: &project::Event,
1109 cx: &mut Context<Self>,
1110 ) {
1111 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1112 return;
1113 }
1114 // TODO [zeta2] init with recent paths
1115 match event {
1116 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1117 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1118 return;
1119 };
1120 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1121 if let Some(path) = path {
1122 if let Some(ix) = project_state
1123 .recent_paths
1124 .iter()
1125 .position(|probe| probe == &path)
1126 {
1127 project_state.recent_paths.remove(ix);
1128 }
1129 project_state.recent_paths.push_front(path);
1130 }
1131 }
1132 project::Event::DiagnosticsUpdated { .. } => {
1133 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1134 self.refresh_prediction_from_diagnostics(
1135 project,
1136 DiagnosticSearchScope::Global,
1137 cx,
1138 );
1139 }
1140 }
1141 _ => (),
1142 }
1143 }
1144
1145 fn register_buffer_impl<'a>(
1146 project_state: &'a mut ProjectState,
1147 buffer: &Entity<Buffer>,
1148 project: &Entity<Project>,
1149 cx: &mut Context<Self>,
1150 ) -> &'a mut RegisteredBuffer {
1151 let buffer_id = buffer.entity_id();
1152
1153 if let Some(file) = buffer.read(cx).file() {
1154 let worktree_id = file.worktree_id(cx);
1155 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1156 project_state
1157 .license_detection_watchers
1158 .entry(worktree_id)
1159 .or_insert_with(|| {
1160 let project_entity_id = project.entity_id();
1161 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1162 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1163 else {
1164 return;
1165 };
1166 project_state
1167 .license_detection_watchers
1168 .remove(&worktree_id);
1169 })
1170 .detach();
1171 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1172 });
1173 }
1174 }
1175
1176 match project_state.registered_buffers.entry(buffer_id) {
1177 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1178 hash_map::Entry::Vacant(entry) => {
1179 let buf = buffer.read(cx);
1180 let snapshot = buf.text_snapshot();
1181 let file = buf.file().cloned();
1182 let project_entity_id = project.entity_id();
1183 entry.insert(RegisteredBuffer {
1184 snapshot,
1185 file,
1186 last_position: None,
1187 pending_predictions: Vec::new(),
1188 _subscriptions: [
1189 cx.subscribe(buffer, {
1190 let project = project.downgrade();
1191 move |this, buffer, event, cx| {
1192 if let language::BufferEvent::Edited = event
1193 && let Some(project) = project.upgrade()
1194 {
1195 this.report_changes_for_buffer(&buffer, &project, false, cx);
1196 }
1197 }
1198 }),
1199 cx.observe_release(buffer, move |this, _buffer, _cx| {
1200 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1201 else {
1202 return;
1203 };
1204 project_state.registered_buffers.remove(&buffer_id);
1205 }),
1206 ],
1207 })
1208 }
1209 }
1210 }
1211
1212 fn report_changes_for_buffer(
1213 &mut self,
1214 buffer: &Entity<Buffer>,
1215 project: &Entity<Project>,
1216 is_predicted: bool,
1217 cx: &mut Context<Self>,
1218 ) {
1219 let project_state = self.get_or_init_project(project, cx);
1220 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1221
1222 let buf = buffer.read(cx);
1223 let new_file = buf.file().cloned();
1224 let new_snapshot = buf.text_snapshot();
1225 if new_snapshot.version == registered_buffer.snapshot.version {
1226 return;
1227 }
1228
1229 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1230 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1231 let mut num_edits = 0usize;
1232 let mut total_deleted = 0usize;
1233 let mut total_inserted = 0usize;
1234 let mut edit_range: Option<Range<Anchor>> = None;
1235 let mut last_offset: Option<usize> = None;
1236 let now = cx.background_executor().now();
1237
1238 for (edit, anchor_range) in
1239 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1240 {
1241 num_edits += 1;
1242 total_deleted += edit.old.len();
1243 total_inserted += edit.new.len();
1244 edit_range = Some(match edit_range {
1245 None => anchor_range,
1246 Some(acc) => acc.start..anchor_range.end,
1247 });
1248 last_offset = Some(edit.new.end);
1249 }
1250
1251 let Some(edit_range) = edit_range else {
1252 return;
1253 };
1254
1255 for pending_prediction in &mut registered_buffer.pending_predictions {
1256 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1257 pending_prediction.last_edit_at = now;
1258 }
1259 }
1260
1261 let action_type = match (total_deleted, total_inserted, num_edits) {
1262 (0, ins, n) if ins == n => UserActionType::InsertChar,
1263 (0, _, _) => UserActionType::InsertSelection,
1264 (del, 0, n) if del == n => UserActionType::DeleteChar,
1265 (_, 0, _) => UserActionType::DeleteSelection,
1266 (_, ins, n) if ins == n => UserActionType::InsertChar,
1267 (_, _, _) => UserActionType::InsertSelection,
1268 };
1269
1270 if let Some(offset) = last_offset {
1271 let point = new_snapshot.offset_to_point(offset);
1272 let timestamp_epoch_ms = SystemTime::now()
1273 .duration_since(UNIX_EPOCH)
1274 .map(|d| d.as_millis() as u64)
1275 .unwrap_or(0);
1276 project_state.record_user_action(UserActionRecord {
1277 action_type,
1278 buffer_id: buffer.entity_id(),
1279 line_number: point.row,
1280 offset,
1281 timestamp_epoch_ms,
1282 });
1283 }
1284
1285 let events = &mut project_state.events;
1286
1287 if let Some(last_event) = project_state.last_event.as_mut() {
1288 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1289 == last_event.new_snapshot.remote_id()
1290 && old_snapshot.version == last_event.new_snapshot.version;
1291
1292 let prediction_source_changed = is_predicted != last_event.predicted;
1293
1294 let should_coalesce = is_next_snapshot_of_same_buffer
1295 && !prediction_source_changed
1296 && last_event
1297 .edit_range
1298 .as_ref()
1299 .is_some_and(|last_edit_range| {
1300 lines_between_ranges(
1301 &edit_range.to_point(&new_snapshot),
1302 &last_edit_range.to_point(&new_snapshot),
1303 ) <= CHANGE_GROUPING_LINE_SPAN
1304 });
1305
1306 if should_coalesce {
1307 let pause_elapsed = last_event
1308 .last_edit_time
1309 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1310 .unwrap_or(false);
1311 if pause_elapsed {
1312 last_event.snapshot_after_last_editing_pause =
1313 Some(last_event.new_snapshot.clone());
1314 }
1315
1316 last_event.edit_range = Some(edit_range);
1317 last_event.new_snapshot = new_snapshot;
1318 last_event.last_edit_time = Some(now);
1319 return;
1320 }
1321 }
1322
1323 if let Some(event) = project_state.last_event.take() {
1324 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1325 if events.len() + 1 >= EVENT_COUNT_MAX {
1326 events.pop_front();
1327 }
1328 events.push_back(event);
1329 }
1330 }
1331
1332 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1333
1334 project_state.last_event = Some(LastEvent {
1335 old_file,
1336 new_file,
1337 old_snapshot,
1338 new_snapshot,
1339 edit_range: Some(edit_range),
1340 predicted: is_predicted,
1341 snapshot_after_last_editing_pause: None,
1342 last_edit_time: Some(now),
1343 });
1344 }
1345
1346 fn prediction_at(
1347 &mut self,
1348 buffer: &Entity<Buffer>,
1349 position: Option<language::Anchor>,
1350 project: &Entity<Project>,
1351 cx: &App,
1352 ) -> Option<BufferEditPrediction<'_>> {
1353 let project_state = self.projects.get_mut(&project.entity_id())?;
1354 if let Some(position) = position
1355 && let Some(buffer) = project_state
1356 .registered_buffers
1357 .get_mut(&buffer.entity_id())
1358 {
1359 buffer.last_position = Some(position);
1360 }
1361
1362 let CurrentEditPrediction {
1363 requested_by,
1364 prediction,
1365 ..
1366 } = project_state.current_prediction.as_ref()?;
1367
1368 if prediction.targets_buffer(buffer.read(cx)) {
1369 Some(BufferEditPrediction::Local { prediction })
1370 } else {
1371 let show_jump = match requested_by {
1372 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1373 requested_by_buffer_id == &buffer.entity_id()
1374 }
1375 PredictionRequestedBy::DiagnosticsUpdate => true,
1376 };
1377
1378 if show_jump {
1379 Some(BufferEditPrediction::Jump { prediction })
1380 } else {
1381 None
1382 }
1383 }
1384 }
1385
1386 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1387 let Some(current_prediction) = self
1388 .projects
1389 .get_mut(&project.entity_id())
1390 .and_then(|project_state| project_state.current_prediction.take())
1391 else {
1392 return;
1393 };
1394
1395 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1396
1397 // can't hold &mut project_state ref across report_changes_for_buffer_call
1398 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1399 return;
1400 };
1401
1402 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1403 project_state.cancel_pending_prediction(pending_prediction, cx);
1404 }
1405
1406 match self.edit_prediction_model {
1407 EditPredictionModel::Sweep => {
1408 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1409 }
1410 EditPredictionModel::Mercury => {
1411 mercury::edit_prediction_accepted(
1412 current_prediction.prediction.id,
1413 self.client.http_client(),
1414 cx,
1415 );
1416 }
1417 EditPredictionModel::Zeta => {
1418 let is_cloud = !matches!(
1419 all_language_settings(None, cx).edit_predictions.provider,
1420 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1421 );
1422 if is_cloud {
1423 zeta::edit_prediction_accepted(self, current_prediction, cx)
1424 }
1425 }
1426 EditPredictionModel::Fim { .. } => {}
1427 }
1428 }
1429
1430 async fn handle_rejected_predictions(
1431 rx: UnboundedReceiver<EditPredictionRejection>,
1432 client: Arc<Client>,
1433 llm_token: LlmApiToken,
1434 app_version: Version,
1435 background_executor: BackgroundExecutor,
1436 ) {
1437 let mut rx = std::pin::pin!(rx.peekable());
1438 let mut batched = Vec::new();
1439
1440 while let Some(rejection) = rx.next().await {
1441 batched.push(rejection);
1442
1443 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1444 select_biased! {
1445 next = rx.as_mut().peek().fuse() => {
1446 if next.is_some() {
1447 continue;
1448 }
1449 }
1450 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1451 }
1452 }
1453
1454 let url = client
1455 .http_client()
1456 .build_zed_llm_url("/predict_edits/reject", &[])
1457 .unwrap();
1458
1459 let flush_count = batched
1460 .len()
1461 // in case items have accumulated after failure
1462 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1463 let start = batched.len() - flush_count;
1464
1465 let body = RejectEditPredictionsBodyRef {
1466 rejections: &batched[start..],
1467 };
1468
1469 let result = Self::send_api_request::<()>(
1470 |builder| {
1471 let req = builder
1472 .uri(url.as_ref())
1473 .body(serde_json::to_string(&body)?.into());
1474 anyhow::Ok(req?)
1475 },
1476 client.clone(),
1477 llm_token.clone(),
1478 app_version.clone(),
1479 true,
1480 )
1481 .await;
1482
1483 if result.log_err().is_some() {
1484 batched.drain(start..);
1485 }
1486 }
1487 }
1488
1489 async fn run_settled_predictions_worker(
1490 this: WeakEntity<Self>,
1491 mut rx: UnboundedReceiver<Instant>,
1492 cx: &mut AsyncApp,
1493 ) {
1494 let mut next_wake_time: Option<Instant> = None;
1495 loop {
1496 let now = cx.background_executor().now();
1497 if let Some(wake_time) = next_wake_time.take() {
1498 cx.background_executor()
1499 .timer(wake_time.duration_since(now))
1500 .await;
1501 } else {
1502 let Some(new_enqueue_time) = rx.next().await else {
1503 break;
1504 };
1505 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1506 while rx.next().now_or_never().flatten().is_some() {}
1507 continue;
1508 }
1509
1510 let Some(this) = this.upgrade() else {
1511 break;
1512 };
1513
1514 let now = cx.background_executor().now();
1515
1516 let mut oldest_edited_at = None;
1517
1518 this.update(cx, |this, _| {
1519 for (_, project_state) in this.projects.iter_mut() {
1520 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1521 registered_buffer
1522 .pending_predictions
1523 .retain_mut(|pending_prediction| {
1524 let age =
1525 now.saturating_duration_since(pending_prediction.enqueued_at);
1526 if age >= EDIT_PREDICTION_SETTLED_TTL {
1527 return false;
1528 }
1529
1530 let quiet_for =
1531 now.saturating_duration_since(pending_prediction.last_edit_at);
1532 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1533 let settled_editable_region = registered_buffer
1534 .snapshot
1535 .text_for_range(
1536 pending_prediction.editable_anchor_range.clone(),
1537 )
1538 .collect::<String>();
1539
1540 #[cfg(test)]
1541 if let Some(callback) = &this.settled_event_callback {
1542 callback(
1543 pending_prediction.request_id.clone(),
1544 settled_editable_region.clone(),
1545 );
1546 }
1547
1548 telemetry::event!(
1549 EDIT_PREDICTION_SETTLED_EVENT,
1550 request_id = pending_prediction.request_id.0.clone(),
1551 settled_editable_region,
1552 );
1553
1554 return false;
1555 }
1556
1557 if oldest_edited_at
1558 .is_none_or(|t| pending_prediction.last_edit_at < t)
1559 {
1560 oldest_edited_at = Some(pending_prediction.last_edit_at);
1561 }
1562
1563 true
1564 });
1565 }
1566 }
1567 });
1568
1569 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1570 }
1571 }
1572
1573 pub(crate) fn enqueue_settled_prediction(
1574 &mut self,
1575 request_id: EditPredictionId,
1576 project: &Entity<Project>,
1577 edited_buffer: &Entity<Buffer>,
1578 edited_buffer_snapshot: &BufferSnapshot,
1579 editable_offset_range: Range<usize>,
1580 cx: &mut Context<Self>,
1581 ) {
1582 let project_state = self.get_or_init_project(project, cx);
1583 if let Some(buffer) = project_state
1584 .registered_buffers
1585 .get_mut(&edited_buffer.entity_id())
1586 {
1587 let now = cx.background_executor().now();
1588 buffer.pending_predictions.push(PendingSettledPrediction {
1589 request_id,
1590 editable_anchor_range: edited_buffer_snapshot
1591 .anchor_range_around(editable_offset_range),
1592 enqueued_at: now,
1593 last_edit_at: now,
1594 });
1595 self.settled_predictions_tx.unbounded_send(now).ok();
1596 }
1597 }
1598
1599 fn reject_current_prediction(
1600 &mut self,
1601 reason: EditPredictionRejectReason,
1602 project: &Entity<Project>,
1603 cx: &App,
1604 ) {
1605 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1606 project_state.pending_predictions.clear();
1607 if let Some(prediction) = project_state.current_prediction.take() {
1608 let model_version = prediction.prediction.model_version.clone();
1609 self.reject_prediction(
1610 prediction.prediction.id,
1611 reason,
1612 prediction.was_shown,
1613 model_version,
1614 cx,
1615 );
1616 }
1617 };
1618 }
1619
1620 fn did_show_current_prediction(
1621 &mut self,
1622 project: &Entity<Project>,
1623 display_type: edit_prediction_types::SuggestionDisplayType,
1624 cx: &mut Context<Self>,
1625 ) {
1626 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1627 return;
1628 };
1629
1630 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1631 return;
1632 };
1633
1634 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1635 let previous_shown_with = current_prediction.shown_with;
1636
1637 if previous_shown_with.is_none() || !is_jump {
1638 current_prediction.shown_with = Some(display_type);
1639 }
1640
1641 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1642
1643 if is_first_non_jump_show {
1644 current_prediction.was_shown = true;
1645 }
1646
1647 let display_type_changed = previous_shown_with != Some(display_type);
1648
1649 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1650 sweep_ai::edit_prediction_shown(
1651 &self.sweep_ai,
1652 self.client.clone(),
1653 ¤t_prediction.prediction,
1654 display_type,
1655 cx,
1656 );
1657 }
1658
1659 if is_first_non_jump_show {
1660 self.shown_predictions
1661 .push_front(current_prediction.prediction.clone());
1662 if self.shown_predictions.len() > 50 {
1663 let completion = self.shown_predictions.pop_back().unwrap();
1664 self.rated_predictions.remove(&completion.id);
1665 }
1666 }
1667 }
1668
1669 fn reject_prediction(
1670 &mut self,
1671 prediction_id: EditPredictionId,
1672 reason: EditPredictionRejectReason,
1673 was_shown: bool,
1674 model_version: Option<String>,
1675 cx: &App,
1676 ) {
1677 match self.edit_prediction_model {
1678 EditPredictionModel::Zeta => {
1679 let is_cloud = !matches!(
1680 all_language_settings(None, cx).edit_predictions.provider,
1681 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1682 );
1683 if is_cloud {
1684 self.reject_predictions_tx
1685 .unbounded_send(EditPredictionRejection {
1686 request_id: prediction_id.to_string(),
1687 reason,
1688 was_shown,
1689 model_version,
1690 })
1691 .log_err();
1692 }
1693 }
1694 EditPredictionModel::Mercury => {
1695 mercury::edit_prediction_rejected(
1696 prediction_id,
1697 was_shown,
1698 reason,
1699 self.client.http_client(),
1700 cx,
1701 );
1702 }
1703 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1704 }
1705 }
1706
1707 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1708 self.projects
1709 .get(&project.entity_id())
1710 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1711 }
1712
1713 pub fn refresh_prediction_from_buffer(
1714 &mut self,
1715 project: Entity<Project>,
1716 buffer: Entity<Buffer>,
1717 position: language::Anchor,
1718 cx: &mut Context<Self>,
1719 ) {
1720 self.queue_prediction_refresh(
1721 project.clone(),
1722 PredictEditsRequestTrigger::Other,
1723 buffer.entity_id(),
1724 cx,
1725 move |this, cx| {
1726 let Some(request_task) = this
1727 .update(cx, |this, cx| {
1728 this.request_prediction(
1729 &project,
1730 &buffer,
1731 position,
1732 PredictEditsRequestTrigger::Other,
1733 cx,
1734 )
1735 })
1736 .log_err()
1737 else {
1738 return Task::ready(anyhow::Ok(None));
1739 };
1740
1741 cx.spawn(async move |_cx| {
1742 request_task.await.map(|prediction_result| {
1743 prediction_result.map(|prediction_result| {
1744 (
1745 prediction_result,
1746 PredictionRequestedBy::Buffer(buffer.entity_id()),
1747 )
1748 })
1749 })
1750 })
1751 },
1752 )
1753 }
1754
1755 pub fn refresh_prediction_from_diagnostics(
1756 &mut self,
1757 project: Entity<Project>,
1758 scope: DiagnosticSearchScope,
1759 cx: &mut Context<Self>,
1760 ) {
1761 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1762 return;
1763 }
1764
1765 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1766 return;
1767 };
1768
1769 // Prefer predictions from buffer
1770 if project_state.current_prediction.is_some() {
1771 return;
1772 }
1773
1774 self.queue_prediction_refresh(
1775 project.clone(),
1776 PredictEditsRequestTrigger::Diagnostics,
1777 project.entity_id(),
1778 cx,
1779 move |this, cx| {
1780 let Some((active_buffer, snapshot, cursor_point)) = this
1781 .read_with(cx, |this, cx| {
1782 let project_state = this.projects.get(&project.entity_id())?;
1783 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1784 let snapshot = buffer.read(cx).snapshot();
1785
1786 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1787 return None;
1788 }
1789
1790 let cursor_point = position
1791 .map(|pos| pos.to_point(&snapshot))
1792 .unwrap_or_default();
1793
1794 Some((buffer, snapshot, cursor_point))
1795 })
1796 .log_err()
1797 .flatten()
1798 else {
1799 return Task::ready(anyhow::Ok(None));
1800 };
1801
1802 cx.spawn(async move |cx| {
1803 let diagnostic_search_range = match scope {
1804 DiagnosticSearchScope::Local => {
1805 let diagnostic_search_start =
1806 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1807 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1808 Point::new(diagnostic_search_start, 0)
1809 ..Point::new(diagnostic_search_end, 0)
1810 }
1811 DiagnosticSearchScope::Global => Default::default(),
1812 };
1813
1814 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1815 active_buffer,
1816 &snapshot,
1817 diagnostic_search_range,
1818 cursor_point,
1819 &project,
1820 cx,
1821 )
1822 .await?
1823 else {
1824 return anyhow::Ok(None);
1825 };
1826
1827 let Some(prediction_result) = this
1828 .update(cx, |this, cx| {
1829 this.request_prediction(
1830 &project,
1831 &jump_buffer,
1832 jump_position,
1833 PredictEditsRequestTrigger::Diagnostics,
1834 cx,
1835 )
1836 })?
1837 .await?
1838 else {
1839 return anyhow::Ok(None);
1840 };
1841
1842 this.update(cx, |this, cx| {
1843 Some((
1844 if this
1845 .get_or_init_project(&project, cx)
1846 .current_prediction
1847 .is_none()
1848 {
1849 prediction_result
1850 } else {
1851 EditPredictionResult {
1852 id: prediction_result.id,
1853 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1854 }
1855 },
1856 PredictionRequestedBy::DiagnosticsUpdate,
1857 ))
1858 })
1859 })
1860 },
1861 );
1862 }
1863
1864 fn predictions_enabled_at(
1865 snapshot: &BufferSnapshot,
1866 position: Option<language::Anchor>,
1867 cx: &App,
1868 ) -> bool {
1869 let file = snapshot.file();
1870 let all_settings = all_language_settings(file, cx);
1871 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1872 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1873 {
1874 return false;
1875 }
1876
1877 if let Some(last_position) = position {
1878 let settings = snapshot.settings_at(last_position, cx);
1879
1880 if !settings.edit_predictions_disabled_in.is_empty()
1881 && let Some(scope) = snapshot.language_scope_at(last_position)
1882 && let Some(scope_name) = scope.override_name()
1883 && settings
1884 .edit_predictions_disabled_in
1885 .iter()
1886 .any(|s| s == scope_name)
1887 {
1888 return false;
1889 }
1890 }
1891
1892 true
1893 }
1894
1895 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1896}
1897
1898fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1899 match provider {
1900 EditPredictionProvider::Zed
1901 | EditPredictionProvider::Sweep
1902 | EditPredictionProvider::Mercury
1903 | EditPredictionProvider::Ollama
1904 | EditPredictionProvider::OpenAiCompatibleApi
1905 | EditPredictionProvider::Experimental(_) => true,
1906 EditPredictionProvider::None
1907 | EditPredictionProvider::Copilot
1908 | EditPredictionProvider::Codestral => false,
1909 }
1910}
1911
1912impl EditPredictionStore {
1913 fn queue_prediction_refresh(
1914 &mut self,
1915 project: Entity<Project>,
1916 request_trigger: PredictEditsRequestTrigger,
1917 throttle_entity: EntityId,
1918 cx: &mut Context<Self>,
1919 do_refresh: impl FnOnce(
1920 WeakEntity<Self>,
1921 &mut AsyncApp,
1922 )
1923 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1924 + 'static,
1925 ) {
1926 fn select_throttle(
1927 project_state: &mut ProjectState,
1928 request_trigger: PredictEditsRequestTrigger,
1929 ) -> &mut Option<(EntityId, Instant)> {
1930 match request_trigger {
1931 PredictEditsRequestTrigger::Diagnostics => {
1932 &mut project_state.last_jump_prediction_refresh
1933 }
1934 _ => &mut project_state.last_edit_prediction_refresh,
1935 }
1936 }
1937
1938 let (needs_acceptance_tracking, max_pending_predictions) =
1939 match all_language_settings(None, cx).edit_predictions.provider {
1940 EditPredictionProvider::Zed
1941 | EditPredictionProvider::Sweep
1942 | EditPredictionProvider::Mercury
1943 | EditPredictionProvider::Experimental(_) => (true, 2),
1944 EditPredictionProvider::Ollama => (false, 1),
1945 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1946 EditPredictionProvider::None
1947 | EditPredictionProvider::Copilot
1948 | EditPredictionProvider::Codestral => {
1949 log::error!("queue_prediction_refresh called with non-store provider");
1950 return;
1951 }
1952 };
1953
1954 let drop_on_cancel = !needs_acceptance_tracking;
1955 let throttle_timeout = Self::THROTTLE_TIMEOUT;
1956 let project_state = self.get_or_init_project(&project, cx);
1957 let pending_prediction_id = project_state.next_pending_prediction_id;
1958 project_state.next_pending_prediction_id += 1;
1959 let last_request = *select_throttle(project_state, request_trigger);
1960
1961 let task = cx.spawn(async move |this, cx| {
1962 if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1963 if throttle_entity != last_entity {
1964 return None;
1965 }
1966 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1967 }) {
1968 cx.background_executor().timer(timeout).await;
1969 }
1970
1971 // If this task was cancelled before the throttle timeout expired,
1972 // do not perform a request.
1973 let mut is_cancelled = true;
1974 this.update(cx, |this, cx| {
1975 let project_state = this.get_or_init_project(&project, cx);
1976 let was_cancelled = project_state
1977 .cancelled_predictions
1978 .remove(&pending_prediction_id);
1979 if !was_cancelled {
1980 let new_refresh = (throttle_entity, Instant::now());
1981 *select_throttle(project_state, request_trigger) = Some(new_refresh);
1982 is_cancelled = false;
1983 }
1984 })
1985 .ok();
1986 if is_cancelled {
1987 return None;
1988 }
1989
1990 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1991 let new_prediction_id = new_prediction_result
1992 .as_ref()
1993 .map(|(prediction, _)| prediction.id.clone());
1994
1995 // When a prediction completes, remove it from the pending list, and cancel
1996 // any pending predictions that were enqueued before it.
1997 this.update(cx, |this, cx| {
1998 let project_state = this.get_or_init_project(&project, cx);
1999
2000 let is_cancelled = project_state
2001 .cancelled_predictions
2002 .remove(&pending_prediction_id);
2003
2004 let new_current_prediction = if !is_cancelled
2005 && let Some((prediction_result, requested_by)) = new_prediction_result
2006 {
2007 match prediction_result.prediction {
2008 Ok(prediction) => {
2009 let new_prediction = CurrentEditPrediction {
2010 requested_by,
2011 prediction,
2012 was_shown: false,
2013 shown_with: None,
2014 };
2015
2016 if let Some(current_prediction) =
2017 project_state.current_prediction.as_ref()
2018 {
2019 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2020 {
2021 this.reject_current_prediction(
2022 EditPredictionRejectReason::Replaced,
2023 &project,
2024 cx,
2025 );
2026
2027 Some(new_prediction)
2028 } else {
2029 this.reject_prediction(
2030 new_prediction.prediction.id,
2031 EditPredictionRejectReason::CurrentPreferred,
2032 false,
2033 new_prediction.prediction.model_version,
2034 cx,
2035 );
2036 None
2037 }
2038 } else {
2039 Some(new_prediction)
2040 }
2041 }
2042 Err(reject_reason) => {
2043 this.reject_prediction(
2044 prediction_result.id,
2045 reject_reason,
2046 false,
2047 None,
2048 cx,
2049 );
2050 None
2051 }
2052 }
2053 } else {
2054 None
2055 };
2056
2057 let project_state = this.get_or_init_project(&project, cx);
2058
2059 if let Some(new_prediction) = new_current_prediction {
2060 project_state.current_prediction = Some(new_prediction);
2061 }
2062
2063 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2064 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2065 if pending_prediction.id == pending_prediction_id {
2066 pending_predictions.remove(ix);
2067 for pending_prediction in pending_predictions.drain(0..ix) {
2068 project_state.cancel_pending_prediction(pending_prediction, cx)
2069 }
2070 break;
2071 }
2072 }
2073 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2074 cx.notify();
2075 })
2076 .ok();
2077
2078 new_prediction_id
2079 });
2080
2081 if project_state.pending_predictions.len() < max_pending_predictions {
2082 project_state.pending_predictions.push(PendingPrediction {
2083 id: pending_prediction_id,
2084 task,
2085 drop_on_cancel,
2086 });
2087 } else {
2088 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2089 project_state.pending_predictions.push(PendingPrediction {
2090 id: pending_prediction_id,
2091 task,
2092 drop_on_cancel,
2093 });
2094 project_state.cancel_pending_prediction(pending_prediction, cx);
2095 }
2096 }
2097
2098 pub fn request_prediction(
2099 &mut self,
2100 project: &Entity<Project>,
2101 active_buffer: &Entity<Buffer>,
2102 position: language::Anchor,
2103 trigger: PredictEditsRequestTrigger,
2104 cx: &mut Context<Self>,
2105 ) -> Task<Result<Option<EditPredictionResult>>> {
2106 self.request_prediction_internal(
2107 project.clone(),
2108 active_buffer.clone(),
2109 position,
2110 trigger,
2111 cx.has_flag::<Zeta2FeatureFlag>(),
2112 cx,
2113 )
2114 }
2115
2116 fn request_prediction_internal(
2117 &mut self,
2118 project: Entity<Project>,
2119 active_buffer: Entity<Buffer>,
2120 position: language::Anchor,
2121 trigger: PredictEditsRequestTrigger,
2122 allow_jump: bool,
2123 cx: &mut Context<Self>,
2124 ) -> Task<Result<Option<EditPredictionResult>>> {
2125 self.get_or_init_project(&project, cx);
2126 let project_state = self.projects.get(&project.entity_id()).unwrap();
2127 let stored_events = project_state.events(cx);
2128 let has_events = !stored_events.is_empty();
2129 let events: Vec<Arc<zeta_prompt::Event>> =
2130 stored_events.iter().map(|e| e.event.clone()).collect();
2131 let debug_tx = project_state.debug_tx.clone();
2132
2133 let snapshot = active_buffer.read(cx).snapshot();
2134 let cursor_point = position.to_point(&snapshot);
2135 let current_offset = position.to_offset(&snapshot);
2136
2137 let mut user_actions: Vec<UserActionRecord> =
2138 project_state.user_actions.iter().cloned().collect();
2139
2140 if let Some(last_action) = user_actions.last() {
2141 if last_action.buffer_id == active_buffer.entity_id()
2142 && current_offset != last_action.offset
2143 {
2144 let timestamp_epoch_ms = SystemTime::now()
2145 .duration_since(UNIX_EPOCH)
2146 .map(|d| d.as_millis() as u64)
2147 .unwrap_or(0);
2148 user_actions.push(UserActionRecord {
2149 action_type: UserActionType::CursorMovement,
2150 buffer_id: active_buffer.entity_id(),
2151 line_number: cursor_point.row,
2152 offset: current_offset,
2153 timestamp_epoch_ms,
2154 });
2155 }
2156 }
2157 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2158 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2159 let diagnostic_search_range =
2160 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2161
2162 let related_files = self.context_for_project(&project, cx);
2163
2164 let is_open_source = snapshot
2165 .file()
2166 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2167 && events.iter().all(|event| event.in_open_source_repo())
2168 && related_files.iter().all(|file| file.in_open_source_repo);
2169
2170 let can_collect_data = !cfg!(test)
2171 && is_open_source
2172 && self.is_data_collection_enabled(cx)
2173 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2174
2175 let inputs = EditPredictionModelInput {
2176 project: project.clone(),
2177 buffer: active_buffer.clone(),
2178 snapshot: snapshot,
2179 position,
2180 events,
2181 related_files,
2182 recent_paths: project_state.recent_paths.clone(),
2183 trigger,
2184 diagnostic_search_range: diagnostic_search_range,
2185 debug_tx,
2186 user_actions,
2187 can_collect_data,
2188 is_open_source,
2189 };
2190
2191 if can_collect_data && rand::random_ratio(1, 1000) {
2192 if let Some(task) = capture_example(
2193 project.clone(),
2194 active_buffer,
2195 position,
2196 stored_events,
2197 false,
2198 cx,
2199 ) {
2200 task.detach();
2201 }
2202 }
2203
2204 let task = match self.edit_prediction_model {
2205 EditPredictionModel::Zeta => zeta::request_prediction_with_zeta(self, inputs, cx),
2206 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2207 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2208 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2209 };
2210
2211 cx.spawn(async move |this, cx| {
2212 let prediction = task.await?;
2213
2214 if prediction.is_none() && allow_jump && has_events {
2215 this.update(cx, |this, cx| {
2216 this.refresh_prediction_from_diagnostics(
2217 project,
2218 DiagnosticSearchScope::Local,
2219 cx,
2220 );
2221 })?;
2222 return anyhow::Ok(None);
2223 }
2224
2225 Ok(prediction)
2226 })
2227 }
2228
2229 pub(crate) async fn next_diagnostic_location(
2230 active_buffer: Entity<Buffer>,
2231 active_buffer_snapshot: &BufferSnapshot,
2232 active_buffer_diagnostic_search_range: Range<Point>,
2233 active_buffer_cursor_point: Point,
2234 project: &Entity<Project>,
2235 cx: &mut AsyncApp,
2236 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2237 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2238 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2239 .flat_map(|(_, _, _, selections)| {
2240 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2241 })
2242 .collect();
2243
2244 let mut jump_location = active_buffer_snapshot
2245 .diagnostic_groups(None)
2246 .into_iter()
2247 .filter_map(|(_, group)| {
2248 let range = &group.entries[group.primary_ix]
2249 .range
2250 .to_point(&active_buffer_snapshot);
2251 if range.overlaps(&active_buffer_diagnostic_search_range) {
2252 return None;
2253 }
2254 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2255 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2256 });
2257 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2258 <= DIAGNOSTIC_LINES_RANGE;
2259 if near_collaborator && !near_local {
2260 return None;
2261 }
2262 Some(range.start)
2263 })
2264 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2265 .map(|position| {
2266 (
2267 active_buffer.clone(),
2268 active_buffer_snapshot.anchor_before(position),
2269 )
2270 });
2271
2272 if jump_location.is_none() {
2273 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2274 let file = buffer.file()?;
2275
2276 Some(ProjectPath {
2277 worktree_id: file.worktree_id(cx),
2278 path: file.path().clone(),
2279 })
2280 });
2281
2282 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2283 project
2284 .diagnostic_summaries(false, cx)
2285 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2286 .map(|(path, _, _)| {
2287 let shared_prefix = path
2288 .path
2289 .components()
2290 .zip(
2291 active_buffer_path
2292 .as_ref()
2293 .map(|p| p.path.components())
2294 .unwrap_or_default(),
2295 )
2296 .take_while(|(a, b)| a == b)
2297 .count();
2298 (path, shared_prefix)
2299 })
2300 .collect()
2301 });
2302
2303 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2304
2305 for (path, _) in candidates {
2306 let candidate_buffer = project
2307 .update(cx, |project, cx| project.open_buffer(path, cx))
2308 .await?;
2309
2310 let (has_collaborators, diagnostic_position) =
2311 candidate_buffer.read_with(cx, |buffer, _cx| {
2312 let snapshot = buffer.snapshot();
2313 let has_collaborators = snapshot
2314 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2315 .next()
2316 .is_some();
2317 let position = buffer
2318 .buffer_diagnostics(None)
2319 .into_iter()
2320 .min_by_key(|entry| entry.diagnostic.severity)
2321 .map(|entry| entry.range.start);
2322 (has_collaborators, position)
2323 });
2324
2325 if has_collaborators {
2326 continue;
2327 }
2328
2329 if let Some(position) = diagnostic_position {
2330 jump_location = Some((candidate_buffer, position));
2331 break;
2332 }
2333 }
2334 }
2335
2336 anyhow::Ok(jump_location)
2337 }
2338
2339 async fn send_raw_llm_request(
2340 request: RawCompletionRequest,
2341 client: Arc<Client>,
2342 custom_url: Option<Arc<Url>>,
2343 llm_token: LlmApiToken,
2344 app_version: Version,
2345 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2346 let url = if let Some(custom_url) = custom_url {
2347 custom_url.as_ref().clone()
2348 } else {
2349 client
2350 .http_client()
2351 .build_zed_llm_url("/predict_edits/raw", &[])?
2352 };
2353
2354 Self::send_api_request(
2355 |builder| {
2356 let req = builder
2357 .uri(url.as_ref())
2358 .body(serde_json::to_string(&request)?.into());
2359 Ok(req?)
2360 },
2361 client,
2362 llm_token,
2363 app_version,
2364 true,
2365 )
2366 .await
2367 }
2368
2369 pub(crate) async fn send_v3_request(
2370 input: ZetaPromptInput,
2371 client: Arc<Client>,
2372 llm_token: LlmApiToken,
2373 app_version: Version,
2374 trigger: PredictEditsRequestTrigger,
2375 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2376 let url = client
2377 .http_client()
2378 .build_zed_llm_url("/predict_edits/v3", &[])?;
2379
2380 let request = PredictEditsV3Request { input, trigger };
2381
2382 let json_bytes = serde_json::to_vec(&request)?;
2383 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2384
2385 Self::send_api_request(
2386 |builder| {
2387 let req = builder
2388 .uri(url.as_ref())
2389 .header("Content-Encoding", "zstd")
2390 .body(compressed.clone().into());
2391 Ok(req?)
2392 },
2393 client,
2394 llm_token,
2395 app_version,
2396 true,
2397 )
2398 .await
2399 }
2400
2401 fn handle_api_response<T>(
2402 this: &WeakEntity<Self>,
2403 response: Result<(T, Option<EditPredictionUsage>)>,
2404 cx: &mut gpui::AsyncApp,
2405 ) -> Result<T> {
2406 match response {
2407 Ok((data, usage)) => {
2408 if let Some(usage) = usage {
2409 this.update(cx, |this, cx| {
2410 this.user_store.update(cx, |user_store, cx| {
2411 user_store.update_edit_prediction_usage(usage, cx);
2412 });
2413 })
2414 .ok();
2415 }
2416 Ok(data)
2417 }
2418 Err(err) => {
2419 if err.is::<ZedUpdateRequiredError>() {
2420 cx.update(|cx| {
2421 this.update(cx, |this, _cx| {
2422 this.update_required = true;
2423 })
2424 .ok();
2425
2426 let error_message: SharedString = err.to_string().into();
2427 show_app_notification(
2428 NotificationId::unique::<ZedUpdateRequiredError>(),
2429 cx,
2430 move |cx| {
2431 cx.new(|cx| {
2432 ErrorMessagePrompt::new(error_message.clone(), cx)
2433 .with_link_button("Update Zed", "https://zed.dev/releases")
2434 })
2435 },
2436 );
2437 });
2438 }
2439 Err(err)
2440 }
2441 }
2442 }
2443
2444 async fn send_api_request<Res>(
2445 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2446 client: Arc<Client>,
2447 llm_token: LlmApiToken,
2448 app_version: Version,
2449 require_auth: bool,
2450 ) -> Result<(Res, Option<EditPredictionUsage>)>
2451 where
2452 Res: DeserializeOwned,
2453 {
2454 let http_client = client.http_client();
2455
2456 let mut token = if require_auth {
2457 Some(llm_token.acquire(&client).await?)
2458 } else {
2459 llm_token.acquire(&client).await.ok()
2460 };
2461 let mut did_retry = false;
2462
2463 loop {
2464 let request_builder = http_client::Request::builder().method(Method::POST);
2465
2466 let mut request_builder = request_builder
2467 .header("Content-Type", "application/json")
2468 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2469
2470 // Only add Authorization header if we have a token
2471 if let Some(ref token_value) = token {
2472 request_builder =
2473 request_builder.header("Authorization", format!("Bearer {}", token_value));
2474 }
2475
2476 let request = build(request_builder)?;
2477
2478 let mut response = http_client.send(request).await?;
2479
2480 if let Some(minimum_required_version) = response
2481 .headers()
2482 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2483 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2484 {
2485 anyhow::ensure!(
2486 app_version >= minimum_required_version,
2487 ZedUpdateRequiredError {
2488 minimum_version: minimum_required_version
2489 }
2490 );
2491 }
2492
2493 if response.status().is_success() {
2494 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2495
2496 let mut body = Vec::new();
2497 response.body_mut().read_to_end(&mut body).await?;
2498 return Ok((serde_json::from_slice(&body)?, usage));
2499 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2500 did_retry = true;
2501 token = Some(llm_token.refresh(&client).await?);
2502 } else {
2503 let mut body = String::new();
2504 response.body_mut().read_to_string(&mut body).await?;
2505 anyhow::bail!(
2506 "Request failed with status: {:?}\nBody: {}",
2507 response.status(),
2508 body
2509 );
2510 }
2511 }
2512 }
2513
2514 pub fn refresh_context(
2515 &mut self,
2516 project: &Entity<Project>,
2517 buffer: &Entity<language::Buffer>,
2518 cursor_position: language::Anchor,
2519 cx: &mut Context<Self>,
2520 ) {
2521 self.get_or_init_project(project, cx)
2522 .context
2523 .update(cx, |store, cx| {
2524 store.refresh(buffer.clone(), cursor_position, cx);
2525 });
2526 }
2527
2528 #[cfg(feature = "cli-support")]
2529 pub fn set_context_for_buffer(
2530 &mut self,
2531 project: &Entity<Project>,
2532 related_files: Vec<RelatedFile>,
2533 cx: &mut Context<Self>,
2534 ) {
2535 self.get_or_init_project(project, cx)
2536 .context
2537 .update(cx, |store, cx| {
2538 store.set_related_files(related_files, cx);
2539 });
2540 }
2541
2542 #[cfg(feature = "cli-support")]
2543 pub fn set_recent_paths_for_project(
2544 &mut self,
2545 project: &Entity<Project>,
2546 paths: impl IntoIterator<Item = project::ProjectPath>,
2547 cx: &mut Context<Self>,
2548 ) {
2549 let project_state = self.get_or_init_project(project, cx);
2550 project_state.recent_paths = paths.into_iter().collect();
2551 }
2552
2553 fn is_file_open_source(
2554 &self,
2555 project: &Entity<Project>,
2556 file: &Arc<dyn File>,
2557 cx: &App,
2558 ) -> bool {
2559 if !file.is_local() || file.is_private() {
2560 return false;
2561 }
2562 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2563 return false;
2564 };
2565 project_state
2566 .license_detection_watchers
2567 .get(&file.worktree_id(cx))
2568 .as_ref()
2569 .is_some_and(|watcher| watcher.is_project_open_source())
2570 }
2571
2572 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2573 self.data_collection_choice.is_enabled(cx)
2574 }
2575
2576 fn load_data_collection_choice() -> DataCollectionChoice {
2577 let choice = KEY_VALUE_STORE
2578 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2579 .log_err()
2580 .flatten();
2581
2582 match choice.as_deref() {
2583 Some("true") => DataCollectionChoice::Enabled,
2584 Some("false") => DataCollectionChoice::Disabled,
2585 Some(_) => {
2586 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2587 DataCollectionChoice::NotAnswered
2588 }
2589 None => DataCollectionChoice::NotAnswered,
2590 }
2591 }
2592
2593 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2594 self.data_collection_choice = self.data_collection_choice.toggle();
2595 let new_choice = self.data_collection_choice;
2596 let is_enabled = new_choice.is_enabled(cx);
2597 db::write_and_log(cx, move || {
2598 KEY_VALUE_STORE.write_kvp(
2599 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2600 is_enabled.to_string(),
2601 )
2602 });
2603 }
2604
2605 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2606 self.shown_predictions.iter()
2607 }
2608
2609 pub fn shown_completions_len(&self) -> usize {
2610 self.shown_predictions.len()
2611 }
2612
2613 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2614 self.rated_predictions.contains(id)
2615 }
2616
2617 pub fn rate_prediction(
2618 &mut self,
2619 prediction: &EditPrediction,
2620 rating: EditPredictionRating,
2621 feedback: String,
2622 cx: &mut Context<Self>,
2623 ) {
2624 let organization = self.user_store.read(cx).current_organization();
2625
2626 self.rated_predictions.insert(prediction.id.clone());
2627
2628 cx.background_spawn({
2629 let client = self.client.clone();
2630 let prediction_id = prediction.id.to_string();
2631 let inputs = serde_json::to_value(&prediction.inputs);
2632 let output = prediction
2633 .edit_preview
2634 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2635 async move {
2636 client
2637 .cloud_client()
2638 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2639 organization_id: organization.map(|organization| organization.id.clone()),
2640 request_id: prediction_id,
2641 rating: match rating {
2642 EditPredictionRating::Positive => "positive".to_string(),
2643 EditPredictionRating::Negative => "negative".to_string(),
2644 },
2645 inputs: inputs?,
2646 output,
2647 feedback,
2648 })
2649 .await?;
2650
2651 anyhow::Ok(())
2652 }
2653 })
2654 .detach_and_log_err(cx);
2655
2656 cx.notify();
2657 }
2658}
2659
2660fn merge_trailing_events_if_needed(
2661 events: &mut VecDeque<StoredEvent>,
2662 end_snapshot: &TextBufferSnapshot,
2663 latest_snapshot: &TextBufferSnapshot,
2664 latest_edit_range: &Range<Anchor>,
2665) {
2666 if let Some(last_event) = events.back() {
2667 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2668 return;
2669 }
2670 }
2671
2672 let mut next_old_event = None;
2673 let mut mergeable_count = 0;
2674 for old_event in events.iter().rev() {
2675 if let Some(next_old_event) = &next_old_event
2676 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2677 {
2678 break;
2679 }
2680 mergeable_count += 1;
2681 next_old_event = Some(old_event);
2682 }
2683
2684 if mergeable_count <= 1 {
2685 return;
2686 }
2687
2688 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2689 let oldest_event = events_to_merge.peek().unwrap();
2690 let oldest_snapshot = oldest_event.old_snapshot.clone();
2691
2692 if let Some((diff, edited_range)) =
2693 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2694 {
2695 let merged_event = match oldest_event.event.as_ref() {
2696 zeta_prompt::Event::BufferChange {
2697 old_path,
2698 path,
2699 in_open_source_repo,
2700 ..
2701 } => StoredEvent {
2702 event: Arc::new(zeta_prompt::Event::BufferChange {
2703 old_path: old_path.clone(),
2704 path: path.clone(),
2705 diff,
2706 in_open_source_repo: *in_open_source_repo,
2707 predicted: events_to_merge.all(|e| {
2708 matches!(
2709 e.event.as_ref(),
2710 zeta_prompt::Event::BufferChange {
2711 predicted: true,
2712 ..
2713 }
2714 )
2715 }),
2716 }),
2717 old_snapshot: oldest_snapshot.clone(),
2718 edit_range: end_snapshot.anchor_before(edited_range.start)
2719 ..end_snapshot.anchor_before(edited_range.end),
2720 },
2721 };
2722 events.truncate(events.len() - mergeable_count);
2723 events.push_back(merged_event);
2724 }
2725}
2726
2727pub(crate) fn filter_redundant_excerpts(
2728 mut related_files: Vec<RelatedFile>,
2729 cursor_path: &Path,
2730 cursor_row_range: Range<u32>,
2731) -> Vec<RelatedFile> {
2732 for file in &mut related_files {
2733 if file.path.as_ref() == cursor_path {
2734 file.excerpts.retain(|excerpt| {
2735 excerpt.row_range.start < cursor_row_range.start
2736 || excerpt.row_range.end > cursor_row_range.end
2737 });
2738 }
2739 }
2740 related_files.retain(|file| !file.excerpts.is_empty());
2741 related_files
2742}
2743
2744#[derive(Error, Debug)]
2745#[error(
2746 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2747)]
2748pub struct ZedUpdateRequiredError {
2749 minimum_version: Version,
2750}
2751
2752#[derive(Debug, Clone, Copy)]
2753pub enum DataCollectionChoice {
2754 NotAnswered,
2755 Enabled,
2756 Disabled,
2757}
2758
2759impl DataCollectionChoice {
2760 pub fn is_enabled(self, cx: &App) -> bool {
2761 if cx.is_staff() {
2762 return true;
2763 }
2764 match self {
2765 Self::Enabled => true,
2766 Self::NotAnswered | Self::Disabled => false,
2767 }
2768 }
2769
2770 #[must_use]
2771 pub fn toggle(&self) -> DataCollectionChoice {
2772 match self {
2773 Self::Enabled => Self::Disabled,
2774 Self::Disabled => Self::Enabled,
2775 Self::NotAnswered => Self::Enabled,
2776 }
2777 }
2778}
2779
2780impl From<bool> for DataCollectionChoice {
2781 fn from(value: bool) -> Self {
2782 match value {
2783 true => DataCollectionChoice::Enabled,
2784 false => DataCollectionChoice::Disabled,
2785 }
2786 }
2787}
2788
2789struct ZedPredictUpsell;
2790
2791impl Dismissable for ZedPredictUpsell {
2792 const KEY: &'static str = "dismissed-edit-predict-upsell";
2793
2794 fn dismissed() -> bool {
2795 // To make this backwards compatible with older versions of Zed, we
2796 // check if the user has seen the previous Edit Prediction Onboarding
2797 // before, by checking the data collection choice which was written to
2798 // the database once the user clicked on "Accept and Enable"
2799 if KEY_VALUE_STORE
2800 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2801 .log_err()
2802 .is_some_and(|s| s.is_some())
2803 {
2804 return true;
2805 }
2806
2807 KEY_VALUE_STORE
2808 .read_kvp(Self::KEY)
2809 .log_err()
2810 .is_some_and(|s| s.is_some())
2811 }
2812}
2813
2814pub fn should_show_upsell_modal() -> bool {
2815 !ZedPredictUpsell::dismissed()
2816}
2817
2818pub fn init(cx: &mut App) {
2819 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2820 workspace.register_action(
2821 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2822 ZedPredictModal::toggle(
2823 workspace,
2824 workspace.user_store().clone(),
2825 workspace.client().clone(),
2826 window,
2827 cx,
2828 )
2829 },
2830 );
2831
2832 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2833 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2834 settings
2835 .project
2836 .all_languages
2837 .edit_predictions
2838 .get_or_insert_default()
2839 .provider = Some(EditPredictionProvider::None)
2840 });
2841 });
2842 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2843 EditPredictionStore::try_global(cx).and_then(|store| {
2844 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2845 })
2846 }
2847
2848 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2849 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2850 copilot_ui::initiate_sign_in(copilot, window, cx);
2851 }
2852 });
2853 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2854 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2855 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2856 }
2857 });
2858 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2859 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2860 copilot_ui::initiate_sign_out(copilot, window, cx);
2861 }
2862 });
2863 })
2864 .detach();
2865}