1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56
57pub mod cursor_excerpt;
58pub mod example_spec;
59pub mod fim;
60mod license_detection;
61pub mod mercury;
62pub mod ollama;
63mod onboarding_modal;
64pub mod open_ai_response;
65mod prediction;
66pub mod sweep_ai;
67
68pub mod udiff;
69
70mod capture_example;
71pub mod open_ai_compatible;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::example_spec::ExampleSpec;
79use crate::license_detection::LicenseDetectionWatcher;
80use crate::mercury::Mercury;
81use crate::onboarding_modal::ZedPredictModal;
82pub use crate::prediction::EditPrediction;
83pub use crate::prediction::EditPredictionId;
84use crate::prediction::EditPredictionResult;
85pub use crate::sweep_ai::SweepAi;
86pub use capture_example::capture_example;
87pub use language_model::ApiKeyState;
88pub use telemetry_events::EditPredictionRating;
89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
90
91actions!(
92 edit_prediction,
93 [
94 /// Resets the edit prediction onboarding state.
95 ResetOnboarding,
96 /// Clears the edit prediction history.
97 ClearHistory,
98 ]
99);
100
101/// Maximum number of events to track.
102const EVENT_COUNT_MAX: usize = 6;
103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
107const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
108const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
109const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
110
111pub struct EditPredictionJumpsFeatureFlag;
112
113impl FeatureFlag for EditPredictionJumpsFeatureFlag {
114 const NAME: &'static str = "edit_prediction_jumps";
115}
116
117#[derive(Clone)]
118struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
119
120impl Global for EditPredictionStoreGlobal {}
121
122/// Configuration for using the raw Zeta2 endpoint.
123/// When set, the client uses the raw endpoint and constructs the prompt itself.
124/// The version is also used as the Baseten environment name (lowercased).
125#[derive(Clone)]
126pub struct Zeta2RawConfig {
127 pub model_id: Option<String>,
128 pub environment: Option<String>,
129 pub format: ZetaFormat,
130}
131
132pub struct EditPredictionStore {
133 client: Arc<Client>,
134 user_store: Entity<UserStore>,
135 llm_token: LlmApiToken,
136 _fetch_experiments_task: Task<()>,
137 projects: HashMap<EntityId, ProjectState>,
138 update_required: bool,
139 edit_prediction_model: EditPredictionModel,
140 zeta2_raw_config: Option<Zeta2RawConfig>,
141 preferred_experiment: Option<String>,
142 available_experiments: Vec<String>,
143 pub sweep_ai: SweepAi,
144 pub mercury: Mercury,
145 data_collection_choice: DataCollectionChoice,
146 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
147 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
148 shown_predictions: VecDeque<EditPrediction>,
149 rated_predictions: HashSet<EditPredictionId>,
150 #[cfg(test)]
151 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
152}
153
154pub(crate) struct EditPredictionRejectionPayload {
155 rejection: EditPredictionRejection,
156 organization_id: Option<OrganizationId>,
157}
158
159#[derive(Copy, Clone, PartialEq, Eq)]
160pub enum EditPredictionModel {
161 Zeta,
162 Fim { format: EditPredictionPromptFormat },
163 Sweep,
164 Mercury,
165}
166
167#[derive(Clone)]
168pub struct EditPredictionModelInput {
169 project: Entity<Project>,
170 buffer: Entity<Buffer>,
171 snapshot: BufferSnapshot,
172 position: Anchor,
173 events: Vec<Arc<zeta_prompt::Event>>,
174 related_files: Vec<RelatedFile>,
175 recent_paths: VecDeque<ProjectPath>,
176 trigger: PredictEditsRequestTrigger,
177 diagnostic_search_range: Range<Point>,
178 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
179 can_collect_data: bool,
180 is_open_source: bool,
181 pub user_actions: Vec<UserActionRecord>,
182}
183
184#[derive(Debug)]
185pub enum DebugEvent {
186 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
187 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
188 EditPredictionStarted(EditPredictionStartedDebugEvent),
189 EditPredictionFinished(EditPredictionFinishedDebugEvent),
190}
191
192#[derive(Debug)]
193pub struct ContextRetrievalStartedDebugEvent {
194 pub project_entity_id: EntityId,
195 pub timestamp: Instant,
196 pub search_prompt: String,
197}
198
199#[derive(Debug)]
200pub struct ContextRetrievalFinishedDebugEvent {
201 pub project_entity_id: EntityId,
202 pub timestamp: Instant,
203 pub metadata: Vec<(&'static str, SharedString)>,
204}
205
206#[derive(Debug)]
207pub struct EditPredictionStartedDebugEvent {
208 pub buffer: WeakEntity<Buffer>,
209 pub position: Anchor,
210 pub prompt: Option<String>,
211}
212
213#[derive(Debug)]
214pub struct EditPredictionFinishedDebugEvent {
215 pub buffer: WeakEntity<Buffer>,
216 pub position: Anchor,
217 pub model_output: Option<String>,
218}
219
220const USER_ACTION_HISTORY_SIZE: usize = 16;
221
222#[derive(Clone, Debug)]
223pub struct UserActionRecord {
224 pub action_type: UserActionType,
225 pub buffer_id: EntityId,
226 pub line_number: u32,
227 pub offset: usize,
228 pub timestamp_epoch_ms: u64,
229}
230
231#[derive(Clone, Copy, Debug, PartialEq, Eq)]
232pub enum UserActionType {
233 InsertChar,
234 InsertSelection,
235 DeleteChar,
236 DeleteSelection,
237 CursorMovement,
238}
239
240/// An event with associated metadata for reconstructing buffer state.
241#[derive(Clone)]
242pub struct StoredEvent {
243 pub event: Arc<zeta_prompt::Event>,
244 pub old_snapshot: TextBufferSnapshot,
245 pub edit_range: Range<Anchor>,
246}
247
248impl StoredEvent {
249 fn can_merge(
250 &self,
251 next_old_event: &&&StoredEvent,
252 new_snapshot: &TextBufferSnapshot,
253 last_edit_range: &Range<Anchor>,
254 ) -> bool {
255 // Events must be for the same buffer
256 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
257 return false;
258 }
259 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
260 return false;
261 }
262
263 let a_is_predicted = matches!(
264 self.event.as_ref(),
265 zeta_prompt::Event::BufferChange {
266 predicted: true,
267 ..
268 }
269 );
270 let b_is_predicted = matches!(
271 next_old_event.event.as_ref(),
272 zeta_prompt::Event::BufferChange {
273 predicted: true,
274 ..
275 }
276 );
277
278 // If events come from the same source (both predicted or both manual) then
279 // we would have coalesced them already.
280 if a_is_predicted == b_is_predicted {
281 return false;
282 }
283
284 let left_range = self.edit_range.to_point(new_snapshot);
285 let right_range = next_old_event.edit_range.to_point(new_snapshot);
286 let latest_range = last_edit_range.to_point(&new_snapshot);
287
288 // Events near to the latest edit are not merged if their sources differ.
289 if lines_between_ranges(&left_range, &latest_range)
290 .min(lines_between_ranges(&right_range, &latest_range))
291 <= CHANGE_GROUPING_LINE_SPAN
292 {
293 return false;
294 }
295
296 // Events that are distant from each other are not merged.
297 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
298 return false;
299 }
300
301 true
302 }
303}
304
305fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
306 if left.start > right.end {
307 return left.start.row - right.end.row;
308 }
309 if right.start > left.end {
310 return right.start.row - left.end.row;
311 }
312 0
313}
314
315struct ProjectState {
316 events: VecDeque<StoredEvent>,
317 last_event: Option<LastEvent>,
318 recent_paths: VecDeque<ProjectPath>,
319 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
320 current_prediction: Option<CurrentEditPrediction>,
321 next_pending_prediction_id: usize,
322 pending_predictions: ArrayVec<PendingPrediction, 2>,
323 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
324 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
325 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
326 cancelled_predictions: HashSet<usize>,
327 context: Entity<RelatedExcerptStore>,
328 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
329 user_actions: VecDeque<UserActionRecord>,
330 _subscriptions: [gpui::Subscription; 2],
331 copilot: Option<Entity<Copilot>>,
332}
333
334impl ProjectState {
335 fn record_user_action(&mut self, action: UserActionRecord) {
336 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
337 self.user_actions.pop_front();
338 }
339 self.user_actions.push_back(action);
340 }
341
342 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
343 self.events
344 .iter()
345 .cloned()
346 .chain(self.last_event.as_ref().iter().flat_map(|event| {
347 let (one, two) = event.split_by_pause();
348 let one = one.finalize(&self.license_detection_watchers, cx);
349 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
350 one.into_iter().chain(two)
351 }))
352 .collect()
353 }
354
355 fn cancel_pending_prediction(
356 &mut self,
357 pending_prediction: PendingPrediction,
358 cx: &mut Context<EditPredictionStore>,
359 ) {
360 self.cancelled_predictions.insert(pending_prediction.id);
361
362 if pending_prediction.drop_on_cancel {
363 drop(pending_prediction.task);
364 } else {
365 cx.spawn(async move |this, cx| {
366 let Some(prediction_id) = pending_prediction.task.await else {
367 return;
368 };
369
370 this.update(cx, |this, cx| {
371 this.reject_prediction(
372 prediction_id,
373 EditPredictionRejectReason::Canceled,
374 false,
375 None,
376 cx,
377 );
378 })
379 .ok();
380 })
381 .detach()
382 }
383 }
384
385 fn active_buffer(
386 &self,
387 project: &Entity<Project>,
388 cx: &App,
389 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
390 let project = project.read(cx);
391 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
392 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
393 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
394 Some((active_buffer, registered_buffer.last_position))
395 }
396}
397
398#[derive(Debug, Clone)]
399struct CurrentEditPrediction {
400 pub requested_by: PredictionRequestedBy,
401 pub prediction: EditPrediction,
402 pub was_shown: bool,
403 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
404}
405
406impl CurrentEditPrediction {
407 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
408 let Some(new_edits) = self
409 .prediction
410 .interpolate(&self.prediction.buffer.read(cx))
411 else {
412 return false;
413 };
414
415 if self.prediction.buffer != old_prediction.prediction.buffer {
416 return true;
417 }
418
419 let Some(old_edits) = old_prediction
420 .prediction
421 .interpolate(&old_prediction.prediction.buffer.read(cx))
422 else {
423 return true;
424 };
425
426 let requested_by_buffer_id = self.requested_by.buffer_id();
427
428 // This reduces the occurrence of UI thrash from replacing edits
429 //
430 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
431 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
432 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
433 && old_edits.len() == 1
434 && new_edits.len() == 1
435 {
436 let (old_range, old_text) = &old_edits[0];
437 let (new_range, new_text) = &new_edits[0];
438 new_range == old_range && new_text.starts_with(old_text.as_ref())
439 } else {
440 true
441 }
442 }
443}
444
445#[derive(Debug, Clone)]
446enum PredictionRequestedBy {
447 DiagnosticsUpdate,
448 Buffer(EntityId),
449}
450
451impl PredictionRequestedBy {
452 pub fn buffer_id(&self) -> Option<EntityId> {
453 match self {
454 PredictionRequestedBy::DiagnosticsUpdate => None,
455 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
456 }
457 }
458}
459
460const DIAGNOSTIC_LINES_RANGE: u32 = 20;
461
462#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
463pub enum DiagnosticSearchScope {
464 Local,
465 Global,
466}
467
468#[derive(Debug)]
469struct PendingPrediction {
470 id: usize,
471 task: Task<Option<EditPredictionId>>,
472 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
473 /// If false, the task is awaited to completion so rejection can be reported.
474 drop_on_cancel: bool,
475}
476
477/// A prediction from the perspective of a buffer.
478#[derive(Debug)]
479enum BufferEditPrediction<'a> {
480 Local { prediction: &'a EditPrediction },
481 Jump { prediction: &'a EditPrediction },
482}
483
484#[cfg(test)]
485impl std::ops::Deref for BufferEditPrediction<'_> {
486 type Target = EditPrediction;
487
488 fn deref(&self) -> &Self::Target {
489 match self {
490 BufferEditPrediction::Local { prediction } => prediction,
491 BufferEditPrediction::Jump { prediction } => prediction,
492 }
493 }
494}
495
496#[derive(Clone)]
497struct PendingSettledPrediction {
498 request_id: EditPredictionId,
499 editable_anchor_range: Range<Anchor>,
500 example: Option<ExampleSpec>,
501 enqueued_at: Instant,
502 last_edit_at: Instant,
503}
504
505struct RegisteredBuffer {
506 file: Option<Arc<dyn File>>,
507 snapshot: TextBufferSnapshot,
508 pending_predictions: Vec<PendingSettledPrediction>,
509 last_position: Option<Anchor>,
510 _subscriptions: [gpui::Subscription; 2],
511}
512
513#[derive(Clone)]
514struct LastEvent {
515 old_snapshot: TextBufferSnapshot,
516 new_snapshot: TextBufferSnapshot,
517 old_file: Option<Arc<dyn File>>,
518 new_file: Option<Arc<dyn File>>,
519 edit_range: Option<Range<Anchor>>,
520 predicted: bool,
521 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
522 last_edit_time: Option<Instant>,
523}
524
525impl LastEvent {
526 pub fn finalize(
527 &self,
528 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
529 cx: &App,
530 ) -> Option<StoredEvent> {
531 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
532 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
533
534 let in_open_source_repo =
535 [self.new_file.as_ref(), self.old_file.as_ref()]
536 .iter()
537 .all(|file| {
538 file.is_some_and(|file| {
539 license_detection_watchers
540 .get(&file.worktree_id(cx))
541 .is_some_and(|watcher| watcher.is_project_open_source())
542 })
543 });
544
545 let (diff, edit_range) =
546 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
547
548 if path == old_path && diff.is_empty() {
549 None
550 } else {
551 Some(StoredEvent {
552 event: Arc::new(zeta_prompt::Event::BufferChange {
553 old_path,
554 path,
555 diff,
556 in_open_source_repo,
557 predicted: self.predicted,
558 }),
559 edit_range: self.new_snapshot.anchor_before(edit_range.start)
560 ..self.new_snapshot.anchor_before(edit_range.end),
561 old_snapshot: self.old_snapshot.clone(),
562 })
563 }
564 }
565
566 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
567 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
568 return (self.clone(), None);
569 };
570
571 let before = LastEvent {
572 old_snapshot: self.old_snapshot.clone(),
573 new_snapshot: boundary_snapshot.clone(),
574 old_file: self.old_file.clone(),
575 new_file: self.new_file.clone(),
576 edit_range: None,
577 predicted: self.predicted,
578 snapshot_after_last_editing_pause: None,
579 last_edit_time: self.last_edit_time,
580 };
581
582 let after = LastEvent {
583 old_snapshot: boundary_snapshot.clone(),
584 new_snapshot: self.new_snapshot.clone(),
585 old_file: self.old_file.clone(),
586 new_file: self.new_file.clone(),
587 edit_range: None,
588 predicted: self.predicted,
589 snapshot_after_last_editing_pause: None,
590 last_edit_time: self.last_edit_time,
591 };
592
593 (before, Some(after))
594 }
595}
596
597pub(crate) fn compute_diff_between_snapshots(
598 old_snapshot: &TextBufferSnapshot,
599 new_snapshot: &TextBufferSnapshot,
600) -> Option<(String, Range<Point>)> {
601 let edits: Vec<Edit<usize>> = new_snapshot
602 .edits_since::<usize>(&old_snapshot.version)
603 .collect();
604
605 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
606
607 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
608 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
609 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
610 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
611
612 const CONTEXT_LINES: u32 = 3;
613
614 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
615 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
616 let old_context_end_row =
617 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
618 let new_context_end_row =
619 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
620
621 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
622 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
623 let old_end_line_offset = old_snapshot
624 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
625 let new_end_line_offset = new_snapshot
626 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
627 let old_edit_range = old_start_line_offset..old_end_line_offset;
628 let new_edit_range = new_start_line_offset..new_end_line_offset;
629
630 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
631 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
632
633 let diff = language::unified_diff_with_offsets(
634 &old_region_text,
635 &new_region_text,
636 old_context_start_row,
637 new_context_start_row,
638 );
639
640 Some((diff, new_start_point..new_end_point))
641}
642
643fn buffer_path_with_id_fallback(
644 file: Option<&Arc<dyn File>>,
645 snapshot: &TextBufferSnapshot,
646 cx: &App,
647) -> Arc<Path> {
648 if let Some(file) = file {
649 file.full_path(cx).into()
650 } else {
651 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
652 }
653}
654
655impl EditPredictionStore {
656 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
657 cx.try_global::<EditPredictionStoreGlobal>()
658 .map(|global| global.0.clone())
659 }
660
661 pub fn global(
662 client: &Arc<Client>,
663 user_store: &Entity<UserStore>,
664 cx: &mut App,
665 ) -> Entity<Self> {
666 cx.try_global::<EditPredictionStoreGlobal>()
667 .map(|global| global.0.clone())
668 .unwrap_or_else(|| {
669 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
670 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
671 ep_store
672 })
673 }
674
675 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
676 let data_collection_choice = Self::load_data_collection_choice();
677
678 let llm_token = LlmApiToken::global(cx);
679
680 let (reject_tx, reject_rx) = mpsc::unbounded();
681 cx.background_spawn({
682 let client = client.clone();
683 let llm_token = llm_token.clone();
684 let app_version = AppVersion::global(cx);
685 let background_executor = cx.background_executor().clone();
686 async move {
687 Self::handle_rejected_predictions(
688 reject_rx,
689 client,
690 llm_token,
691 app_version,
692 background_executor,
693 )
694 .await
695 }
696 })
697 .detach();
698
699 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
700 cx.spawn(async move |this, cx| {
701 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
702 })
703 .detach();
704
705 let mut current_user = user_store.read(cx).watch_current_user();
706 let fetch_experiments_task = cx.spawn(async move |this, cx| {
707 while current_user.borrow().is_none() {
708 current_user.next().await;
709 }
710 this.update(cx, |this, cx| {
711 this.refresh_available_experiments(cx);
712 })
713 .log_err();
714 });
715
716 let this = Self {
717 projects: HashMap::default(),
718 client,
719 user_store,
720 llm_token,
721 _fetch_experiments_task: fetch_experiments_task,
722 update_required: false,
723 edit_prediction_model: EditPredictionModel::Zeta,
724 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
725 preferred_experiment: None,
726 available_experiments: Vec::new(),
727 sweep_ai: SweepAi::new(cx),
728 mercury: Mercury::new(cx),
729
730 data_collection_choice,
731 reject_predictions_tx: reject_tx,
732 settled_predictions_tx,
733 rated_predictions: Default::default(),
734 shown_predictions: Default::default(),
735 #[cfg(test)]
736 settled_event_callback: None,
737 };
738
739 this
740 }
741
742 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
743 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
744 let format = ZetaFormat::parse(&version_str).ok()?;
745 let model_id = env::var("ZED_ZETA_MODEL").ok();
746 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
747 Some(Zeta2RawConfig {
748 model_id,
749 environment,
750 format,
751 })
752 }
753
754 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
755 self.edit_prediction_model = model;
756 }
757
758 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
759 self.zeta2_raw_config = Some(config);
760 }
761
762 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
763 self.zeta2_raw_config.as_ref()
764 }
765
766 pub fn preferred_experiment(&self) -> Option<&str> {
767 self.preferred_experiment.as_deref()
768 }
769
770 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
771 self.preferred_experiment = experiment;
772 }
773
774 pub fn available_experiments(&self) -> &[String] {
775 &self.available_experiments
776 }
777
778 pub fn active_experiment(&self) -> Option<&str> {
779 self.preferred_experiment.as_deref().or_else(|| {
780 self.shown_predictions
781 .iter()
782 .find_map(|p| p.model_version.as_ref())
783 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
784 })
785 }
786
787 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
788 let client = self.client.clone();
789 let llm_token = self.llm_token.clone();
790 let app_version = AppVersion::global(cx);
791 let organization_id = self
792 .user_store
793 .read(cx)
794 .current_organization()
795 .map(|organization| organization.id.clone());
796
797 cx.spawn(async move |this, cx| {
798 let experiments = cx
799 .background_spawn(async move {
800 let http_client = client.http_client();
801 let token = llm_token.acquire(&client, organization_id).await?;
802 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
803 let request = http_client::Request::builder()
804 .method(Method::GET)
805 .uri(url.as_ref())
806 .header("Authorization", format!("Bearer {}", token))
807 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
808 .body(Default::default())?;
809 let mut response = http_client.send(request).await?;
810 if response.status().is_success() {
811 let mut body = Vec::new();
812 response.body_mut().read_to_end(&mut body).await?;
813 let experiments: Vec<String> = serde_json::from_slice(&body)?;
814 Ok(experiments)
815 } else {
816 let mut body = String::new();
817 response.body_mut().read_to_string(&mut body).await?;
818 anyhow::bail!(
819 "Failed to fetch experiments: {:?}\nBody: {}",
820 response.status(),
821 body
822 );
823 }
824 })
825 .await?;
826 this.update(cx, |this, cx| {
827 this.available_experiments = experiments;
828 cx.notify();
829 })?;
830 anyhow::Ok(())
831 })
832 .detach_and_log_err(cx);
833 }
834
835 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
836 use ui::IconName;
837 match self.edit_prediction_model {
838 EditPredictionModel::Sweep => {
839 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
840 .with_disabled(IconName::SweepAiDisabled)
841 .with_up(IconName::SweepAiUp)
842 .with_down(IconName::SweepAiDown)
843 .with_error(IconName::SweepAiError)
844 }
845 EditPredictionModel::Mercury => {
846 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
847 }
848 EditPredictionModel::Zeta => {
849 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
850 .with_disabled(IconName::ZedPredictDisabled)
851 .with_up(IconName::ZedPredictUp)
852 .with_down(IconName::ZedPredictDown)
853 .with_error(IconName::ZedPredictError)
854 }
855 EditPredictionModel::Fim { .. } => {
856 let settings = &all_language_settings(None, cx).edit_predictions;
857 match settings.provider {
858 EditPredictionProvider::Ollama => {
859 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
860 }
861 _ => {
862 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
863 }
864 }
865 }
866 }
867 }
868
869 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
870 self.sweep_ai.api_token.read(cx).has_key()
871 }
872
873 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
874 self.mercury.api_token.read(cx).has_key()
875 }
876
877 pub fn clear_history(&mut self) {
878 for project_state in self.projects.values_mut() {
879 project_state.events.clear();
880 project_state.last_event.take();
881 }
882 }
883
884 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
885 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
886 project_state.events.clear();
887 project_state.last_event.take();
888 }
889 }
890
891 pub fn edit_history_for_project(
892 &self,
893 project: &Entity<Project>,
894 cx: &App,
895 ) -> Vec<StoredEvent> {
896 self.projects
897 .get(&project.entity_id())
898 .map(|project_state| project_state.events(cx))
899 .unwrap_or_default()
900 }
901
902 pub fn context_for_project<'a>(
903 &'a self,
904 project: &Entity<Project>,
905 cx: &'a mut App,
906 ) -> Vec<RelatedFile> {
907 self.projects
908 .get(&project.entity_id())
909 .map(|project_state| {
910 project_state.context.update(cx, |context, cx| {
911 context
912 .related_files_with_buffers(cx)
913 .map(|(mut related_file, buffer)| {
914 related_file.in_open_source_repo = buffer
915 .read(cx)
916 .file()
917 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
918 related_file
919 })
920 .collect()
921 })
922 })
923 .unwrap_or_default()
924 }
925
926 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
927 self.projects
928 .get(&project.entity_id())
929 .and_then(|project| project.copilot.clone())
930 }
931
932 pub fn start_copilot_for_project(
933 &mut self,
934 project: &Entity<Project>,
935 cx: &mut Context<Self>,
936 ) -> Option<Entity<Copilot>> {
937 if DisableAiSettings::get(None, cx).disable_ai {
938 return None;
939 }
940 let state = self.get_or_init_project(project, cx);
941
942 if state.copilot.is_some() {
943 return state.copilot.clone();
944 }
945 let _project = project.clone();
946 let project = project.read(cx);
947
948 let node = project.node_runtime().cloned();
949 if let Some(node) = node {
950 let next_id = project.languages().next_language_server_id();
951 let fs = project.fs().clone();
952
953 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
954 state.copilot = Some(copilot.clone());
955 Some(copilot)
956 } else {
957 None
958 }
959 }
960
961 pub fn context_for_project_with_buffers<'a>(
962 &'a self,
963 project: &Entity<Project>,
964 cx: &'a mut App,
965 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
966 self.projects
967 .get(&project.entity_id())
968 .map(|project| {
969 project.context.update(cx, |context, cx| {
970 context.related_files_with_buffers(cx).collect()
971 })
972 })
973 .unwrap_or_default()
974 }
975
976 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
977 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
978 self.user_store.read(cx).edit_prediction_usage()
979 } else {
980 None
981 }
982 }
983
984 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
985 self.get_or_init_project(project, cx);
986 }
987
988 pub fn register_buffer(
989 &mut self,
990 buffer: &Entity<Buffer>,
991 project: &Entity<Project>,
992 cx: &mut Context<Self>,
993 ) {
994 let project_state = self.get_or_init_project(project, cx);
995 Self::register_buffer_impl(project_state, buffer, project, cx);
996 }
997
998 fn get_or_init_project(
999 &mut self,
1000 project: &Entity<Project>,
1001 cx: &mut Context<Self>,
1002 ) -> &mut ProjectState {
1003 let entity_id = project.entity_id();
1004 self.projects
1005 .entry(entity_id)
1006 .or_insert_with(|| ProjectState {
1007 context: {
1008 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1009 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1010 this.handle_excerpt_store_event(entity_id, event);
1011 })
1012 .detach();
1013 related_excerpt_store
1014 },
1015 events: VecDeque::new(),
1016 last_event: None,
1017 recent_paths: VecDeque::new(),
1018 debug_tx: None,
1019 registered_buffers: HashMap::default(),
1020 current_prediction: None,
1021 cancelled_predictions: HashSet::default(),
1022 pending_predictions: ArrayVec::new(),
1023 next_pending_prediction_id: 0,
1024 last_edit_prediction_refresh: None,
1025 last_jump_prediction_refresh: None,
1026 license_detection_watchers: HashMap::default(),
1027 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1028 _subscriptions: [
1029 cx.subscribe(&project, Self::handle_project_event),
1030 cx.observe_release(&project, move |this, _, cx| {
1031 this.projects.remove(&entity_id);
1032 cx.notify();
1033 }),
1034 ],
1035 copilot: None,
1036 })
1037 }
1038
1039 pub fn remove_project(&mut self, project: &Entity<Project>) {
1040 self.projects.remove(&project.entity_id());
1041 }
1042
1043 fn handle_excerpt_store_event(
1044 &mut self,
1045 project_entity_id: EntityId,
1046 event: &RelatedExcerptStoreEvent,
1047 ) {
1048 if let Some(project_state) = self.projects.get(&project_entity_id) {
1049 if let Some(debug_tx) = project_state.debug_tx.clone() {
1050 match event {
1051 RelatedExcerptStoreEvent::StartedRefresh => {
1052 debug_tx
1053 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1054 ContextRetrievalStartedDebugEvent {
1055 project_entity_id: project_entity_id,
1056 timestamp: Instant::now(),
1057 search_prompt: String::new(),
1058 },
1059 ))
1060 .ok();
1061 }
1062 RelatedExcerptStoreEvent::FinishedRefresh {
1063 cache_hit_count,
1064 cache_miss_count,
1065 mean_definition_latency,
1066 max_definition_latency,
1067 } => {
1068 debug_tx
1069 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1070 ContextRetrievalFinishedDebugEvent {
1071 project_entity_id: project_entity_id,
1072 timestamp: Instant::now(),
1073 metadata: vec![
1074 (
1075 "Cache Hits",
1076 format!(
1077 "{}/{}",
1078 cache_hit_count,
1079 cache_hit_count + cache_miss_count
1080 )
1081 .into(),
1082 ),
1083 (
1084 "Max LSP Time",
1085 format!("{} ms", max_definition_latency.as_millis())
1086 .into(),
1087 ),
1088 (
1089 "Mean LSP Time",
1090 format!("{} ms", mean_definition_latency.as_millis())
1091 .into(),
1092 ),
1093 ],
1094 },
1095 ))
1096 .ok();
1097 }
1098 }
1099 }
1100 }
1101 }
1102
1103 pub fn debug_info(
1104 &mut self,
1105 project: &Entity<Project>,
1106 cx: &mut Context<Self>,
1107 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1108 let project_state = self.get_or_init_project(project, cx);
1109 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1110 project_state.debug_tx = Some(debug_watch_tx);
1111 debug_watch_rx
1112 }
1113
1114 fn handle_project_event(
1115 &mut self,
1116 project: Entity<Project>,
1117 event: &project::Event,
1118 cx: &mut Context<Self>,
1119 ) {
1120 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1121 return;
1122 }
1123 // TODO [zeta2] init with recent paths
1124 match event {
1125 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1126 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1127 return;
1128 };
1129 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1130 if let Some(path) = path {
1131 if let Some(ix) = project_state
1132 .recent_paths
1133 .iter()
1134 .position(|probe| probe == &path)
1135 {
1136 project_state.recent_paths.remove(ix);
1137 }
1138 project_state.recent_paths.push_front(path);
1139 }
1140 }
1141 project::Event::DiagnosticsUpdated { .. } => {
1142 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1143 self.refresh_prediction_from_diagnostics(
1144 project,
1145 DiagnosticSearchScope::Global,
1146 cx,
1147 );
1148 }
1149 }
1150 _ => (),
1151 }
1152 }
1153
1154 fn register_buffer_impl<'a>(
1155 project_state: &'a mut ProjectState,
1156 buffer: &Entity<Buffer>,
1157 project: &Entity<Project>,
1158 cx: &mut Context<Self>,
1159 ) -> &'a mut RegisteredBuffer {
1160 let buffer_id = buffer.entity_id();
1161
1162 if let Some(file) = buffer.read(cx).file() {
1163 let worktree_id = file.worktree_id(cx);
1164 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1165 project_state
1166 .license_detection_watchers
1167 .entry(worktree_id)
1168 .or_insert_with(|| {
1169 let project_entity_id = project.entity_id();
1170 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1171 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1172 else {
1173 return;
1174 };
1175 project_state
1176 .license_detection_watchers
1177 .remove(&worktree_id);
1178 })
1179 .detach();
1180 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1181 });
1182 }
1183 }
1184
1185 match project_state.registered_buffers.entry(buffer_id) {
1186 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1187 hash_map::Entry::Vacant(entry) => {
1188 let buf = buffer.read(cx);
1189 let snapshot = buf.text_snapshot();
1190 let file = buf.file().cloned();
1191 let project_entity_id = project.entity_id();
1192 entry.insert(RegisteredBuffer {
1193 snapshot,
1194 file,
1195 last_position: None,
1196 pending_predictions: Vec::new(),
1197 _subscriptions: [
1198 cx.subscribe(buffer, {
1199 let project = project.downgrade();
1200 move |this, buffer, event, cx| {
1201 if let language::BufferEvent::Edited { .. } = event
1202 && let Some(project) = project.upgrade()
1203 {
1204 this.report_changes_for_buffer(&buffer, &project, false, cx);
1205 }
1206 }
1207 }),
1208 cx.observe_release(buffer, move |this, _buffer, _cx| {
1209 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1210 else {
1211 return;
1212 };
1213 project_state.registered_buffers.remove(&buffer_id);
1214 }),
1215 ],
1216 })
1217 }
1218 }
1219 }
1220
1221 fn report_changes_for_buffer(
1222 &mut self,
1223 buffer: &Entity<Buffer>,
1224 project: &Entity<Project>,
1225 is_predicted: bool,
1226 cx: &mut Context<Self>,
1227 ) {
1228 let project_state = self.get_or_init_project(project, cx);
1229 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1230
1231 let buf = buffer.read(cx);
1232 let new_file = buf.file().cloned();
1233 let new_snapshot = buf.text_snapshot();
1234 if new_snapshot.version == registered_buffer.snapshot.version {
1235 return;
1236 }
1237
1238 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1239 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1240 let mut num_edits = 0usize;
1241 let mut total_deleted = 0usize;
1242 let mut total_inserted = 0usize;
1243 let mut edit_range: Option<Range<Anchor>> = None;
1244 let mut last_offset: Option<usize> = None;
1245 let now = cx.background_executor().now();
1246
1247 for (edit, anchor_range) in
1248 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1249 {
1250 num_edits += 1;
1251 total_deleted += edit.old.len();
1252 total_inserted += edit.new.len();
1253 edit_range = Some(match edit_range {
1254 None => anchor_range,
1255 Some(acc) => acc.start..anchor_range.end,
1256 });
1257 last_offset = Some(edit.new.end);
1258 }
1259
1260 let Some(edit_range) = edit_range else {
1261 return;
1262 };
1263
1264 for pending_prediction in &mut registered_buffer.pending_predictions {
1265 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1266 pending_prediction.last_edit_at = now;
1267 }
1268 }
1269
1270 let action_type = match (total_deleted, total_inserted, num_edits) {
1271 (0, ins, n) if ins == n => UserActionType::InsertChar,
1272 (0, _, _) => UserActionType::InsertSelection,
1273 (del, 0, n) if del == n => UserActionType::DeleteChar,
1274 (_, 0, _) => UserActionType::DeleteSelection,
1275 (_, ins, n) if ins == n => UserActionType::InsertChar,
1276 (_, _, _) => UserActionType::InsertSelection,
1277 };
1278
1279 if let Some(offset) = last_offset {
1280 let point = new_snapshot.offset_to_point(offset);
1281 let timestamp_epoch_ms = SystemTime::now()
1282 .duration_since(UNIX_EPOCH)
1283 .map(|d| d.as_millis() as u64)
1284 .unwrap_or(0);
1285 project_state.record_user_action(UserActionRecord {
1286 action_type,
1287 buffer_id: buffer.entity_id(),
1288 line_number: point.row,
1289 offset,
1290 timestamp_epoch_ms,
1291 });
1292 }
1293
1294 let events = &mut project_state.events;
1295
1296 if let Some(last_event) = project_state.last_event.as_mut() {
1297 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1298 == last_event.new_snapshot.remote_id()
1299 && old_snapshot.version == last_event.new_snapshot.version;
1300
1301 let prediction_source_changed = is_predicted != last_event.predicted;
1302
1303 let should_coalesce = is_next_snapshot_of_same_buffer
1304 && !prediction_source_changed
1305 && last_event
1306 .edit_range
1307 .as_ref()
1308 .is_some_and(|last_edit_range| {
1309 lines_between_ranges(
1310 &edit_range.to_point(&new_snapshot),
1311 &last_edit_range.to_point(&new_snapshot),
1312 ) <= CHANGE_GROUPING_LINE_SPAN
1313 });
1314
1315 if should_coalesce {
1316 let pause_elapsed = last_event
1317 .last_edit_time
1318 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1319 .unwrap_or(false);
1320 if pause_elapsed {
1321 last_event.snapshot_after_last_editing_pause =
1322 Some(last_event.new_snapshot.clone());
1323 }
1324
1325 last_event.edit_range = Some(edit_range);
1326 last_event.new_snapshot = new_snapshot;
1327 last_event.last_edit_time = Some(now);
1328 return;
1329 }
1330 }
1331
1332 if let Some(event) = project_state.last_event.take() {
1333 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1334 if events.len() + 1 >= EVENT_COUNT_MAX {
1335 events.pop_front();
1336 }
1337 events.push_back(event);
1338 }
1339 }
1340
1341 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1342
1343 project_state.last_event = Some(LastEvent {
1344 old_file,
1345 new_file,
1346 old_snapshot,
1347 new_snapshot,
1348 edit_range: Some(edit_range),
1349 predicted: is_predicted,
1350 snapshot_after_last_editing_pause: None,
1351 last_edit_time: Some(now),
1352 });
1353 }
1354
1355 fn prediction_at(
1356 &mut self,
1357 buffer: &Entity<Buffer>,
1358 position: Option<language::Anchor>,
1359 project: &Entity<Project>,
1360 cx: &App,
1361 ) -> Option<BufferEditPrediction<'_>> {
1362 let project_state = self.projects.get_mut(&project.entity_id())?;
1363 if let Some(position) = position
1364 && let Some(buffer) = project_state
1365 .registered_buffers
1366 .get_mut(&buffer.entity_id())
1367 {
1368 buffer.last_position = Some(position);
1369 }
1370
1371 let CurrentEditPrediction {
1372 requested_by,
1373 prediction,
1374 ..
1375 } = project_state.current_prediction.as_ref()?;
1376
1377 if prediction.targets_buffer(buffer.read(cx)) {
1378 Some(BufferEditPrediction::Local { prediction })
1379 } else {
1380 let show_jump = match requested_by {
1381 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1382 requested_by_buffer_id == &buffer.entity_id()
1383 }
1384 PredictionRequestedBy::DiagnosticsUpdate => true,
1385 };
1386
1387 if show_jump {
1388 Some(BufferEditPrediction::Jump { prediction })
1389 } else {
1390 None
1391 }
1392 }
1393 }
1394
1395 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1396 let Some(current_prediction) = self
1397 .projects
1398 .get_mut(&project.entity_id())
1399 .and_then(|project_state| project_state.current_prediction.take())
1400 else {
1401 return;
1402 };
1403
1404 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1405
1406 // can't hold &mut project_state ref across report_changes_for_buffer_call
1407 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1408 return;
1409 };
1410
1411 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1412 project_state.cancel_pending_prediction(pending_prediction, cx);
1413 }
1414
1415 match self.edit_prediction_model {
1416 EditPredictionModel::Sweep => {
1417 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1418 }
1419 EditPredictionModel::Mercury => {
1420 mercury::edit_prediction_accepted(
1421 current_prediction.prediction.id,
1422 self.client.http_client(),
1423 cx,
1424 );
1425 }
1426 EditPredictionModel::Zeta => {
1427 let is_cloud = !matches!(
1428 all_language_settings(None, cx).edit_predictions.provider,
1429 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1430 );
1431 if is_cloud {
1432 zeta::edit_prediction_accepted(self, current_prediction, cx)
1433 }
1434 }
1435 EditPredictionModel::Fim { .. } => {}
1436 }
1437 }
1438
1439 async fn handle_rejected_predictions(
1440 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1441 client: Arc<Client>,
1442 llm_token: LlmApiToken,
1443 app_version: Version,
1444 background_executor: BackgroundExecutor,
1445 ) {
1446 let mut rx = std::pin::pin!(rx.peekable());
1447 let mut batched = Vec::new();
1448
1449 while let Some(EditPredictionRejectionPayload {
1450 rejection,
1451 organization_id,
1452 }) = rx.next().await
1453 {
1454 batched.push(rejection);
1455
1456 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1457 select_biased! {
1458 next = rx.as_mut().peek().fuse() => {
1459 if next.is_some() {
1460 continue;
1461 }
1462 }
1463 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1464 }
1465 }
1466
1467 let url = client
1468 .http_client()
1469 .build_zed_llm_url("/predict_edits/reject", &[])
1470 .unwrap();
1471
1472 let flush_count = batched
1473 .len()
1474 // in case items have accumulated after failure
1475 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1476 let start = batched.len() - flush_count;
1477
1478 let body = RejectEditPredictionsBodyRef {
1479 rejections: &batched[start..],
1480 };
1481
1482 let result = Self::send_api_request::<()>(
1483 |builder| {
1484 let req = builder
1485 .uri(url.as_ref())
1486 .body(serde_json::to_string(&body)?.into());
1487 anyhow::Ok(req?)
1488 },
1489 client.clone(),
1490 llm_token.clone(),
1491 organization_id,
1492 app_version.clone(),
1493 true,
1494 )
1495 .await;
1496
1497 if result.log_err().is_some() {
1498 batched.drain(start..);
1499 }
1500 }
1501 }
1502
1503 async fn run_settled_predictions_worker(
1504 this: WeakEntity<Self>,
1505 mut rx: UnboundedReceiver<Instant>,
1506 cx: &mut AsyncApp,
1507 ) {
1508 let mut next_wake_time: Option<Instant> = None;
1509 loop {
1510 let now = cx.background_executor().now();
1511 if let Some(wake_time) = next_wake_time.take() {
1512 cx.background_executor()
1513 .timer(wake_time.duration_since(now))
1514 .await;
1515 } else {
1516 let Some(new_enqueue_time) = rx.next().await else {
1517 break;
1518 };
1519 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1520 while rx.next().now_or_never().flatten().is_some() {}
1521 continue;
1522 }
1523
1524 let Some(this) = this.upgrade() else {
1525 break;
1526 };
1527
1528 let now = cx.background_executor().now();
1529
1530 let mut oldest_edited_at = None;
1531
1532 this.update(cx, |this, _| {
1533 for (_, project_state) in this.projects.iter_mut() {
1534 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1535 registered_buffer
1536 .pending_predictions
1537 .retain_mut(|pending_prediction| {
1538 let age =
1539 now.saturating_duration_since(pending_prediction.enqueued_at);
1540 if age >= EDIT_PREDICTION_SETTLED_TTL {
1541 return false;
1542 }
1543
1544 let quiet_for =
1545 now.saturating_duration_since(pending_prediction.last_edit_at);
1546 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1547 let settled_editable_region = registered_buffer
1548 .snapshot
1549 .text_for_range(
1550 pending_prediction.editable_anchor_range.clone(),
1551 )
1552 .collect::<String>();
1553
1554 #[cfg(test)]
1555 if let Some(callback) = &this.settled_event_callback {
1556 callback(
1557 pending_prediction.request_id.clone(),
1558 settled_editable_region.clone(),
1559 );
1560 }
1561
1562 telemetry::event!(
1563 EDIT_PREDICTION_SETTLED_EVENT,
1564 request_id = pending_prediction.request_id.0.clone(),
1565 settled_editable_region,
1566 example = pending_prediction.example.take(),
1567 );
1568
1569 return false;
1570 }
1571
1572 if oldest_edited_at
1573 .is_none_or(|t| pending_prediction.last_edit_at < t)
1574 {
1575 oldest_edited_at = Some(pending_prediction.last_edit_at);
1576 }
1577
1578 true
1579 });
1580 }
1581 }
1582 });
1583
1584 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1585 }
1586 }
1587
1588 pub(crate) fn enqueue_settled_prediction(
1589 &mut self,
1590 request_id: EditPredictionId,
1591 project: &Entity<Project>,
1592 edited_buffer: &Entity<Buffer>,
1593 edited_buffer_snapshot: &BufferSnapshot,
1594 editable_offset_range: Range<usize>,
1595 example: Option<ExampleSpec>,
1596 cx: &mut Context<Self>,
1597 ) {
1598 let this = &mut *self;
1599 let project_state = this.get_or_init_project(project, cx);
1600 if let Some(buffer) = project_state
1601 .registered_buffers
1602 .get_mut(&edited_buffer.entity_id())
1603 {
1604 let now = cx.background_executor().now();
1605 buffer.pending_predictions.push(PendingSettledPrediction {
1606 request_id: request_id,
1607 editable_anchor_range: edited_buffer_snapshot
1608 .anchor_range_around(editable_offset_range),
1609 example,
1610 enqueued_at: now,
1611 last_edit_at: now,
1612 });
1613 this.settled_predictions_tx.unbounded_send(now).ok();
1614 }
1615 }
1616
1617 fn reject_current_prediction(
1618 &mut self,
1619 reason: EditPredictionRejectReason,
1620 project: &Entity<Project>,
1621 cx: &App,
1622 ) {
1623 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1624 project_state.pending_predictions.clear();
1625 if let Some(prediction) = project_state.current_prediction.take() {
1626 let model_version = prediction.prediction.model_version.clone();
1627 self.reject_prediction(
1628 prediction.prediction.id,
1629 reason,
1630 prediction.was_shown,
1631 model_version,
1632 cx,
1633 );
1634 }
1635 };
1636 }
1637
1638 fn did_show_current_prediction(
1639 &mut self,
1640 project: &Entity<Project>,
1641 display_type: edit_prediction_types::SuggestionDisplayType,
1642 cx: &mut Context<Self>,
1643 ) {
1644 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1645 return;
1646 };
1647
1648 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1649 return;
1650 };
1651
1652 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1653 let previous_shown_with = current_prediction.shown_with;
1654
1655 if previous_shown_with.is_none() || !is_jump {
1656 current_prediction.shown_with = Some(display_type);
1657 }
1658
1659 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1660
1661 if is_first_non_jump_show {
1662 current_prediction.was_shown = true;
1663 }
1664
1665 let display_type_changed = previous_shown_with != Some(display_type);
1666
1667 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1668 sweep_ai::edit_prediction_shown(
1669 &self.sweep_ai,
1670 self.client.clone(),
1671 ¤t_prediction.prediction,
1672 display_type,
1673 cx,
1674 );
1675 }
1676
1677 if is_first_non_jump_show {
1678 self.shown_predictions
1679 .push_front(current_prediction.prediction.clone());
1680 if self.shown_predictions.len() > 50 {
1681 let completion = self.shown_predictions.pop_back().unwrap();
1682 self.rated_predictions.remove(&completion.id);
1683 }
1684 }
1685 }
1686
1687 fn reject_prediction(
1688 &mut self,
1689 prediction_id: EditPredictionId,
1690 reason: EditPredictionRejectReason,
1691 was_shown: bool,
1692 model_version: Option<String>,
1693 cx: &App,
1694 ) {
1695 match self.edit_prediction_model {
1696 EditPredictionModel::Zeta => {
1697 let is_cloud = !matches!(
1698 all_language_settings(None, cx).edit_predictions.provider,
1699 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1700 );
1701
1702 if is_cloud {
1703 let organization_id = self
1704 .user_store
1705 .read(cx)
1706 .current_organization()
1707 .map(|organization| organization.id.clone());
1708
1709 self.reject_predictions_tx
1710 .unbounded_send(EditPredictionRejectionPayload {
1711 rejection: EditPredictionRejection {
1712 request_id: prediction_id.to_string(),
1713 reason,
1714 was_shown,
1715 model_version,
1716 },
1717 organization_id,
1718 })
1719 .log_err();
1720 }
1721 }
1722 EditPredictionModel::Mercury => {
1723 mercury::edit_prediction_rejected(
1724 prediction_id,
1725 was_shown,
1726 reason,
1727 self.client.http_client(),
1728 cx,
1729 );
1730 }
1731 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1732 }
1733 }
1734
1735 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1736 self.projects
1737 .get(&project.entity_id())
1738 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1739 }
1740
1741 pub fn refresh_prediction_from_buffer(
1742 &mut self,
1743 project: Entity<Project>,
1744 buffer: Entity<Buffer>,
1745 position: language::Anchor,
1746 cx: &mut Context<Self>,
1747 ) {
1748 self.queue_prediction_refresh(
1749 project.clone(),
1750 PredictEditsRequestTrigger::Other,
1751 buffer.entity_id(),
1752 cx,
1753 move |this, cx| {
1754 let Some(request_task) = this
1755 .update(cx, |this, cx| {
1756 this.request_prediction(
1757 &project,
1758 &buffer,
1759 position,
1760 PredictEditsRequestTrigger::Other,
1761 cx,
1762 )
1763 })
1764 .log_err()
1765 else {
1766 return Task::ready(anyhow::Ok(None));
1767 };
1768
1769 cx.spawn(async move |_cx| {
1770 request_task.await.map(|prediction_result| {
1771 prediction_result.map(|prediction_result| {
1772 (
1773 prediction_result,
1774 PredictionRequestedBy::Buffer(buffer.entity_id()),
1775 )
1776 })
1777 })
1778 })
1779 },
1780 )
1781 }
1782
1783 pub fn refresh_prediction_from_diagnostics(
1784 &mut self,
1785 project: Entity<Project>,
1786 scope: DiagnosticSearchScope,
1787 cx: &mut Context<Self>,
1788 ) {
1789 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1790 return;
1791 }
1792
1793 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1794 return;
1795 };
1796
1797 // Prefer predictions from buffer
1798 if project_state.current_prediction.is_some() {
1799 log::debug!(
1800 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1801 );
1802 return;
1803 }
1804
1805 self.queue_prediction_refresh(
1806 project.clone(),
1807 PredictEditsRequestTrigger::Diagnostics,
1808 project.entity_id(),
1809 cx,
1810 move |this, cx| {
1811 let Some((active_buffer, snapshot, cursor_point)) = this
1812 .read_with(cx, |this, cx| {
1813 let project_state = this.projects.get(&project.entity_id())?;
1814 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1815 let snapshot = buffer.read(cx).snapshot();
1816
1817 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1818 return None;
1819 }
1820
1821 let cursor_point = position
1822 .map(|pos| pos.to_point(&snapshot))
1823 .unwrap_or_default();
1824
1825 Some((buffer, snapshot, cursor_point))
1826 })
1827 .log_err()
1828 .flatten()
1829 else {
1830 return Task::ready(anyhow::Ok(None));
1831 };
1832
1833 cx.spawn(async move |cx| {
1834 let diagnostic_search_range = match scope {
1835 DiagnosticSearchScope::Local => {
1836 let diagnostic_search_start =
1837 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1838 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1839 Point::new(diagnostic_search_start, 0)
1840 ..Point::new(diagnostic_search_end, 0)
1841 }
1842 DiagnosticSearchScope::Global => Default::default(),
1843 };
1844
1845 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1846 active_buffer,
1847 &snapshot,
1848 diagnostic_search_range,
1849 cursor_point,
1850 &project,
1851 cx,
1852 )
1853 .await?
1854 else {
1855 return anyhow::Ok(None);
1856 };
1857
1858 let Some(prediction_result) = this
1859 .update(cx, |this, cx| {
1860 this.request_prediction(
1861 &project,
1862 &jump_buffer,
1863 jump_position,
1864 PredictEditsRequestTrigger::Diagnostics,
1865 cx,
1866 )
1867 })?
1868 .await?
1869 else {
1870 return anyhow::Ok(None);
1871 };
1872
1873 this.update(cx, |this, cx| {
1874 Some((
1875 if this
1876 .get_or_init_project(&project, cx)
1877 .current_prediction
1878 .is_none()
1879 {
1880 prediction_result
1881 } else {
1882 EditPredictionResult {
1883 id: prediction_result.id,
1884 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1885 }
1886 },
1887 PredictionRequestedBy::DiagnosticsUpdate,
1888 ))
1889 })
1890 })
1891 },
1892 );
1893 }
1894
1895 fn predictions_enabled_at(
1896 snapshot: &BufferSnapshot,
1897 position: Option<language::Anchor>,
1898 cx: &App,
1899 ) -> bool {
1900 let file = snapshot.file();
1901 let all_settings = all_language_settings(file, cx);
1902 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1903 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1904 {
1905 return false;
1906 }
1907
1908 if let Some(last_position) = position {
1909 let settings = snapshot.settings_at(last_position, cx);
1910
1911 if !settings.edit_predictions_disabled_in.is_empty()
1912 && let Some(scope) = snapshot.language_scope_at(last_position)
1913 && let Some(scope_name) = scope.override_name()
1914 && settings
1915 .edit_predictions_disabled_in
1916 .iter()
1917 .any(|s| s == scope_name)
1918 {
1919 return false;
1920 }
1921 }
1922
1923 true
1924 }
1925
1926 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1927}
1928
1929fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1930 match provider {
1931 EditPredictionProvider::Zed
1932 | EditPredictionProvider::Sweep
1933 | EditPredictionProvider::Mercury
1934 | EditPredictionProvider::Ollama
1935 | EditPredictionProvider::OpenAiCompatibleApi
1936 | EditPredictionProvider::Experimental(_) => true,
1937 EditPredictionProvider::None
1938 | EditPredictionProvider::Copilot
1939 | EditPredictionProvider::Codestral => false,
1940 }
1941}
1942
1943impl EditPredictionStore {
1944 fn queue_prediction_refresh(
1945 &mut self,
1946 project: Entity<Project>,
1947 request_trigger: PredictEditsRequestTrigger,
1948 throttle_entity: EntityId,
1949 cx: &mut Context<Self>,
1950 do_refresh: impl FnOnce(
1951 WeakEntity<Self>,
1952 &mut AsyncApp,
1953 )
1954 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1955 + 'static,
1956 ) {
1957 fn select_throttle(
1958 project_state: &mut ProjectState,
1959 request_trigger: PredictEditsRequestTrigger,
1960 ) -> &mut Option<(EntityId, Instant)> {
1961 match request_trigger {
1962 PredictEditsRequestTrigger::Diagnostics => {
1963 &mut project_state.last_jump_prediction_refresh
1964 }
1965 _ => &mut project_state.last_edit_prediction_refresh,
1966 }
1967 }
1968
1969 let (needs_acceptance_tracking, max_pending_predictions) =
1970 match all_language_settings(None, cx).edit_predictions.provider {
1971 EditPredictionProvider::Zed
1972 | EditPredictionProvider::Sweep
1973 | EditPredictionProvider::Mercury
1974 | EditPredictionProvider::Experimental(_) => (true, 2),
1975 EditPredictionProvider::Ollama => (false, 1),
1976 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1977 EditPredictionProvider::None
1978 | EditPredictionProvider::Copilot
1979 | EditPredictionProvider::Codestral => {
1980 log::error!("queue_prediction_refresh called with non-store provider");
1981 return;
1982 }
1983 };
1984
1985 let drop_on_cancel = !needs_acceptance_tracking;
1986 let throttle_timeout = Self::THROTTLE_TIMEOUT;
1987 let project_state = self.get_or_init_project(&project, cx);
1988 let pending_prediction_id = project_state.next_pending_prediction_id;
1989 project_state.next_pending_prediction_id += 1;
1990 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
1991
1992 let task = cx.spawn(async move |this, cx| {
1993 let throttle_wait = this
1994 .update(cx, |this, cx| {
1995 let project_state = this.get_or_init_project(&project, cx);
1996 let throttle = *select_throttle(project_state, request_trigger);
1997
1998 throttle.and_then(|(last_entity, last_timestamp)| {
1999 if throttle_entity != last_entity {
2000 return None;
2001 }
2002 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2003 })
2004 })
2005 .ok()
2006 .flatten();
2007
2008 if let Some(timeout) = throttle_wait {
2009 cx.background_executor().timer(timeout).await;
2010 }
2011
2012 // If this task was cancelled before the throttle timeout expired,
2013 // do not perform a request. Also skip if another task already
2014 // proceeded since we were enqueued (duplicate).
2015 let mut is_cancelled = true;
2016 this.update(cx, |this, cx| {
2017 let project_state = this.get_or_init_project(&project, cx);
2018 let was_cancelled = project_state
2019 .cancelled_predictions
2020 .remove(&pending_prediction_id);
2021 if was_cancelled {
2022 return;
2023 }
2024
2025 // Another request has been already sent since this was enqueued
2026 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2027 return;
2028 }
2029
2030 let new_refresh = (throttle_entity, Instant::now());
2031 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2032 is_cancelled = false;
2033 })
2034 .ok();
2035 if is_cancelled {
2036 return None;
2037 }
2038
2039 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2040 let new_prediction_id = new_prediction_result
2041 .as_ref()
2042 .map(|(prediction, _)| prediction.id.clone());
2043
2044 // When a prediction completes, remove it from the pending list, and cancel
2045 // any pending predictions that were enqueued before it.
2046 this.update(cx, |this, cx| {
2047 let project_state = this.get_or_init_project(&project, cx);
2048
2049 let is_cancelled = project_state
2050 .cancelled_predictions
2051 .remove(&pending_prediction_id);
2052
2053 let new_current_prediction = if !is_cancelled
2054 && let Some((prediction_result, requested_by)) = new_prediction_result
2055 {
2056 match prediction_result.prediction {
2057 Ok(prediction) => {
2058 let new_prediction = CurrentEditPrediction {
2059 requested_by,
2060 prediction,
2061 was_shown: false,
2062 shown_with: None,
2063 };
2064
2065 if let Some(current_prediction) =
2066 project_state.current_prediction.as_ref()
2067 {
2068 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2069 {
2070 this.reject_current_prediction(
2071 EditPredictionRejectReason::Replaced,
2072 &project,
2073 cx,
2074 );
2075
2076 Some(new_prediction)
2077 } else {
2078 this.reject_prediction(
2079 new_prediction.prediction.id,
2080 EditPredictionRejectReason::CurrentPreferred,
2081 false,
2082 new_prediction.prediction.model_version,
2083 cx,
2084 );
2085 None
2086 }
2087 } else {
2088 Some(new_prediction)
2089 }
2090 }
2091 Err(reject_reason) => {
2092 this.reject_prediction(
2093 prediction_result.id,
2094 reject_reason,
2095 false,
2096 None,
2097 cx,
2098 );
2099 None
2100 }
2101 }
2102 } else {
2103 None
2104 };
2105
2106 let project_state = this.get_or_init_project(&project, cx);
2107
2108 if let Some(new_prediction) = new_current_prediction {
2109 project_state.current_prediction = Some(new_prediction);
2110 }
2111
2112 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2113 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2114 if pending_prediction.id == pending_prediction_id {
2115 pending_predictions.remove(ix);
2116 for pending_prediction in pending_predictions.drain(0..ix) {
2117 project_state.cancel_pending_prediction(pending_prediction, cx)
2118 }
2119 break;
2120 }
2121 }
2122 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2123 cx.notify();
2124 })
2125 .ok();
2126
2127 new_prediction_id
2128 });
2129
2130 if project_state.pending_predictions.len() < max_pending_predictions {
2131 project_state.pending_predictions.push(PendingPrediction {
2132 id: pending_prediction_id,
2133 task,
2134 drop_on_cancel,
2135 });
2136 } else {
2137 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2138 project_state.pending_predictions.push(PendingPrediction {
2139 id: pending_prediction_id,
2140 task,
2141 drop_on_cancel,
2142 });
2143 project_state.cancel_pending_prediction(pending_prediction, cx);
2144 }
2145 }
2146
2147 pub fn request_prediction(
2148 &mut self,
2149 project: &Entity<Project>,
2150 active_buffer: &Entity<Buffer>,
2151 position: language::Anchor,
2152 trigger: PredictEditsRequestTrigger,
2153 cx: &mut Context<Self>,
2154 ) -> Task<Result<Option<EditPredictionResult>>> {
2155 self.request_prediction_internal(
2156 project.clone(),
2157 active_buffer.clone(),
2158 position,
2159 trigger,
2160 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2161 cx,
2162 )
2163 }
2164
2165 fn request_prediction_internal(
2166 &mut self,
2167 project: Entity<Project>,
2168 active_buffer: Entity<Buffer>,
2169 position: language::Anchor,
2170 trigger: PredictEditsRequestTrigger,
2171 allow_jump: bool,
2172 cx: &mut Context<Self>,
2173 ) -> Task<Result<Option<EditPredictionResult>>> {
2174 self.get_or_init_project(&project, cx);
2175 let project_state = self.projects.get(&project.entity_id()).unwrap();
2176 let stored_events = project_state.events(cx);
2177 let has_events = !stored_events.is_empty();
2178 let events: Vec<Arc<zeta_prompt::Event>> =
2179 stored_events.iter().map(|e| e.event.clone()).collect();
2180 let debug_tx = project_state.debug_tx.clone();
2181
2182 let snapshot = active_buffer.read(cx).snapshot();
2183 let cursor_point = position.to_point(&snapshot);
2184 let current_offset = position.to_offset(&snapshot);
2185
2186 let mut user_actions: Vec<UserActionRecord> =
2187 project_state.user_actions.iter().cloned().collect();
2188
2189 if let Some(last_action) = user_actions.last() {
2190 if last_action.buffer_id == active_buffer.entity_id()
2191 && current_offset != last_action.offset
2192 {
2193 let timestamp_epoch_ms = SystemTime::now()
2194 .duration_since(UNIX_EPOCH)
2195 .map(|d| d.as_millis() as u64)
2196 .unwrap_or(0);
2197 user_actions.push(UserActionRecord {
2198 action_type: UserActionType::CursorMovement,
2199 buffer_id: active_buffer.entity_id(),
2200 line_number: cursor_point.row,
2201 offset: current_offset,
2202 timestamp_epoch_ms,
2203 });
2204 }
2205 }
2206 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2207 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2208 let diagnostic_search_range =
2209 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2210
2211 let related_files = self.context_for_project(&project, cx);
2212
2213 let is_open_source = snapshot
2214 .file()
2215 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2216 && events.iter().all(|event| event.in_open_source_repo())
2217 && related_files.iter().all(|file| file.in_open_source_repo);
2218
2219 let can_collect_data = !cfg!(test)
2220 && is_open_source
2221 && self.is_data_collection_enabled(cx)
2222 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2223
2224 let recent_paths = project_state.recent_paths.clone();
2225
2226 let inputs = EditPredictionModelInput {
2227 project: project.clone(),
2228 buffer: active_buffer,
2229 snapshot,
2230 position,
2231 events,
2232 related_files,
2233 recent_paths,
2234 trigger,
2235 diagnostic_search_range: diagnostic_search_range,
2236 debug_tx,
2237 user_actions,
2238 can_collect_data,
2239 is_open_source,
2240 };
2241
2242 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2243
2244 let task = match self.edit_prediction_model {
2245 EditPredictionModel::Zeta => {
2246 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2247 }
2248 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2249 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2250 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2251 };
2252
2253 cx.spawn(async move |this, cx| {
2254 let prediction = task.await?;
2255
2256 // Only fall back to diagnostics-based prediction if we got a
2257 // the model had nothing to suggest for the buffer
2258 if prediction.is_none()
2259 && allow_jump
2260 && has_events
2261 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2262 {
2263 this.update(cx, |this, cx| {
2264 this.refresh_prediction_from_diagnostics(
2265 project,
2266 DiagnosticSearchScope::Local,
2267 cx,
2268 );
2269 })?;
2270 return anyhow::Ok(None);
2271 }
2272
2273 Ok(prediction)
2274 })
2275 }
2276
2277 pub(crate) async fn next_diagnostic_location(
2278 active_buffer: Entity<Buffer>,
2279 active_buffer_snapshot: &BufferSnapshot,
2280 active_buffer_diagnostic_search_range: Range<Point>,
2281 active_buffer_cursor_point: Point,
2282 project: &Entity<Project>,
2283 cx: &mut AsyncApp,
2284 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2285 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2286 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2287 .flat_map(|(_, _, _, selections)| {
2288 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2289 })
2290 .collect();
2291
2292 let mut jump_location = active_buffer_snapshot
2293 .diagnostic_groups(None)
2294 .into_iter()
2295 .filter_map(|(_, group)| {
2296 let range = &group.entries[group.primary_ix]
2297 .range
2298 .to_point(&active_buffer_snapshot);
2299 if range.overlaps(&active_buffer_diagnostic_search_range) {
2300 return None;
2301 }
2302 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2303 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2304 });
2305 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2306 <= DIAGNOSTIC_LINES_RANGE;
2307 if near_collaborator && !near_local {
2308 return None;
2309 }
2310 Some(range.start)
2311 })
2312 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2313 .map(|position| {
2314 (
2315 active_buffer.clone(),
2316 active_buffer_snapshot.anchor_before(position),
2317 )
2318 });
2319
2320 if jump_location.is_none() {
2321 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2322 let file = buffer.file()?;
2323
2324 Some(ProjectPath {
2325 worktree_id: file.worktree_id(cx),
2326 path: file.path().clone(),
2327 })
2328 });
2329
2330 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2331 project
2332 .diagnostic_summaries(false, cx)
2333 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2334 .map(|(path, _, _)| {
2335 let shared_prefix = path
2336 .path
2337 .components()
2338 .zip(
2339 active_buffer_path
2340 .as_ref()
2341 .map(|p| p.path.components())
2342 .unwrap_or_default(),
2343 )
2344 .take_while(|(a, b)| a == b)
2345 .count();
2346 (path, shared_prefix)
2347 })
2348 .collect()
2349 });
2350
2351 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2352
2353 for (path, _) in candidates {
2354 let candidate_buffer = project
2355 .update(cx, |project, cx| project.open_buffer(path, cx))
2356 .await?;
2357
2358 let (has_collaborators, diagnostic_position) =
2359 candidate_buffer.read_with(cx, |buffer, _cx| {
2360 let snapshot = buffer.snapshot();
2361 let has_collaborators = snapshot
2362 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2363 .next()
2364 .is_some();
2365 let position = buffer
2366 .buffer_diagnostics(None)
2367 .into_iter()
2368 .min_by_key(|entry| entry.diagnostic.severity)
2369 .map(|entry| entry.range.start);
2370 (has_collaborators, position)
2371 });
2372
2373 if has_collaborators {
2374 continue;
2375 }
2376
2377 if let Some(position) = diagnostic_position {
2378 jump_location = Some((candidate_buffer, position));
2379 break;
2380 }
2381 }
2382 }
2383
2384 anyhow::Ok(jump_location)
2385 }
2386
2387 async fn send_raw_llm_request(
2388 request: RawCompletionRequest,
2389 client: Arc<Client>,
2390 custom_url: Option<Arc<Url>>,
2391 llm_token: LlmApiToken,
2392 organization_id: Option<OrganizationId>,
2393 app_version: Version,
2394 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2395 let url = if let Some(custom_url) = custom_url {
2396 custom_url.as_ref().clone()
2397 } else {
2398 client
2399 .http_client()
2400 .build_zed_llm_url("/predict_edits/raw", &[])?
2401 };
2402
2403 Self::send_api_request(
2404 |builder| {
2405 let req = builder
2406 .uri(url.as_ref())
2407 .body(serde_json::to_string(&request)?.into());
2408 Ok(req?)
2409 },
2410 client,
2411 llm_token,
2412 organization_id,
2413 app_version,
2414 true,
2415 )
2416 .await
2417 }
2418
2419 pub(crate) async fn send_v3_request(
2420 input: ZetaPromptInput,
2421 client: Arc<Client>,
2422 llm_token: LlmApiToken,
2423 organization_id: Option<OrganizationId>,
2424 app_version: Version,
2425 trigger: PredictEditsRequestTrigger,
2426 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2427 let url = client
2428 .http_client()
2429 .build_zed_llm_url("/predict_edits/v3", &[])?;
2430
2431 let request = PredictEditsV3Request { input, trigger };
2432
2433 let json_bytes = serde_json::to_vec(&request)?;
2434 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2435
2436 Self::send_api_request(
2437 |builder| {
2438 let req = builder
2439 .uri(url.as_ref())
2440 .header("Content-Encoding", "zstd")
2441 .body(compressed.clone().into());
2442 Ok(req?)
2443 },
2444 client,
2445 llm_token,
2446 organization_id,
2447 app_version,
2448 true,
2449 )
2450 .await
2451 }
2452
2453 async fn send_api_request<Res>(
2454 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2455 client: Arc<Client>,
2456 llm_token: LlmApiToken,
2457 organization_id: Option<OrganizationId>,
2458 app_version: Version,
2459 require_auth: bool,
2460 ) -> Result<(Res, Option<EditPredictionUsage>)>
2461 where
2462 Res: DeserializeOwned,
2463 {
2464 let http_client = client.http_client();
2465
2466 let mut token = if require_auth {
2467 Some(llm_token.acquire(&client, organization_id.clone()).await?)
2468 } else {
2469 llm_token
2470 .acquire(&client, organization_id.clone())
2471 .await
2472 .ok()
2473 };
2474 let mut did_retry = false;
2475
2476 loop {
2477 let request_builder = http_client::Request::builder().method(Method::POST);
2478
2479 let mut request_builder = request_builder
2480 .header("Content-Type", "application/json")
2481 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2482
2483 // Only add Authorization header if we have a token
2484 if let Some(ref token_value) = token {
2485 request_builder =
2486 request_builder.header("Authorization", format!("Bearer {}", token_value));
2487 }
2488
2489 let request = build(request_builder)?;
2490
2491 let mut response = http_client.send(request).await?;
2492
2493 if let Some(minimum_required_version) = response
2494 .headers()
2495 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2496 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2497 {
2498 anyhow::ensure!(
2499 app_version >= minimum_required_version,
2500 ZedUpdateRequiredError {
2501 minimum_version: minimum_required_version
2502 }
2503 );
2504 }
2505
2506 if response.status().is_success() {
2507 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2508
2509 let mut body = Vec::new();
2510 response.body_mut().read_to_end(&mut body).await?;
2511 return Ok((serde_json::from_slice(&body)?, usage));
2512 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2513 did_retry = true;
2514 token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2515 } else {
2516 let mut body = String::new();
2517 response.body_mut().read_to_string(&mut body).await?;
2518 anyhow::bail!(
2519 "Request failed with status: {:?}\nBody: {}",
2520 response.status(),
2521 body
2522 );
2523 }
2524 }
2525 }
2526
2527 pub fn refresh_context(
2528 &mut self,
2529 project: &Entity<Project>,
2530 buffer: &Entity<language::Buffer>,
2531 cursor_position: language::Anchor,
2532 cx: &mut Context<Self>,
2533 ) {
2534 self.get_or_init_project(project, cx)
2535 .context
2536 .update(cx, |store, cx| {
2537 store.refresh(buffer.clone(), cursor_position, cx);
2538 });
2539 }
2540
2541 #[cfg(feature = "cli-support")]
2542 pub fn set_context_for_buffer(
2543 &mut self,
2544 project: &Entity<Project>,
2545 related_files: Vec<RelatedFile>,
2546 cx: &mut Context<Self>,
2547 ) {
2548 self.get_or_init_project(project, cx)
2549 .context
2550 .update(cx, |store, cx| {
2551 store.set_related_files(related_files, cx);
2552 });
2553 }
2554
2555 #[cfg(feature = "cli-support")]
2556 pub fn set_recent_paths_for_project(
2557 &mut self,
2558 project: &Entity<Project>,
2559 paths: impl IntoIterator<Item = project::ProjectPath>,
2560 cx: &mut Context<Self>,
2561 ) {
2562 let project_state = self.get_or_init_project(project, cx);
2563 project_state.recent_paths = paths.into_iter().collect();
2564 }
2565
2566 fn is_file_open_source(
2567 &self,
2568 project: &Entity<Project>,
2569 file: &Arc<dyn File>,
2570 cx: &App,
2571 ) -> bool {
2572 if !file.is_local() || file.is_private() {
2573 return false;
2574 }
2575 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2576 return false;
2577 };
2578 project_state
2579 .license_detection_watchers
2580 .get(&file.worktree_id(cx))
2581 .as_ref()
2582 .is_some_and(|watcher| watcher.is_project_open_source())
2583 }
2584
2585 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2586 self.data_collection_choice.is_enabled(cx)
2587 }
2588
2589 fn load_data_collection_choice() -> DataCollectionChoice {
2590 let choice = KEY_VALUE_STORE
2591 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2592 .log_err()
2593 .flatten();
2594
2595 match choice.as_deref() {
2596 Some("true") => DataCollectionChoice::Enabled,
2597 Some("false") => DataCollectionChoice::Disabled,
2598 Some(_) => {
2599 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2600 DataCollectionChoice::NotAnswered
2601 }
2602 None => DataCollectionChoice::NotAnswered,
2603 }
2604 }
2605
2606 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2607 self.data_collection_choice = self.data_collection_choice.toggle();
2608 let new_choice = self.data_collection_choice;
2609 let is_enabled = new_choice.is_enabled(cx);
2610 db::write_and_log(cx, move || {
2611 KEY_VALUE_STORE.write_kvp(
2612 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2613 is_enabled.to_string(),
2614 )
2615 });
2616 }
2617
2618 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2619 self.shown_predictions.iter()
2620 }
2621
2622 pub fn shown_completions_len(&self) -> usize {
2623 self.shown_predictions.len()
2624 }
2625
2626 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2627 self.rated_predictions.contains(id)
2628 }
2629
2630 pub fn rate_prediction(
2631 &mut self,
2632 prediction: &EditPrediction,
2633 rating: EditPredictionRating,
2634 feedback: String,
2635 cx: &mut Context<Self>,
2636 ) {
2637 let organization = self.user_store.read(cx).current_organization();
2638
2639 self.rated_predictions.insert(prediction.id.clone());
2640
2641 cx.background_spawn({
2642 let client = self.client.clone();
2643 let prediction_id = prediction.id.to_string();
2644 let inputs = serde_json::to_value(&prediction.inputs);
2645 let output = prediction
2646 .edit_preview
2647 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2648 async move {
2649 client
2650 .cloud_client()
2651 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2652 organization_id: organization.map(|organization| organization.id.clone()),
2653 request_id: prediction_id,
2654 rating: match rating {
2655 EditPredictionRating::Positive => "positive".to_string(),
2656 EditPredictionRating::Negative => "negative".to_string(),
2657 },
2658 inputs: inputs?,
2659 output,
2660 feedback,
2661 })
2662 .await?;
2663
2664 anyhow::Ok(())
2665 }
2666 })
2667 .detach_and_log_err(cx);
2668
2669 cx.notify();
2670 }
2671}
2672
2673fn merge_trailing_events_if_needed(
2674 events: &mut VecDeque<StoredEvent>,
2675 end_snapshot: &TextBufferSnapshot,
2676 latest_snapshot: &TextBufferSnapshot,
2677 latest_edit_range: &Range<Anchor>,
2678) {
2679 if let Some(last_event) = events.back() {
2680 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2681 return;
2682 }
2683 }
2684
2685 let mut next_old_event = None;
2686 let mut mergeable_count = 0;
2687 for old_event in events.iter().rev() {
2688 if let Some(next_old_event) = &next_old_event
2689 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2690 {
2691 break;
2692 }
2693 mergeable_count += 1;
2694 next_old_event = Some(old_event);
2695 }
2696
2697 if mergeable_count <= 1 {
2698 return;
2699 }
2700
2701 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2702 let oldest_event = events_to_merge.peek().unwrap();
2703 let oldest_snapshot = oldest_event.old_snapshot.clone();
2704
2705 if let Some((diff, edited_range)) =
2706 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2707 {
2708 let merged_event = match oldest_event.event.as_ref() {
2709 zeta_prompt::Event::BufferChange {
2710 old_path,
2711 path,
2712 in_open_source_repo,
2713 ..
2714 } => StoredEvent {
2715 event: Arc::new(zeta_prompt::Event::BufferChange {
2716 old_path: old_path.clone(),
2717 path: path.clone(),
2718 diff,
2719 in_open_source_repo: *in_open_source_repo,
2720 predicted: events_to_merge.all(|e| {
2721 matches!(
2722 e.event.as_ref(),
2723 zeta_prompt::Event::BufferChange {
2724 predicted: true,
2725 ..
2726 }
2727 )
2728 }),
2729 }),
2730 old_snapshot: oldest_snapshot.clone(),
2731 edit_range: end_snapshot.anchor_before(edited_range.start)
2732 ..end_snapshot.anchor_before(edited_range.end),
2733 },
2734 };
2735 events.truncate(events.len() - mergeable_count);
2736 events.push_back(merged_event);
2737 }
2738}
2739
2740#[derive(Error, Debug)]
2741#[error(
2742 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2743)]
2744pub struct ZedUpdateRequiredError {
2745 minimum_version: Version,
2746}
2747
2748#[derive(Debug, Clone, Copy)]
2749pub enum DataCollectionChoice {
2750 NotAnswered,
2751 Enabled,
2752 Disabled,
2753}
2754
2755impl DataCollectionChoice {
2756 pub fn is_enabled(self, cx: &App) -> bool {
2757 if cx.is_staff() {
2758 return true;
2759 }
2760 match self {
2761 Self::Enabled => true,
2762 Self::NotAnswered | Self::Disabled => false,
2763 }
2764 }
2765
2766 #[must_use]
2767 pub fn toggle(&self) -> DataCollectionChoice {
2768 match self {
2769 Self::Enabled => Self::Disabled,
2770 Self::Disabled => Self::Enabled,
2771 Self::NotAnswered => Self::Enabled,
2772 }
2773 }
2774}
2775
2776impl From<bool> for DataCollectionChoice {
2777 fn from(value: bool) -> Self {
2778 match value {
2779 true => DataCollectionChoice::Enabled,
2780 false => DataCollectionChoice::Disabled,
2781 }
2782 }
2783}
2784
2785struct ZedPredictUpsell;
2786
2787impl Dismissable for ZedPredictUpsell {
2788 const KEY: &'static str = "dismissed-edit-predict-upsell";
2789
2790 fn dismissed() -> bool {
2791 // To make this backwards compatible with older versions of Zed, we
2792 // check if the user has seen the previous Edit Prediction Onboarding
2793 // before, by checking the data collection choice which was written to
2794 // the database once the user clicked on "Accept and Enable"
2795 if KEY_VALUE_STORE
2796 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2797 .log_err()
2798 .is_some_and(|s| s.is_some())
2799 {
2800 return true;
2801 }
2802
2803 KEY_VALUE_STORE
2804 .read_kvp(Self::KEY)
2805 .log_err()
2806 .is_some_and(|s| s.is_some())
2807 }
2808}
2809
2810pub fn should_show_upsell_modal() -> bool {
2811 !ZedPredictUpsell::dismissed()
2812}
2813
2814pub fn init(cx: &mut App) {
2815 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2816 workspace.register_action(
2817 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2818 ZedPredictModal::toggle(
2819 workspace,
2820 workspace.user_store().clone(),
2821 workspace.client().clone(),
2822 window,
2823 cx,
2824 )
2825 },
2826 );
2827
2828 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2829 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2830 settings
2831 .project
2832 .all_languages
2833 .edit_predictions
2834 .get_or_insert_default()
2835 .provider = Some(EditPredictionProvider::None)
2836 });
2837 });
2838 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2839 EditPredictionStore::try_global(cx).and_then(|store| {
2840 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2841 })
2842 }
2843
2844 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2845 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2846 copilot_ui::initiate_sign_in(copilot, window, cx);
2847 }
2848 });
2849 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2850 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2851 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2852 }
2853 });
2854 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2855 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2856 copilot_ui::initiate_sign_out(copilot, window, cx);
2857 }
2858 });
2859 })
2860 .detach();
2861}