1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56
57pub mod cursor_excerpt;
58pub mod example_spec;
59pub mod fim;
60mod license_detection;
61pub mod mercury;
62pub mod ollama;
63mod onboarding_modal;
64pub mod open_ai_response;
65mod prediction;
66pub mod sweep_ai;
67
68pub mod udiff;
69
70mod capture_example;
71pub mod open_ai_compatible;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::example_spec::ExampleSpec;
79use crate::license_detection::LicenseDetectionWatcher;
80use crate::mercury::Mercury;
81use crate::onboarding_modal::ZedPredictModal;
82pub use crate::prediction::EditPrediction;
83pub use crate::prediction::EditPredictionId;
84use crate::prediction::EditPredictionResult;
85pub use crate::sweep_ai::SweepAi;
86pub use capture_example::capture_example;
87pub use language_model::ApiKeyState;
88pub use telemetry_events::EditPredictionRating;
89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
90
91actions!(
92 edit_prediction,
93 [
94 /// Resets the edit prediction onboarding state.
95 ResetOnboarding,
96 /// Clears the edit prediction history.
97 ClearHistory,
98 ]
99);
100
101/// Maximum number of events to track.
102const EVENT_COUNT_MAX: usize = 6;
103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
107const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
108const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
109const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
110
111pub struct EditPredictionJumpsFeatureFlag;
112
113impl FeatureFlag for EditPredictionJumpsFeatureFlag {
114 const NAME: &'static str = "edit_prediction_jumps";
115}
116
117#[derive(Clone)]
118struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
119
120impl Global for EditPredictionStoreGlobal {}
121
122/// Configuration for using the raw Zeta2 endpoint.
123/// When set, the client uses the raw endpoint and constructs the prompt itself.
124/// The version is also used as the Baseten environment name (lowercased).
125#[derive(Clone)]
126pub struct Zeta2RawConfig {
127 pub model_id: Option<String>,
128 pub environment: Option<String>,
129 pub format: ZetaFormat,
130}
131
132pub struct EditPredictionStore {
133 client: Arc<Client>,
134 user_store: Entity<UserStore>,
135 llm_token: LlmApiToken,
136 _llm_token_subscription: Subscription,
137 _fetch_experiments_task: Task<()>,
138 projects: HashMap<EntityId, ProjectState>,
139 update_required: bool,
140 edit_prediction_model: EditPredictionModel,
141 zeta2_raw_config: Option<Zeta2RawConfig>,
142 preferred_experiment: Option<String>,
143 available_experiments: Vec<String>,
144 pub sweep_ai: SweepAi,
145 pub mercury: Mercury,
146 data_collection_choice: DataCollectionChoice,
147 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
148 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
149 shown_predictions: VecDeque<EditPrediction>,
150 rated_predictions: HashSet<EditPredictionId>,
151 #[cfg(test)]
152 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
153}
154
155pub(crate) struct EditPredictionRejectionPayload {
156 rejection: EditPredictionRejection,
157 organization_id: Option<OrganizationId>,
158}
159
160#[derive(Copy, Clone, PartialEq, Eq)]
161pub enum EditPredictionModel {
162 Zeta,
163 Fim { format: EditPredictionPromptFormat },
164 Sweep,
165 Mercury,
166}
167
168#[derive(Clone)]
169pub struct EditPredictionModelInput {
170 project: Entity<Project>,
171 buffer: Entity<Buffer>,
172 snapshot: BufferSnapshot,
173 position: Anchor,
174 events: Vec<Arc<zeta_prompt::Event>>,
175 related_files: Vec<RelatedFile>,
176 recent_paths: VecDeque<ProjectPath>,
177 trigger: PredictEditsRequestTrigger,
178 diagnostic_search_range: Range<Point>,
179 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
180 can_collect_data: bool,
181 is_open_source: bool,
182 pub user_actions: Vec<UserActionRecord>,
183}
184
185#[derive(Debug)]
186pub enum DebugEvent {
187 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
188 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
189 EditPredictionStarted(EditPredictionStartedDebugEvent),
190 EditPredictionFinished(EditPredictionFinishedDebugEvent),
191}
192
193#[derive(Debug)]
194pub struct ContextRetrievalStartedDebugEvent {
195 pub project_entity_id: EntityId,
196 pub timestamp: Instant,
197 pub search_prompt: String,
198}
199
200#[derive(Debug)]
201pub struct ContextRetrievalFinishedDebugEvent {
202 pub project_entity_id: EntityId,
203 pub timestamp: Instant,
204 pub metadata: Vec<(&'static str, SharedString)>,
205}
206
207#[derive(Debug)]
208pub struct EditPredictionStartedDebugEvent {
209 pub buffer: WeakEntity<Buffer>,
210 pub position: Anchor,
211 pub prompt: Option<String>,
212}
213
214#[derive(Debug)]
215pub struct EditPredictionFinishedDebugEvent {
216 pub buffer: WeakEntity<Buffer>,
217 pub position: Anchor,
218 pub model_output: Option<String>,
219}
220
221const USER_ACTION_HISTORY_SIZE: usize = 16;
222
223#[derive(Clone, Debug)]
224pub struct UserActionRecord {
225 pub action_type: UserActionType,
226 pub buffer_id: EntityId,
227 pub line_number: u32,
228 pub offset: usize,
229 pub timestamp_epoch_ms: u64,
230}
231
232#[derive(Clone, Copy, Debug, PartialEq, Eq)]
233pub enum UserActionType {
234 InsertChar,
235 InsertSelection,
236 DeleteChar,
237 DeleteSelection,
238 CursorMovement,
239}
240
241/// An event with associated metadata for reconstructing buffer state.
242#[derive(Clone)]
243pub struct StoredEvent {
244 pub event: Arc<zeta_prompt::Event>,
245 pub old_snapshot: TextBufferSnapshot,
246 pub edit_range: Range<Anchor>,
247}
248
249impl StoredEvent {
250 fn can_merge(
251 &self,
252 next_old_event: &&&StoredEvent,
253 new_snapshot: &TextBufferSnapshot,
254 last_edit_range: &Range<Anchor>,
255 ) -> bool {
256 // Events must be for the same buffer
257 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
258 return false;
259 }
260 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
261 return false;
262 }
263
264 let a_is_predicted = matches!(
265 self.event.as_ref(),
266 zeta_prompt::Event::BufferChange {
267 predicted: true,
268 ..
269 }
270 );
271 let b_is_predicted = matches!(
272 next_old_event.event.as_ref(),
273 zeta_prompt::Event::BufferChange {
274 predicted: true,
275 ..
276 }
277 );
278
279 // If events come from the same source (both predicted or both manual) then
280 // we would have coalesced them already.
281 if a_is_predicted == b_is_predicted {
282 return false;
283 }
284
285 let left_range = self.edit_range.to_point(new_snapshot);
286 let right_range = next_old_event.edit_range.to_point(new_snapshot);
287 let latest_range = last_edit_range.to_point(&new_snapshot);
288
289 // Events near to the latest edit are not merged if their sources differ.
290 if lines_between_ranges(&left_range, &latest_range)
291 .min(lines_between_ranges(&right_range, &latest_range))
292 <= CHANGE_GROUPING_LINE_SPAN
293 {
294 return false;
295 }
296
297 // Events that are distant from each other are not merged.
298 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
299 return false;
300 }
301
302 true
303 }
304}
305
306fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
307 if left.start > right.end {
308 return left.start.row - right.end.row;
309 }
310 if right.start > left.end {
311 return right.start.row - left.end.row;
312 }
313 0
314}
315
316struct ProjectState {
317 events: VecDeque<StoredEvent>,
318 last_event: Option<LastEvent>,
319 recent_paths: VecDeque<ProjectPath>,
320 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
321 current_prediction: Option<CurrentEditPrediction>,
322 next_pending_prediction_id: usize,
323 pending_predictions: ArrayVec<PendingPrediction, 2>,
324 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
325 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
326 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
327 cancelled_predictions: HashSet<usize>,
328 context: Entity<RelatedExcerptStore>,
329 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
330 user_actions: VecDeque<UserActionRecord>,
331 _subscriptions: [gpui::Subscription; 2],
332 copilot: Option<Entity<Copilot>>,
333}
334
335impl ProjectState {
336 fn record_user_action(&mut self, action: UserActionRecord) {
337 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
338 self.user_actions.pop_front();
339 }
340 self.user_actions.push_back(action);
341 }
342
343 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
344 self.events
345 .iter()
346 .cloned()
347 .chain(self.last_event.as_ref().iter().flat_map(|event| {
348 let (one, two) = event.split_by_pause();
349 let one = one.finalize(&self.license_detection_watchers, cx);
350 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
351 one.into_iter().chain(two)
352 }))
353 .collect()
354 }
355
356 fn cancel_pending_prediction(
357 &mut self,
358 pending_prediction: PendingPrediction,
359 cx: &mut Context<EditPredictionStore>,
360 ) {
361 self.cancelled_predictions.insert(pending_prediction.id);
362
363 if pending_prediction.drop_on_cancel {
364 drop(pending_prediction.task);
365 } else {
366 cx.spawn(async move |this, cx| {
367 let Some(prediction_id) = pending_prediction.task.await else {
368 return;
369 };
370
371 this.update(cx, |this, cx| {
372 this.reject_prediction(
373 prediction_id,
374 EditPredictionRejectReason::Canceled,
375 false,
376 None,
377 cx,
378 );
379 })
380 .ok();
381 })
382 .detach()
383 }
384 }
385
386 fn active_buffer(
387 &self,
388 project: &Entity<Project>,
389 cx: &App,
390 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
391 let project = project.read(cx);
392 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
393 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
394 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
395 Some((active_buffer, registered_buffer.last_position))
396 }
397}
398
399#[derive(Debug, Clone)]
400struct CurrentEditPrediction {
401 pub requested_by: PredictionRequestedBy,
402 pub prediction: EditPrediction,
403 pub was_shown: bool,
404 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
405}
406
407impl CurrentEditPrediction {
408 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
409 let Some(new_edits) = self
410 .prediction
411 .interpolate(&self.prediction.buffer.read(cx))
412 else {
413 return false;
414 };
415
416 if self.prediction.buffer != old_prediction.prediction.buffer {
417 return true;
418 }
419
420 let Some(old_edits) = old_prediction
421 .prediction
422 .interpolate(&old_prediction.prediction.buffer.read(cx))
423 else {
424 return true;
425 };
426
427 let requested_by_buffer_id = self.requested_by.buffer_id();
428
429 // This reduces the occurrence of UI thrash from replacing edits
430 //
431 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
432 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
433 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
434 && old_edits.len() == 1
435 && new_edits.len() == 1
436 {
437 let (old_range, old_text) = &old_edits[0];
438 let (new_range, new_text) = &new_edits[0];
439 new_range == old_range && new_text.starts_with(old_text.as_ref())
440 } else {
441 true
442 }
443 }
444}
445
446#[derive(Debug, Clone)]
447enum PredictionRequestedBy {
448 DiagnosticsUpdate,
449 Buffer(EntityId),
450}
451
452impl PredictionRequestedBy {
453 pub fn buffer_id(&self) -> Option<EntityId> {
454 match self {
455 PredictionRequestedBy::DiagnosticsUpdate => None,
456 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
457 }
458 }
459}
460
461const DIAGNOSTIC_LINES_RANGE: u32 = 20;
462
463#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
464pub enum DiagnosticSearchScope {
465 Local,
466 Global,
467}
468
469#[derive(Debug)]
470struct PendingPrediction {
471 id: usize,
472 task: Task<Option<EditPredictionId>>,
473 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
474 /// If false, the task is awaited to completion so rejection can be reported.
475 drop_on_cancel: bool,
476}
477
478/// A prediction from the perspective of a buffer.
479#[derive(Debug)]
480enum BufferEditPrediction<'a> {
481 Local { prediction: &'a EditPrediction },
482 Jump { prediction: &'a EditPrediction },
483}
484
485#[cfg(test)]
486impl std::ops::Deref for BufferEditPrediction<'_> {
487 type Target = EditPrediction;
488
489 fn deref(&self) -> &Self::Target {
490 match self {
491 BufferEditPrediction::Local { prediction } => prediction,
492 BufferEditPrediction::Jump { prediction } => prediction,
493 }
494 }
495}
496
497#[derive(Clone)]
498struct PendingSettledPrediction {
499 request_id: EditPredictionId,
500 editable_anchor_range: Range<Anchor>,
501 example: Option<ExampleSpec>,
502 enqueued_at: Instant,
503 last_edit_at: Instant,
504}
505
506struct RegisteredBuffer {
507 file: Option<Arc<dyn File>>,
508 snapshot: TextBufferSnapshot,
509 pending_predictions: Vec<PendingSettledPrediction>,
510 last_position: Option<Anchor>,
511 _subscriptions: [gpui::Subscription; 2],
512}
513
514#[derive(Clone)]
515struct LastEvent {
516 old_snapshot: TextBufferSnapshot,
517 new_snapshot: TextBufferSnapshot,
518 old_file: Option<Arc<dyn File>>,
519 new_file: Option<Arc<dyn File>>,
520 edit_range: Option<Range<Anchor>>,
521 predicted: bool,
522 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
523 last_edit_time: Option<Instant>,
524}
525
526impl LastEvent {
527 pub fn finalize(
528 &self,
529 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
530 cx: &App,
531 ) -> Option<StoredEvent> {
532 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
533 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
534
535 let in_open_source_repo =
536 [self.new_file.as_ref(), self.old_file.as_ref()]
537 .iter()
538 .all(|file| {
539 file.is_some_and(|file| {
540 license_detection_watchers
541 .get(&file.worktree_id(cx))
542 .is_some_and(|watcher| watcher.is_project_open_source())
543 })
544 });
545
546 let (diff, edit_range) =
547 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
548
549 if path == old_path && diff.is_empty() {
550 None
551 } else {
552 Some(StoredEvent {
553 event: Arc::new(zeta_prompt::Event::BufferChange {
554 old_path,
555 path,
556 diff,
557 in_open_source_repo,
558 predicted: self.predicted,
559 }),
560 edit_range: self.new_snapshot.anchor_before(edit_range.start)
561 ..self.new_snapshot.anchor_before(edit_range.end),
562 old_snapshot: self.old_snapshot.clone(),
563 })
564 }
565 }
566
567 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
568 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
569 return (self.clone(), None);
570 };
571
572 let before = LastEvent {
573 old_snapshot: self.old_snapshot.clone(),
574 new_snapshot: boundary_snapshot.clone(),
575 old_file: self.old_file.clone(),
576 new_file: self.new_file.clone(),
577 edit_range: None,
578 predicted: self.predicted,
579 snapshot_after_last_editing_pause: None,
580 last_edit_time: self.last_edit_time,
581 };
582
583 let after = LastEvent {
584 old_snapshot: boundary_snapshot.clone(),
585 new_snapshot: self.new_snapshot.clone(),
586 old_file: self.old_file.clone(),
587 new_file: self.new_file.clone(),
588 edit_range: None,
589 predicted: self.predicted,
590 snapshot_after_last_editing_pause: None,
591 last_edit_time: self.last_edit_time,
592 };
593
594 (before, Some(after))
595 }
596}
597
598pub(crate) fn compute_diff_between_snapshots(
599 old_snapshot: &TextBufferSnapshot,
600 new_snapshot: &TextBufferSnapshot,
601) -> Option<(String, Range<Point>)> {
602 let edits: Vec<Edit<usize>> = new_snapshot
603 .edits_since::<usize>(&old_snapshot.version)
604 .collect();
605
606 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
607
608 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
609 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
610 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
611 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
612
613 const CONTEXT_LINES: u32 = 3;
614
615 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
616 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
617 let old_context_end_row =
618 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
619 let new_context_end_row =
620 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
621
622 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
623 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
624 let old_end_line_offset = old_snapshot
625 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
626 let new_end_line_offset = new_snapshot
627 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
628 let old_edit_range = old_start_line_offset..old_end_line_offset;
629 let new_edit_range = new_start_line_offset..new_end_line_offset;
630
631 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
632 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
633
634 let diff = language::unified_diff_with_offsets(
635 &old_region_text,
636 &new_region_text,
637 old_context_start_row,
638 new_context_start_row,
639 );
640
641 Some((diff, new_start_point..new_end_point))
642}
643
644fn buffer_path_with_id_fallback(
645 file: Option<&Arc<dyn File>>,
646 snapshot: &TextBufferSnapshot,
647 cx: &App,
648) -> Arc<Path> {
649 if let Some(file) = file {
650 file.full_path(cx).into()
651 } else {
652 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
653 }
654}
655
656impl EditPredictionStore {
657 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
658 cx.try_global::<EditPredictionStoreGlobal>()
659 .map(|global| global.0.clone())
660 }
661
662 pub fn global(
663 client: &Arc<Client>,
664 user_store: &Entity<UserStore>,
665 cx: &mut App,
666 ) -> Entity<Self> {
667 cx.try_global::<EditPredictionStoreGlobal>()
668 .map(|global| global.0.clone())
669 .unwrap_or_else(|| {
670 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
671 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
672 ep_store
673 })
674 }
675
676 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
677 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
678 let data_collection_choice = Self::load_data_collection_choice();
679
680 let llm_token = LlmApiToken::default();
681
682 let (reject_tx, reject_rx) = mpsc::unbounded();
683 cx.background_spawn({
684 let client = client.clone();
685 let llm_token = llm_token.clone();
686 let app_version = AppVersion::global(cx);
687 let background_executor = cx.background_executor().clone();
688 async move {
689 Self::handle_rejected_predictions(
690 reject_rx,
691 client,
692 llm_token,
693 app_version,
694 background_executor,
695 )
696 .await
697 }
698 })
699 .detach();
700
701 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
702 cx.spawn(async move |this, cx| {
703 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
704 })
705 .detach();
706
707 let mut current_user = user_store.read(cx).watch_current_user();
708 let fetch_experiments_task = cx.spawn(async move |this, cx| {
709 while current_user.borrow().is_none() {
710 current_user.next().await;
711 }
712 this.update(cx, |this, cx| {
713 this.refresh_available_experiments(cx);
714 })
715 .log_err();
716 });
717
718 let this = Self {
719 projects: HashMap::default(),
720 client,
721 user_store,
722 llm_token,
723 _fetch_experiments_task: fetch_experiments_task,
724 _llm_token_subscription: cx.subscribe(
725 &refresh_llm_token_listener,
726 |this, _listener, _event, cx| {
727 let client = this.client.clone();
728 let llm_token = this.llm_token.clone();
729 let organization_id = this
730 .user_store
731 .read(cx)
732 .current_organization()
733 .map(|organization| organization.id.clone());
734 cx.spawn(async move |_this, _cx| {
735 llm_token.refresh(&client, organization_id).await?;
736 anyhow::Ok(())
737 })
738 .detach_and_log_err(cx);
739 },
740 ),
741 update_required: false,
742 edit_prediction_model: EditPredictionModel::Zeta,
743 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
744 preferred_experiment: None,
745 available_experiments: Vec::new(),
746 sweep_ai: SweepAi::new(cx),
747 mercury: Mercury::new(cx),
748
749 data_collection_choice,
750 reject_predictions_tx: reject_tx,
751 settled_predictions_tx,
752 rated_predictions: Default::default(),
753 shown_predictions: Default::default(),
754 #[cfg(test)]
755 settled_event_callback: None,
756 };
757
758 this
759 }
760
761 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
762 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
763 let format = ZetaFormat::parse(&version_str).ok()?;
764 let model_id = env::var("ZED_ZETA_MODEL").ok();
765 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
766 Some(Zeta2RawConfig {
767 model_id,
768 environment,
769 format,
770 })
771 }
772
773 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
774 self.edit_prediction_model = model;
775 }
776
777 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
778 self.zeta2_raw_config = Some(config);
779 }
780
781 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
782 self.zeta2_raw_config.as_ref()
783 }
784
785 pub fn preferred_experiment(&self) -> Option<&str> {
786 self.preferred_experiment.as_deref()
787 }
788
789 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
790 self.preferred_experiment = experiment;
791 }
792
793 pub fn available_experiments(&self) -> &[String] {
794 &self.available_experiments
795 }
796
797 pub fn active_experiment(&self) -> Option<&str> {
798 self.preferred_experiment.as_deref().or_else(|| {
799 self.shown_predictions
800 .iter()
801 .find_map(|p| p.model_version.as_ref())
802 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
803 })
804 }
805
806 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
807 let client = self.client.clone();
808 let llm_token = self.llm_token.clone();
809 let app_version = AppVersion::global(cx);
810 let organization_id = self
811 .user_store
812 .read(cx)
813 .current_organization()
814 .map(|organization| organization.id.clone());
815
816 cx.spawn(async move |this, cx| {
817 let experiments = cx
818 .background_spawn(async move {
819 let http_client = client.http_client();
820 let token = llm_token.acquire(&client, organization_id).await?;
821 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
822 let request = http_client::Request::builder()
823 .method(Method::GET)
824 .uri(url.as_ref())
825 .header("Authorization", format!("Bearer {}", token))
826 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
827 .body(Default::default())?;
828 let mut response = http_client.send(request).await?;
829 if response.status().is_success() {
830 let mut body = Vec::new();
831 response.body_mut().read_to_end(&mut body).await?;
832 let experiments: Vec<String> = serde_json::from_slice(&body)?;
833 Ok(experiments)
834 } else {
835 let mut body = String::new();
836 response.body_mut().read_to_string(&mut body).await?;
837 anyhow::bail!(
838 "Failed to fetch experiments: {:?}\nBody: {}",
839 response.status(),
840 body
841 );
842 }
843 })
844 .await?;
845 this.update(cx, |this, cx| {
846 this.available_experiments = experiments;
847 cx.notify();
848 })?;
849 anyhow::Ok(())
850 })
851 .detach_and_log_err(cx);
852 }
853
854 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
855 use ui::IconName;
856 match self.edit_prediction_model {
857 EditPredictionModel::Sweep => {
858 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
859 .with_disabled(IconName::SweepAiDisabled)
860 .with_up(IconName::SweepAiUp)
861 .with_down(IconName::SweepAiDown)
862 .with_error(IconName::SweepAiError)
863 }
864 EditPredictionModel::Mercury => {
865 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
866 }
867 EditPredictionModel::Zeta => {
868 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
869 .with_disabled(IconName::ZedPredictDisabled)
870 .with_up(IconName::ZedPredictUp)
871 .with_down(IconName::ZedPredictDown)
872 .with_error(IconName::ZedPredictError)
873 }
874 EditPredictionModel::Fim { .. } => {
875 let settings = &all_language_settings(None, cx).edit_predictions;
876 match settings.provider {
877 EditPredictionProvider::Ollama => {
878 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
879 }
880 _ => {
881 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
882 }
883 }
884 }
885 }
886 }
887
888 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
889 self.sweep_ai.api_token.read(cx).has_key()
890 }
891
892 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
893 self.mercury.api_token.read(cx).has_key()
894 }
895
896 pub fn clear_history(&mut self) {
897 for project_state in self.projects.values_mut() {
898 project_state.events.clear();
899 project_state.last_event.take();
900 }
901 }
902
903 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
904 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
905 project_state.events.clear();
906 project_state.last_event.take();
907 }
908 }
909
910 pub fn edit_history_for_project(
911 &self,
912 project: &Entity<Project>,
913 cx: &App,
914 ) -> Vec<StoredEvent> {
915 self.projects
916 .get(&project.entity_id())
917 .map(|project_state| project_state.events(cx))
918 .unwrap_or_default()
919 }
920
921 pub fn context_for_project<'a>(
922 &'a self,
923 project: &Entity<Project>,
924 cx: &'a mut App,
925 ) -> Vec<RelatedFile> {
926 self.projects
927 .get(&project.entity_id())
928 .map(|project_state| {
929 project_state.context.update(cx, |context, cx| {
930 context
931 .related_files_with_buffers(cx)
932 .map(|(mut related_file, buffer)| {
933 related_file.in_open_source_repo = buffer
934 .read(cx)
935 .file()
936 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
937 related_file
938 })
939 .collect()
940 })
941 })
942 .unwrap_or_default()
943 }
944
945 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
946 self.projects
947 .get(&project.entity_id())
948 .and_then(|project| project.copilot.clone())
949 }
950
951 pub fn start_copilot_for_project(
952 &mut self,
953 project: &Entity<Project>,
954 cx: &mut Context<Self>,
955 ) -> Option<Entity<Copilot>> {
956 if DisableAiSettings::get(None, cx).disable_ai {
957 return None;
958 }
959 let state = self.get_or_init_project(project, cx);
960
961 if state.copilot.is_some() {
962 return state.copilot.clone();
963 }
964 let _project = project.clone();
965 let project = project.read(cx);
966
967 let node = project.node_runtime().cloned();
968 if let Some(node) = node {
969 let next_id = project.languages().next_language_server_id();
970 let fs = project.fs().clone();
971
972 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
973 state.copilot = Some(copilot.clone());
974 Some(copilot)
975 } else {
976 None
977 }
978 }
979
980 pub fn context_for_project_with_buffers<'a>(
981 &'a self,
982 project: &Entity<Project>,
983 cx: &'a mut App,
984 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
985 self.projects
986 .get(&project.entity_id())
987 .map(|project| {
988 project.context.update(cx, |context, cx| {
989 context.related_files_with_buffers(cx).collect()
990 })
991 })
992 .unwrap_or_default()
993 }
994
995 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
996 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
997 self.user_store.read(cx).edit_prediction_usage()
998 } else {
999 None
1000 }
1001 }
1002
1003 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1004 self.get_or_init_project(project, cx);
1005 }
1006
1007 pub fn register_buffer(
1008 &mut self,
1009 buffer: &Entity<Buffer>,
1010 project: &Entity<Project>,
1011 cx: &mut Context<Self>,
1012 ) {
1013 let project_state = self.get_or_init_project(project, cx);
1014 Self::register_buffer_impl(project_state, buffer, project, cx);
1015 }
1016
1017 fn get_or_init_project(
1018 &mut self,
1019 project: &Entity<Project>,
1020 cx: &mut Context<Self>,
1021 ) -> &mut ProjectState {
1022 let entity_id = project.entity_id();
1023 self.projects
1024 .entry(entity_id)
1025 .or_insert_with(|| ProjectState {
1026 context: {
1027 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1028 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1029 this.handle_excerpt_store_event(entity_id, event);
1030 })
1031 .detach();
1032 related_excerpt_store
1033 },
1034 events: VecDeque::new(),
1035 last_event: None,
1036 recent_paths: VecDeque::new(),
1037 debug_tx: None,
1038 registered_buffers: HashMap::default(),
1039 current_prediction: None,
1040 cancelled_predictions: HashSet::default(),
1041 pending_predictions: ArrayVec::new(),
1042 next_pending_prediction_id: 0,
1043 last_edit_prediction_refresh: None,
1044 last_jump_prediction_refresh: None,
1045 license_detection_watchers: HashMap::default(),
1046 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1047 _subscriptions: [
1048 cx.subscribe(&project, Self::handle_project_event),
1049 cx.observe_release(&project, move |this, _, cx| {
1050 this.projects.remove(&entity_id);
1051 cx.notify();
1052 }),
1053 ],
1054 copilot: None,
1055 })
1056 }
1057
1058 pub fn remove_project(&mut self, project: &Entity<Project>) {
1059 self.projects.remove(&project.entity_id());
1060 }
1061
1062 fn handle_excerpt_store_event(
1063 &mut self,
1064 project_entity_id: EntityId,
1065 event: &RelatedExcerptStoreEvent,
1066 ) {
1067 if let Some(project_state) = self.projects.get(&project_entity_id) {
1068 if let Some(debug_tx) = project_state.debug_tx.clone() {
1069 match event {
1070 RelatedExcerptStoreEvent::StartedRefresh => {
1071 debug_tx
1072 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1073 ContextRetrievalStartedDebugEvent {
1074 project_entity_id: project_entity_id,
1075 timestamp: Instant::now(),
1076 search_prompt: String::new(),
1077 },
1078 ))
1079 .ok();
1080 }
1081 RelatedExcerptStoreEvent::FinishedRefresh {
1082 cache_hit_count,
1083 cache_miss_count,
1084 mean_definition_latency,
1085 max_definition_latency,
1086 } => {
1087 debug_tx
1088 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1089 ContextRetrievalFinishedDebugEvent {
1090 project_entity_id: project_entity_id,
1091 timestamp: Instant::now(),
1092 metadata: vec![
1093 (
1094 "Cache Hits",
1095 format!(
1096 "{}/{}",
1097 cache_hit_count,
1098 cache_hit_count + cache_miss_count
1099 )
1100 .into(),
1101 ),
1102 (
1103 "Max LSP Time",
1104 format!("{} ms", max_definition_latency.as_millis())
1105 .into(),
1106 ),
1107 (
1108 "Mean LSP Time",
1109 format!("{} ms", mean_definition_latency.as_millis())
1110 .into(),
1111 ),
1112 ],
1113 },
1114 ))
1115 .ok();
1116 }
1117 }
1118 }
1119 }
1120 }
1121
1122 pub fn debug_info(
1123 &mut self,
1124 project: &Entity<Project>,
1125 cx: &mut Context<Self>,
1126 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1127 let project_state = self.get_or_init_project(project, cx);
1128 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1129 project_state.debug_tx = Some(debug_watch_tx);
1130 debug_watch_rx
1131 }
1132
1133 fn handle_project_event(
1134 &mut self,
1135 project: Entity<Project>,
1136 event: &project::Event,
1137 cx: &mut Context<Self>,
1138 ) {
1139 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1140 return;
1141 }
1142 // TODO [zeta2] init with recent paths
1143 match event {
1144 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1145 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1146 return;
1147 };
1148 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1149 if let Some(path) = path {
1150 if let Some(ix) = project_state
1151 .recent_paths
1152 .iter()
1153 .position(|probe| probe == &path)
1154 {
1155 project_state.recent_paths.remove(ix);
1156 }
1157 project_state.recent_paths.push_front(path);
1158 }
1159 }
1160 project::Event::DiagnosticsUpdated { .. } => {
1161 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1162 self.refresh_prediction_from_diagnostics(
1163 project,
1164 DiagnosticSearchScope::Global,
1165 cx,
1166 );
1167 }
1168 }
1169 _ => (),
1170 }
1171 }
1172
1173 fn register_buffer_impl<'a>(
1174 project_state: &'a mut ProjectState,
1175 buffer: &Entity<Buffer>,
1176 project: &Entity<Project>,
1177 cx: &mut Context<Self>,
1178 ) -> &'a mut RegisteredBuffer {
1179 let buffer_id = buffer.entity_id();
1180
1181 if let Some(file) = buffer.read(cx).file() {
1182 let worktree_id = file.worktree_id(cx);
1183 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1184 project_state
1185 .license_detection_watchers
1186 .entry(worktree_id)
1187 .or_insert_with(|| {
1188 let project_entity_id = project.entity_id();
1189 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1190 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1191 else {
1192 return;
1193 };
1194 project_state
1195 .license_detection_watchers
1196 .remove(&worktree_id);
1197 })
1198 .detach();
1199 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1200 });
1201 }
1202 }
1203
1204 match project_state.registered_buffers.entry(buffer_id) {
1205 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1206 hash_map::Entry::Vacant(entry) => {
1207 let buf = buffer.read(cx);
1208 let snapshot = buf.text_snapshot();
1209 let file = buf.file().cloned();
1210 let project_entity_id = project.entity_id();
1211 entry.insert(RegisteredBuffer {
1212 snapshot,
1213 file,
1214 last_position: None,
1215 pending_predictions: Vec::new(),
1216 _subscriptions: [
1217 cx.subscribe(buffer, {
1218 let project = project.downgrade();
1219 move |this, buffer, event, cx| {
1220 if let language::BufferEvent::Edited = event
1221 && let Some(project) = project.upgrade()
1222 {
1223 this.report_changes_for_buffer(&buffer, &project, false, cx);
1224 }
1225 }
1226 }),
1227 cx.observe_release(buffer, move |this, _buffer, _cx| {
1228 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1229 else {
1230 return;
1231 };
1232 project_state.registered_buffers.remove(&buffer_id);
1233 }),
1234 ],
1235 })
1236 }
1237 }
1238 }
1239
1240 fn report_changes_for_buffer(
1241 &mut self,
1242 buffer: &Entity<Buffer>,
1243 project: &Entity<Project>,
1244 is_predicted: bool,
1245 cx: &mut Context<Self>,
1246 ) {
1247 let project_state = self.get_or_init_project(project, cx);
1248 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1249
1250 let buf = buffer.read(cx);
1251 let new_file = buf.file().cloned();
1252 let new_snapshot = buf.text_snapshot();
1253 if new_snapshot.version == registered_buffer.snapshot.version {
1254 return;
1255 }
1256
1257 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1258 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1259 let mut num_edits = 0usize;
1260 let mut total_deleted = 0usize;
1261 let mut total_inserted = 0usize;
1262 let mut edit_range: Option<Range<Anchor>> = None;
1263 let mut last_offset: Option<usize> = None;
1264 let now = cx.background_executor().now();
1265
1266 for (edit, anchor_range) in
1267 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1268 {
1269 num_edits += 1;
1270 total_deleted += edit.old.len();
1271 total_inserted += edit.new.len();
1272 edit_range = Some(match edit_range {
1273 None => anchor_range,
1274 Some(acc) => acc.start..anchor_range.end,
1275 });
1276 last_offset = Some(edit.new.end);
1277 }
1278
1279 let Some(edit_range) = edit_range else {
1280 return;
1281 };
1282
1283 for pending_prediction in &mut registered_buffer.pending_predictions {
1284 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1285 pending_prediction.last_edit_at = now;
1286 }
1287 }
1288
1289 let action_type = match (total_deleted, total_inserted, num_edits) {
1290 (0, ins, n) if ins == n => UserActionType::InsertChar,
1291 (0, _, _) => UserActionType::InsertSelection,
1292 (del, 0, n) if del == n => UserActionType::DeleteChar,
1293 (_, 0, _) => UserActionType::DeleteSelection,
1294 (_, ins, n) if ins == n => UserActionType::InsertChar,
1295 (_, _, _) => UserActionType::InsertSelection,
1296 };
1297
1298 if let Some(offset) = last_offset {
1299 let point = new_snapshot.offset_to_point(offset);
1300 let timestamp_epoch_ms = SystemTime::now()
1301 .duration_since(UNIX_EPOCH)
1302 .map(|d| d.as_millis() as u64)
1303 .unwrap_or(0);
1304 project_state.record_user_action(UserActionRecord {
1305 action_type,
1306 buffer_id: buffer.entity_id(),
1307 line_number: point.row,
1308 offset,
1309 timestamp_epoch_ms,
1310 });
1311 }
1312
1313 let events = &mut project_state.events;
1314
1315 if let Some(last_event) = project_state.last_event.as_mut() {
1316 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1317 == last_event.new_snapshot.remote_id()
1318 && old_snapshot.version == last_event.new_snapshot.version;
1319
1320 let prediction_source_changed = is_predicted != last_event.predicted;
1321
1322 let should_coalesce = is_next_snapshot_of_same_buffer
1323 && !prediction_source_changed
1324 && last_event
1325 .edit_range
1326 .as_ref()
1327 .is_some_and(|last_edit_range| {
1328 lines_between_ranges(
1329 &edit_range.to_point(&new_snapshot),
1330 &last_edit_range.to_point(&new_snapshot),
1331 ) <= CHANGE_GROUPING_LINE_SPAN
1332 });
1333
1334 if should_coalesce {
1335 let pause_elapsed = last_event
1336 .last_edit_time
1337 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1338 .unwrap_or(false);
1339 if pause_elapsed {
1340 last_event.snapshot_after_last_editing_pause =
1341 Some(last_event.new_snapshot.clone());
1342 }
1343
1344 last_event.edit_range = Some(edit_range);
1345 last_event.new_snapshot = new_snapshot;
1346 last_event.last_edit_time = Some(now);
1347 return;
1348 }
1349 }
1350
1351 if let Some(event) = project_state.last_event.take() {
1352 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1353 if events.len() + 1 >= EVENT_COUNT_MAX {
1354 events.pop_front();
1355 }
1356 events.push_back(event);
1357 }
1358 }
1359
1360 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1361
1362 project_state.last_event = Some(LastEvent {
1363 old_file,
1364 new_file,
1365 old_snapshot,
1366 new_snapshot,
1367 edit_range: Some(edit_range),
1368 predicted: is_predicted,
1369 snapshot_after_last_editing_pause: None,
1370 last_edit_time: Some(now),
1371 });
1372 }
1373
1374 fn prediction_at(
1375 &mut self,
1376 buffer: &Entity<Buffer>,
1377 position: Option<language::Anchor>,
1378 project: &Entity<Project>,
1379 cx: &App,
1380 ) -> Option<BufferEditPrediction<'_>> {
1381 let project_state = self.projects.get_mut(&project.entity_id())?;
1382 if let Some(position) = position
1383 && let Some(buffer) = project_state
1384 .registered_buffers
1385 .get_mut(&buffer.entity_id())
1386 {
1387 buffer.last_position = Some(position);
1388 }
1389
1390 let CurrentEditPrediction {
1391 requested_by,
1392 prediction,
1393 ..
1394 } = project_state.current_prediction.as_ref()?;
1395
1396 if prediction.targets_buffer(buffer.read(cx)) {
1397 Some(BufferEditPrediction::Local { prediction })
1398 } else {
1399 let show_jump = match requested_by {
1400 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1401 requested_by_buffer_id == &buffer.entity_id()
1402 }
1403 PredictionRequestedBy::DiagnosticsUpdate => true,
1404 };
1405
1406 if show_jump {
1407 Some(BufferEditPrediction::Jump { prediction })
1408 } else {
1409 None
1410 }
1411 }
1412 }
1413
1414 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1415 let Some(current_prediction) = self
1416 .projects
1417 .get_mut(&project.entity_id())
1418 .and_then(|project_state| project_state.current_prediction.take())
1419 else {
1420 return;
1421 };
1422
1423 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1424
1425 // can't hold &mut project_state ref across report_changes_for_buffer_call
1426 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1427 return;
1428 };
1429
1430 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1431 project_state.cancel_pending_prediction(pending_prediction, cx);
1432 }
1433
1434 match self.edit_prediction_model {
1435 EditPredictionModel::Sweep => {
1436 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1437 }
1438 EditPredictionModel::Mercury => {
1439 mercury::edit_prediction_accepted(
1440 current_prediction.prediction.id,
1441 self.client.http_client(),
1442 cx,
1443 );
1444 }
1445 EditPredictionModel::Zeta => {
1446 let is_cloud = !matches!(
1447 all_language_settings(None, cx).edit_predictions.provider,
1448 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1449 );
1450 if is_cloud {
1451 zeta::edit_prediction_accepted(self, current_prediction, cx)
1452 }
1453 }
1454 EditPredictionModel::Fim { .. } => {}
1455 }
1456 }
1457
1458 async fn handle_rejected_predictions(
1459 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1460 client: Arc<Client>,
1461 llm_token: LlmApiToken,
1462 app_version: Version,
1463 background_executor: BackgroundExecutor,
1464 ) {
1465 let mut rx = std::pin::pin!(rx.peekable());
1466 let mut batched = Vec::new();
1467
1468 while let Some(EditPredictionRejectionPayload {
1469 rejection,
1470 organization_id,
1471 }) = rx.next().await
1472 {
1473 batched.push(rejection);
1474
1475 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1476 select_biased! {
1477 next = rx.as_mut().peek().fuse() => {
1478 if next.is_some() {
1479 continue;
1480 }
1481 }
1482 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1483 }
1484 }
1485
1486 let url = client
1487 .http_client()
1488 .build_zed_llm_url("/predict_edits/reject", &[])
1489 .unwrap();
1490
1491 let flush_count = batched
1492 .len()
1493 // in case items have accumulated after failure
1494 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1495 let start = batched.len() - flush_count;
1496
1497 let body = RejectEditPredictionsBodyRef {
1498 rejections: &batched[start..],
1499 };
1500
1501 let result = Self::send_api_request::<()>(
1502 |builder| {
1503 let req = builder
1504 .uri(url.as_ref())
1505 .body(serde_json::to_string(&body)?.into());
1506 anyhow::Ok(req?)
1507 },
1508 client.clone(),
1509 llm_token.clone(),
1510 organization_id,
1511 app_version.clone(),
1512 true,
1513 )
1514 .await;
1515
1516 if result.log_err().is_some() {
1517 batched.drain(start..);
1518 }
1519 }
1520 }
1521
1522 async fn run_settled_predictions_worker(
1523 this: WeakEntity<Self>,
1524 mut rx: UnboundedReceiver<Instant>,
1525 cx: &mut AsyncApp,
1526 ) {
1527 let mut next_wake_time: Option<Instant> = None;
1528 loop {
1529 let now = cx.background_executor().now();
1530 if let Some(wake_time) = next_wake_time.take() {
1531 cx.background_executor()
1532 .timer(wake_time.duration_since(now))
1533 .await;
1534 } else {
1535 let Some(new_enqueue_time) = rx.next().await else {
1536 break;
1537 };
1538 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1539 while rx.next().now_or_never().flatten().is_some() {}
1540 continue;
1541 }
1542
1543 let Some(this) = this.upgrade() else {
1544 break;
1545 };
1546
1547 let now = cx.background_executor().now();
1548
1549 let mut oldest_edited_at = None;
1550
1551 this.update(cx, |this, _| {
1552 for (_, project_state) in this.projects.iter_mut() {
1553 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1554 registered_buffer
1555 .pending_predictions
1556 .retain_mut(|pending_prediction| {
1557 let age =
1558 now.saturating_duration_since(pending_prediction.enqueued_at);
1559 if age >= EDIT_PREDICTION_SETTLED_TTL {
1560 return false;
1561 }
1562
1563 let quiet_for =
1564 now.saturating_duration_since(pending_prediction.last_edit_at);
1565 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1566 let settled_editable_region = registered_buffer
1567 .snapshot
1568 .text_for_range(
1569 pending_prediction.editable_anchor_range.clone(),
1570 )
1571 .collect::<String>();
1572
1573 #[cfg(test)]
1574 if let Some(callback) = &this.settled_event_callback {
1575 callback(
1576 pending_prediction.request_id.clone(),
1577 settled_editable_region.clone(),
1578 );
1579 }
1580
1581 telemetry::event!(
1582 EDIT_PREDICTION_SETTLED_EVENT,
1583 request_id = pending_prediction.request_id.0.clone(),
1584 settled_editable_region,
1585 example = pending_prediction.example.take(),
1586 );
1587
1588 return false;
1589 }
1590
1591 if oldest_edited_at
1592 .is_none_or(|t| pending_prediction.last_edit_at < t)
1593 {
1594 oldest_edited_at = Some(pending_prediction.last_edit_at);
1595 }
1596
1597 true
1598 });
1599 }
1600 }
1601 });
1602
1603 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1604 }
1605 }
1606
1607 pub(crate) fn enqueue_settled_prediction(
1608 &mut self,
1609 request_id: EditPredictionId,
1610 project: &Entity<Project>,
1611 edited_buffer: &Entity<Buffer>,
1612 edited_buffer_snapshot: &BufferSnapshot,
1613 editable_offset_range: Range<usize>,
1614 example: Option<ExampleSpec>,
1615 cx: &mut Context<Self>,
1616 ) {
1617 let this = &mut *self;
1618 let project_state = this.get_or_init_project(project, cx);
1619 if let Some(buffer) = project_state
1620 .registered_buffers
1621 .get_mut(&edited_buffer.entity_id())
1622 {
1623 let now = cx.background_executor().now();
1624 buffer.pending_predictions.push(PendingSettledPrediction {
1625 request_id: request_id,
1626 editable_anchor_range: edited_buffer_snapshot
1627 .anchor_range_around(editable_offset_range),
1628 example,
1629 enqueued_at: now,
1630 last_edit_at: now,
1631 });
1632 this.settled_predictions_tx.unbounded_send(now).ok();
1633 }
1634 }
1635
1636 fn reject_current_prediction(
1637 &mut self,
1638 reason: EditPredictionRejectReason,
1639 project: &Entity<Project>,
1640 cx: &App,
1641 ) {
1642 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1643 project_state.pending_predictions.clear();
1644 if let Some(prediction) = project_state.current_prediction.take() {
1645 let model_version = prediction.prediction.model_version.clone();
1646 self.reject_prediction(
1647 prediction.prediction.id,
1648 reason,
1649 prediction.was_shown,
1650 model_version,
1651 cx,
1652 );
1653 }
1654 };
1655 }
1656
1657 fn did_show_current_prediction(
1658 &mut self,
1659 project: &Entity<Project>,
1660 display_type: edit_prediction_types::SuggestionDisplayType,
1661 cx: &mut Context<Self>,
1662 ) {
1663 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1664 return;
1665 };
1666
1667 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1668 return;
1669 };
1670
1671 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1672 let previous_shown_with = current_prediction.shown_with;
1673
1674 if previous_shown_with.is_none() || !is_jump {
1675 current_prediction.shown_with = Some(display_type);
1676 }
1677
1678 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1679
1680 if is_first_non_jump_show {
1681 current_prediction.was_shown = true;
1682 }
1683
1684 let display_type_changed = previous_shown_with != Some(display_type);
1685
1686 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1687 sweep_ai::edit_prediction_shown(
1688 &self.sweep_ai,
1689 self.client.clone(),
1690 ¤t_prediction.prediction,
1691 display_type,
1692 cx,
1693 );
1694 }
1695
1696 if is_first_non_jump_show {
1697 self.shown_predictions
1698 .push_front(current_prediction.prediction.clone());
1699 if self.shown_predictions.len() > 50 {
1700 let completion = self.shown_predictions.pop_back().unwrap();
1701 self.rated_predictions.remove(&completion.id);
1702 }
1703 }
1704 }
1705
1706 fn reject_prediction(
1707 &mut self,
1708 prediction_id: EditPredictionId,
1709 reason: EditPredictionRejectReason,
1710 was_shown: bool,
1711 model_version: Option<String>,
1712 cx: &App,
1713 ) {
1714 match self.edit_prediction_model {
1715 EditPredictionModel::Zeta => {
1716 let is_cloud = !matches!(
1717 all_language_settings(None, cx).edit_predictions.provider,
1718 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1719 );
1720
1721 if is_cloud {
1722 let organization_id = self
1723 .user_store
1724 .read(cx)
1725 .current_organization()
1726 .map(|organization| organization.id.clone());
1727
1728 self.reject_predictions_tx
1729 .unbounded_send(EditPredictionRejectionPayload {
1730 rejection: EditPredictionRejection {
1731 request_id: prediction_id.to_string(),
1732 reason,
1733 was_shown,
1734 model_version,
1735 },
1736 organization_id,
1737 })
1738 .log_err();
1739 }
1740 }
1741 EditPredictionModel::Mercury => {
1742 mercury::edit_prediction_rejected(
1743 prediction_id,
1744 was_shown,
1745 reason,
1746 self.client.http_client(),
1747 cx,
1748 );
1749 }
1750 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1751 }
1752 }
1753
1754 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1755 self.projects
1756 .get(&project.entity_id())
1757 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1758 }
1759
1760 pub fn refresh_prediction_from_buffer(
1761 &mut self,
1762 project: Entity<Project>,
1763 buffer: Entity<Buffer>,
1764 position: language::Anchor,
1765 cx: &mut Context<Self>,
1766 ) {
1767 self.queue_prediction_refresh(
1768 project.clone(),
1769 PredictEditsRequestTrigger::Other,
1770 buffer.entity_id(),
1771 cx,
1772 move |this, cx| {
1773 let Some(request_task) = this
1774 .update(cx, |this, cx| {
1775 this.request_prediction(
1776 &project,
1777 &buffer,
1778 position,
1779 PredictEditsRequestTrigger::Other,
1780 cx,
1781 )
1782 })
1783 .log_err()
1784 else {
1785 return Task::ready(anyhow::Ok(None));
1786 };
1787
1788 cx.spawn(async move |_cx| {
1789 request_task.await.map(|prediction_result| {
1790 prediction_result.map(|prediction_result| {
1791 (
1792 prediction_result,
1793 PredictionRequestedBy::Buffer(buffer.entity_id()),
1794 )
1795 })
1796 })
1797 })
1798 },
1799 )
1800 }
1801
1802 pub fn refresh_prediction_from_diagnostics(
1803 &mut self,
1804 project: Entity<Project>,
1805 scope: DiagnosticSearchScope,
1806 cx: &mut Context<Self>,
1807 ) {
1808 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1809 return;
1810 }
1811
1812 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1813 return;
1814 };
1815
1816 // Prefer predictions from buffer
1817 if project_state.current_prediction.is_some() {
1818 log::debug!(
1819 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1820 );
1821 return;
1822 }
1823
1824 self.queue_prediction_refresh(
1825 project.clone(),
1826 PredictEditsRequestTrigger::Diagnostics,
1827 project.entity_id(),
1828 cx,
1829 move |this, cx| {
1830 let Some((active_buffer, snapshot, cursor_point)) = this
1831 .read_with(cx, |this, cx| {
1832 let project_state = this.projects.get(&project.entity_id())?;
1833 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1834 let snapshot = buffer.read(cx).snapshot();
1835
1836 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1837 return None;
1838 }
1839
1840 let cursor_point = position
1841 .map(|pos| pos.to_point(&snapshot))
1842 .unwrap_or_default();
1843
1844 Some((buffer, snapshot, cursor_point))
1845 })
1846 .log_err()
1847 .flatten()
1848 else {
1849 return Task::ready(anyhow::Ok(None));
1850 };
1851
1852 cx.spawn(async move |cx| {
1853 let diagnostic_search_range = match scope {
1854 DiagnosticSearchScope::Local => {
1855 let diagnostic_search_start =
1856 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1857 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1858 Point::new(diagnostic_search_start, 0)
1859 ..Point::new(diagnostic_search_end, 0)
1860 }
1861 DiagnosticSearchScope::Global => Default::default(),
1862 };
1863
1864 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1865 active_buffer,
1866 &snapshot,
1867 diagnostic_search_range,
1868 cursor_point,
1869 &project,
1870 cx,
1871 )
1872 .await?
1873 else {
1874 return anyhow::Ok(None);
1875 };
1876
1877 let Some(prediction_result) = this
1878 .update(cx, |this, cx| {
1879 this.request_prediction(
1880 &project,
1881 &jump_buffer,
1882 jump_position,
1883 PredictEditsRequestTrigger::Diagnostics,
1884 cx,
1885 )
1886 })?
1887 .await?
1888 else {
1889 return anyhow::Ok(None);
1890 };
1891
1892 this.update(cx, |this, cx| {
1893 Some((
1894 if this
1895 .get_or_init_project(&project, cx)
1896 .current_prediction
1897 .is_none()
1898 {
1899 prediction_result
1900 } else {
1901 EditPredictionResult {
1902 id: prediction_result.id,
1903 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1904 }
1905 },
1906 PredictionRequestedBy::DiagnosticsUpdate,
1907 ))
1908 })
1909 })
1910 },
1911 );
1912 }
1913
1914 fn predictions_enabled_at(
1915 snapshot: &BufferSnapshot,
1916 position: Option<language::Anchor>,
1917 cx: &App,
1918 ) -> bool {
1919 let file = snapshot.file();
1920 let all_settings = all_language_settings(file, cx);
1921 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1922 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1923 {
1924 return false;
1925 }
1926
1927 if let Some(last_position) = position {
1928 let settings = snapshot.settings_at(last_position, cx);
1929
1930 if !settings.edit_predictions_disabled_in.is_empty()
1931 && let Some(scope) = snapshot.language_scope_at(last_position)
1932 && let Some(scope_name) = scope.override_name()
1933 && settings
1934 .edit_predictions_disabled_in
1935 .iter()
1936 .any(|s| s == scope_name)
1937 {
1938 return false;
1939 }
1940 }
1941
1942 true
1943 }
1944
1945 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1946}
1947
1948fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1949 match provider {
1950 EditPredictionProvider::Zed
1951 | EditPredictionProvider::Sweep
1952 | EditPredictionProvider::Mercury
1953 | EditPredictionProvider::Ollama
1954 | EditPredictionProvider::OpenAiCompatibleApi
1955 | EditPredictionProvider::Experimental(_) => true,
1956 EditPredictionProvider::None
1957 | EditPredictionProvider::Copilot
1958 | EditPredictionProvider::Codestral => false,
1959 }
1960}
1961
1962impl EditPredictionStore {
1963 fn queue_prediction_refresh(
1964 &mut self,
1965 project: Entity<Project>,
1966 request_trigger: PredictEditsRequestTrigger,
1967 throttle_entity: EntityId,
1968 cx: &mut Context<Self>,
1969 do_refresh: impl FnOnce(
1970 WeakEntity<Self>,
1971 &mut AsyncApp,
1972 )
1973 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1974 + 'static,
1975 ) {
1976 fn select_throttle(
1977 project_state: &mut ProjectState,
1978 request_trigger: PredictEditsRequestTrigger,
1979 ) -> &mut Option<(EntityId, Instant)> {
1980 match request_trigger {
1981 PredictEditsRequestTrigger::Diagnostics => {
1982 &mut project_state.last_jump_prediction_refresh
1983 }
1984 _ => &mut project_state.last_edit_prediction_refresh,
1985 }
1986 }
1987
1988 let (needs_acceptance_tracking, max_pending_predictions) =
1989 match all_language_settings(None, cx).edit_predictions.provider {
1990 EditPredictionProvider::Zed
1991 | EditPredictionProvider::Sweep
1992 | EditPredictionProvider::Mercury
1993 | EditPredictionProvider::Experimental(_) => (true, 2),
1994 EditPredictionProvider::Ollama => (false, 1),
1995 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1996 EditPredictionProvider::None
1997 | EditPredictionProvider::Copilot
1998 | EditPredictionProvider::Codestral => {
1999 log::error!("queue_prediction_refresh called with non-store provider");
2000 return;
2001 }
2002 };
2003
2004 let drop_on_cancel = !needs_acceptance_tracking;
2005 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2006 let project_state = self.get_or_init_project(&project, cx);
2007 let pending_prediction_id = project_state.next_pending_prediction_id;
2008 project_state.next_pending_prediction_id += 1;
2009 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2010
2011 let task = cx.spawn(async move |this, cx| {
2012 let throttle_wait = this
2013 .update(cx, |this, cx| {
2014 let project_state = this.get_or_init_project(&project, cx);
2015 let throttle = *select_throttle(project_state, request_trigger);
2016
2017 throttle.and_then(|(last_entity, last_timestamp)| {
2018 if throttle_entity != last_entity {
2019 return None;
2020 }
2021 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2022 })
2023 })
2024 .ok()
2025 .flatten();
2026
2027 if let Some(timeout) = throttle_wait {
2028 cx.background_executor().timer(timeout).await;
2029 }
2030
2031 // If this task was cancelled before the throttle timeout expired,
2032 // do not perform a request. Also skip if another task already
2033 // proceeded since we were enqueued (duplicate).
2034 let mut is_cancelled = true;
2035 this.update(cx, |this, cx| {
2036 let project_state = this.get_or_init_project(&project, cx);
2037 let was_cancelled = project_state
2038 .cancelled_predictions
2039 .remove(&pending_prediction_id);
2040 if was_cancelled {
2041 return;
2042 }
2043
2044 // Another request has been already sent since this was enqueued
2045 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2046 return;
2047 }
2048
2049 let new_refresh = (throttle_entity, Instant::now());
2050 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2051 is_cancelled = false;
2052 })
2053 .ok();
2054 if is_cancelled {
2055 return None;
2056 }
2057
2058 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2059 let new_prediction_id = new_prediction_result
2060 .as_ref()
2061 .map(|(prediction, _)| prediction.id.clone());
2062
2063 // When a prediction completes, remove it from the pending list, and cancel
2064 // any pending predictions that were enqueued before it.
2065 this.update(cx, |this, cx| {
2066 let project_state = this.get_or_init_project(&project, cx);
2067
2068 let is_cancelled = project_state
2069 .cancelled_predictions
2070 .remove(&pending_prediction_id);
2071
2072 let new_current_prediction = if !is_cancelled
2073 && let Some((prediction_result, requested_by)) = new_prediction_result
2074 {
2075 match prediction_result.prediction {
2076 Ok(prediction) => {
2077 let new_prediction = CurrentEditPrediction {
2078 requested_by,
2079 prediction,
2080 was_shown: false,
2081 shown_with: None,
2082 };
2083
2084 if let Some(current_prediction) =
2085 project_state.current_prediction.as_ref()
2086 {
2087 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2088 {
2089 this.reject_current_prediction(
2090 EditPredictionRejectReason::Replaced,
2091 &project,
2092 cx,
2093 );
2094
2095 Some(new_prediction)
2096 } else {
2097 this.reject_prediction(
2098 new_prediction.prediction.id,
2099 EditPredictionRejectReason::CurrentPreferred,
2100 false,
2101 new_prediction.prediction.model_version,
2102 cx,
2103 );
2104 None
2105 }
2106 } else {
2107 Some(new_prediction)
2108 }
2109 }
2110 Err(reject_reason) => {
2111 this.reject_prediction(
2112 prediction_result.id,
2113 reject_reason,
2114 false,
2115 None,
2116 cx,
2117 );
2118 None
2119 }
2120 }
2121 } else {
2122 None
2123 };
2124
2125 let project_state = this.get_or_init_project(&project, cx);
2126
2127 if let Some(new_prediction) = new_current_prediction {
2128 project_state.current_prediction = Some(new_prediction);
2129 }
2130
2131 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2132 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2133 if pending_prediction.id == pending_prediction_id {
2134 pending_predictions.remove(ix);
2135 for pending_prediction in pending_predictions.drain(0..ix) {
2136 project_state.cancel_pending_prediction(pending_prediction, cx)
2137 }
2138 break;
2139 }
2140 }
2141 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2142 cx.notify();
2143 })
2144 .ok();
2145
2146 new_prediction_id
2147 });
2148
2149 if project_state.pending_predictions.len() < max_pending_predictions {
2150 project_state.pending_predictions.push(PendingPrediction {
2151 id: pending_prediction_id,
2152 task,
2153 drop_on_cancel,
2154 });
2155 } else {
2156 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2157 project_state.pending_predictions.push(PendingPrediction {
2158 id: pending_prediction_id,
2159 task,
2160 drop_on_cancel,
2161 });
2162 project_state.cancel_pending_prediction(pending_prediction, cx);
2163 }
2164 }
2165
2166 pub fn request_prediction(
2167 &mut self,
2168 project: &Entity<Project>,
2169 active_buffer: &Entity<Buffer>,
2170 position: language::Anchor,
2171 trigger: PredictEditsRequestTrigger,
2172 cx: &mut Context<Self>,
2173 ) -> Task<Result<Option<EditPredictionResult>>> {
2174 self.request_prediction_internal(
2175 project.clone(),
2176 active_buffer.clone(),
2177 position,
2178 trigger,
2179 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2180 cx,
2181 )
2182 }
2183
2184 fn request_prediction_internal(
2185 &mut self,
2186 project: Entity<Project>,
2187 active_buffer: Entity<Buffer>,
2188 position: language::Anchor,
2189 trigger: PredictEditsRequestTrigger,
2190 allow_jump: bool,
2191 cx: &mut Context<Self>,
2192 ) -> Task<Result<Option<EditPredictionResult>>> {
2193 self.get_or_init_project(&project, cx);
2194 let project_state = self.projects.get(&project.entity_id()).unwrap();
2195 let stored_events = project_state.events(cx);
2196 let has_events = !stored_events.is_empty();
2197 let events: Vec<Arc<zeta_prompt::Event>> =
2198 stored_events.iter().map(|e| e.event.clone()).collect();
2199 let debug_tx = project_state.debug_tx.clone();
2200
2201 let snapshot = active_buffer.read(cx).snapshot();
2202 let cursor_point = position.to_point(&snapshot);
2203 let current_offset = position.to_offset(&snapshot);
2204
2205 let mut user_actions: Vec<UserActionRecord> =
2206 project_state.user_actions.iter().cloned().collect();
2207
2208 if let Some(last_action) = user_actions.last() {
2209 if last_action.buffer_id == active_buffer.entity_id()
2210 && current_offset != last_action.offset
2211 {
2212 let timestamp_epoch_ms = SystemTime::now()
2213 .duration_since(UNIX_EPOCH)
2214 .map(|d| d.as_millis() as u64)
2215 .unwrap_or(0);
2216 user_actions.push(UserActionRecord {
2217 action_type: UserActionType::CursorMovement,
2218 buffer_id: active_buffer.entity_id(),
2219 line_number: cursor_point.row,
2220 offset: current_offset,
2221 timestamp_epoch_ms,
2222 });
2223 }
2224 }
2225 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2226 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2227 let diagnostic_search_range =
2228 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2229
2230 let related_files = self.context_for_project(&project, cx);
2231
2232 let is_open_source = snapshot
2233 .file()
2234 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2235 && events.iter().all(|event| event.in_open_source_repo())
2236 && related_files.iter().all(|file| file.in_open_source_repo);
2237
2238 let can_collect_data = !cfg!(test)
2239 && is_open_source
2240 && self.is_data_collection_enabled(cx)
2241 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2242
2243 let recent_paths = project_state.recent_paths.clone();
2244
2245 let inputs = EditPredictionModelInput {
2246 project: project.clone(),
2247 buffer: active_buffer,
2248 snapshot,
2249 position,
2250 events,
2251 related_files,
2252 recent_paths,
2253 trigger,
2254 diagnostic_search_range: diagnostic_search_range,
2255 debug_tx,
2256 user_actions,
2257 can_collect_data,
2258 is_open_source,
2259 };
2260
2261 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2262
2263 let task = match self.edit_prediction_model {
2264 EditPredictionModel::Zeta => {
2265 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2266 }
2267 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2268 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2269 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2270 };
2271
2272 cx.spawn(async move |this, cx| {
2273 let prediction = task.await?;
2274
2275 // Only fall back to diagnostics-based prediction if we got a
2276 // the model had nothing to suggest for the buffer
2277 if prediction.is_none()
2278 && allow_jump
2279 && has_events
2280 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2281 {
2282 this.update(cx, |this, cx| {
2283 this.refresh_prediction_from_diagnostics(
2284 project,
2285 DiagnosticSearchScope::Local,
2286 cx,
2287 );
2288 })?;
2289 return anyhow::Ok(None);
2290 }
2291
2292 Ok(prediction)
2293 })
2294 }
2295
2296 pub(crate) async fn next_diagnostic_location(
2297 active_buffer: Entity<Buffer>,
2298 active_buffer_snapshot: &BufferSnapshot,
2299 active_buffer_diagnostic_search_range: Range<Point>,
2300 active_buffer_cursor_point: Point,
2301 project: &Entity<Project>,
2302 cx: &mut AsyncApp,
2303 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2304 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2305 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2306 .flat_map(|(_, _, _, selections)| {
2307 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2308 })
2309 .collect();
2310
2311 let mut jump_location = active_buffer_snapshot
2312 .diagnostic_groups(None)
2313 .into_iter()
2314 .filter_map(|(_, group)| {
2315 let range = &group.entries[group.primary_ix]
2316 .range
2317 .to_point(&active_buffer_snapshot);
2318 if range.overlaps(&active_buffer_diagnostic_search_range) {
2319 return None;
2320 }
2321 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2322 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2323 });
2324 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2325 <= DIAGNOSTIC_LINES_RANGE;
2326 if near_collaborator && !near_local {
2327 return None;
2328 }
2329 Some(range.start)
2330 })
2331 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2332 .map(|position| {
2333 (
2334 active_buffer.clone(),
2335 active_buffer_snapshot.anchor_before(position),
2336 )
2337 });
2338
2339 if jump_location.is_none() {
2340 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2341 let file = buffer.file()?;
2342
2343 Some(ProjectPath {
2344 worktree_id: file.worktree_id(cx),
2345 path: file.path().clone(),
2346 })
2347 });
2348
2349 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2350 project
2351 .diagnostic_summaries(false, cx)
2352 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2353 .map(|(path, _, _)| {
2354 let shared_prefix = path
2355 .path
2356 .components()
2357 .zip(
2358 active_buffer_path
2359 .as_ref()
2360 .map(|p| p.path.components())
2361 .unwrap_or_default(),
2362 )
2363 .take_while(|(a, b)| a == b)
2364 .count();
2365 (path, shared_prefix)
2366 })
2367 .collect()
2368 });
2369
2370 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2371
2372 for (path, _) in candidates {
2373 let candidate_buffer = project
2374 .update(cx, |project, cx| project.open_buffer(path, cx))
2375 .await?;
2376
2377 let (has_collaborators, diagnostic_position) =
2378 candidate_buffer.read_with(cx, |buffer, _cx| {
2379 let snapshot = buffer.snapshot();
2380 let has_collaborators = snapshot
2381 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2382 .next()
2383 .is_some();
2384 let position = buffer
2385 .buffer_diagnostics(None)
2386 .into_iter()
2387 .min_by_key(|entry| entry.diagnostic.severity)
2388 .map(|entry| entry.range.start);
2389 (has_collaborators, position)
2390 });
2391
2392 if has_collaborators {
2393 continue;
2394 }
2395
2396 if let Some(position) = diagnostic_position {
2397 jump_location = Some((candidate_buffer, position));
2398 break;
2399 }
2400 }
2401 }
2402
2403 anyhow::Ok(jump_location)
2404 }
2405
2406 async fn send_raw_llm_request(
2407 request: RawCompletionRequest,
2408 client: Arc<Client>,
2409 custom_url: Option<Arc<Url>>,
2410 llm_token: LlmApiToken,
2411 organization_id: Option<OrganizationId>,
2412 app_version: Version,
2413 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2414 let url = if let Some(custom_url) = custom_url {
2415 custom_url.as_ref().clone()
2416 } else {
2417 client
2418 .http_client()
2419 .build_zed_llm_url("/predict_edits/raw", &[])?
2420 };
2421
2422 Self::send_api_request(
2423 |builder| {
2424 let req = builder
2425 .uri(url.as_ref())
2426 .body(serde_json::to_string(&request)?.into());
2427 Ok(req?)
2428 },
2429 client,
2430 llm_token,
2431 organization_id,
2432 app_version,
2433 true,
2434 )
2435 .await
2436 }
2437
2438 pub(crate) async fn send_v3_request(
2439 input: ZetaPromptInput,
2440 client: Arc<Client>,
2441 llm_token: LlmApiToken,
2442 organization_id: Option<OrganizationId>,
2443 app_version: Version,
2444 trigger: PredictEditsRequestTrigger,
2445 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2446 let url = client
2447 .http_client()
2448 .build_zed_llm_url("/predict_edits/v3", &[])?;
2449
2450 let request = PredictEditsV3Request { input, trigger };
2451
2452 let json_bytes = serde_json::to_vec(&request)?;
2453 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2454
2455 Self::send_api_request(
2456 |builder| {
2457 let req = builder
2458 .uri(url.as_ref())
2459 .header("Content-Encoding", "zstd")
2460 .body(compressed.clone().into());
2461 Ok(req?)
2462 },
2463 client,
2464 llm_token,
2465 organization_id,
2466 app_version,
2467 true,
2468 )
2469 .await
2470 }
2471
2472 async fn send_api_request<Res>(
2473 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2474 client: Arc<Client>,
2475 llm_token: LlmApiToken,
2476 organization_id: Option<OrganizationId>,
2477 app_version: Version,
2478 require_auth: bool,
2479 ) -> Result<(Res, Option<EditPredictionUsage>)>
2480 where
2481 Res: DeserializeOwned,
2482 {
2483 let http_client = client.http_client();
2484
2485 let mut token = if require_auth {
2486 Some(llm_token.acquire(&client, organization_id.clone()).await?)
2487 } else {
2488 llm_token
2489 .acquire(&client, organization_id.clone())
2490 .await
2491 .ok()
2492 };
2493 let mut did_retry = false;
2494
2495 loop {
2496 let request_builder = http_client::Request::builder().method(Method::POST);
2497
2498 let mut request_builder = request_builder
2499 .header("Content-Type", "application/json")
2500 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2501
2502 // Only add Authorization header if we have a token
2503 if let Some(ref token_value) = token {
2504 request_builder =
2505 request_builder.header("Authorization", format!("Bearer {}", token_value));
2506 }
2507
2508 let request = build(request_builder)?;
2509
2510 let mut response = http_client.send(request).await?;
2511
2512 if let Some(minimum_required_version) = response
2513 .headers()
2514 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2515 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2516 {
2517 anyhow::ensure!(
2518 app_version >= minimum_required_version,
2519 ZedUpdateRequiredError {
2520 minimum_version: minimum_required_version
2521 }
2522 );
2523 }
2524
2525 if response.status().is_success() {
2526 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2527
2528 let mut body = Vec::new();
2529 response.body_mut().read_to_end(&mut body).await?;
2530 return Ok((serde_json::from_slice(&body)?, usage));
2531 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2532 did_retry = true;
2533 token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2534 } else {
2535 let mut body = String::new();
2536 response.body_mut().read_to_string(&mut body).await?;
2537 anyhow::bail!(
2538 "Request failed with status: {:?}\nBody: {}",
2539 response.status(),
2540 body
2541 );
2542 }
2543 }
2544 }
2545
2546 pub fn refresh_context(
2547 &mut self,
2548 project: &Entity<Project>,
2549 buffer: &Entity<language::Buffer>,
2550 cursor_position: language::Anchor,
2551 cx: &mut Context<Self>,
2552 ) {
2553 self.get_or_init_project(project, cx)
2554 .context
2555 .update(cx, |store, cx| {
2556 store.refresh(buffer.clone(), cursor_position, cx);
2557 });
2558 }
2559
2560 #[cfg(feature = "cli-support")]
2561 pub fn set_context_for_buffer(
2562 &mut self,
2563 project: &Entity<Project>,
2564 related_files: Vec<RelatedFile>,
2565 cx: &mut Context<Self>,
2566 ) {
2567 self.get_or_init_project(project, cx)
2568 .context
2569 .update(cx, |store, cx| {
2570 store.set_related_files(related_files, cx);
2571 });
2572 }
2573
2574 #[cfg(feature = "cli-support")]
2575 pub fn set_recent_paths_for_project(
2576 &mut self,
2577 project: &Entity<Project>,
2578 paths: impl IntoIterator<Item = project::ProjectPath>,
2579 cx: &mut Context<Self>,
2580 ) {
2581 let project_state = self.get_or_init_project(project, cx);
2582 project_state.recent_paths = paths.into_iter().collect();
2583 }
2584
2585 fn is_file_open_source(
2586 &self,
2587 project: &Entity<Project>,
2588 file: &Arc<dyn File>,
2589 cx: &App,
2590 ) -> bool {
2591 if !file.is_local() || file.is_private() {
2592 return false;
2593 }
2594 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2595 return false;
2596 };
2597 project_state
2598 .license_detection_watchers
2599 .get(&file.worktree_id(cx))
2600 .as_ref()
2601 .is_some_and(|watcher| watcher.is_project_open_source())
2602 }
2603
2604 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2605 self.data_collection_choice.is_enabled(cx)
2606 }
2607
2608 fn load_data_collection_choice() -> DataCollectionChoice {
2609 let choice = KEY_VALUE_STORE
2610 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2611 .log_err()
2612 .flatten();
2613
2614 match choice.as_deref() {
2615 Some("true") => DataCollectionChoice::Enabled,
2616 Some("false") => DataCollectionChoice::Disabled,
2617 Some(_) => {
2618 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2619 DataCollectionChoice::NotAnswered
2620 }
2621 None => DataCollectionChoice::NotAnswered,
2622 }
2623 }
2624
2625 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2626 self.data_collection_choice = self.data_collection_choice.toggle();
2627 let new_choice = self.data_collection_choice;
2628 let is_enabled = new_choice.is_enabled(cx);
2629 db::write_and_log(cx, move || {
2630 KEY_VALUE_STORE.write_kvp(
2631 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2632 is_enabled.to_string(),
2633 )
2634 });
2635 }
2636
2637 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2638 self.shown_predictions.iter()
2639 }
2640
2641 pub fn shown_completions_len(&self) -> usize {
2642 self.shown_predictions.len()
2643 }
2644
2645 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2646 self.rated_predictions.contains(id)
2647 }
2648
2649 pub fn rate_prediction(
2650 &mut self,
2651 prediction: &EditPrediction,
2652 rating: EditPredictionRating,
2653 feedback: String,
2654 cx: &mut Context<Self>,
2655 ) {
2656 let organization = self.user_store.read(cx).current_organization();
2657
2658 self.rated_predictions.insert(prediction.id.clone());
2659
2660 cx.background_spawn({
2661 let client = self.client.clone();
2662 let prediction_id = prediction.id.to_string();
2663 let inputs = serde_json::to_value(&prediction.inputs);
2664 let output = prediction
2665 .edit_preview
2666 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2667 async move {
2668 client
2669 .cloud_client()
2670 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2671 organization_id: organization.map(|organization| organization.id.clone()),
2672 request_id: prediction_id,
2673 rating: match rating {
2674 EditPredictionRating::Positive => "positive".to_string(),
2675 EditPredictionRating::Negative => "negative".to_string(),
2676 },
2677 inputs: inputs?,
2678 output,
2679 feedback,
2680 })
2681 .await?;
2682
2683 anyhow::Ok(())
2684 }
2685 })
2686 .detach_and_log_err(cx);
2687
2688 cx.notify();
2689 }
2690}
2691
2692fn merge_trailing_events_if_needed(
2693 events: &mut VecDeque<StoredEvent>,
2694 end_snapshot: &TextBufferSnapshot,
2695 latest_snapshot: &TextBufferSnapshot,
2696 latest_edit_range: &Range<Anchor>,
2697) {
2698 if let Some(last_event) = events.back() {
2699 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2700 return;
2701 }
2702 }
2703
2704 let mut next_old_event = None;
2705 let mut mergeable_count = 0;
2706 for old_event in events.iter().rev() {
2707 if let Some(next_old_event) = &next_old_event
2708 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2709 {
2710 break;
2711 }
2712 mergeable_count += 1;
2713 next_old_event = Some(old_event);
2714 }
2715
2716 if mergeable_count <= 1 {
2717 return;
2718 }
2719
2720 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2721 let oldest_event = events_to_merge.peek().unwrap();
2722 let oldest_snapshot = oldest_event.old_snapshot.clone();
2723
2724 if let Some((diff, edited_range)) =
2725 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2726 {
2727 let merged_event = match oldest_event.event.as_ref() {
2728 zeta_prompt::Event::BufferChange {
2729 old_path,
2730 path,
2731 in_open_source_repo,
2732 ..
2733 } => StoredEvent {
2734 event: Arc::new(zeta_prompt::Event::BufferChange {
2735 old_path: old_path.clone(),
2736 path: path.clone(),
2737 diff,
2738 in_open_source_repo: *in_open_source_repo,
2739 predicted: events_to_merge.all(|e| {
2740 matches!(
2741 e.event.as_ref(),
2742 zeta_prompt::Event::BufferChange {
2743 predicted: true,
2744 ..
2745 }
2746 )
2747 }),
2748 }),
2749 old_snapshot: oldest_snapshot.clone(),
2750 edit_range: end_snapshot.anchor_before(edited_range.start)
2751 ..end_snapshot.anchor_before(edited_range.end),
2752 },
2753 };
2754 events.truncate(events.len() - mergeable_count);
2755 events.push_back(merged_event);
2756 }
2757}
2758
2759#[derive(Error, Debug)]
2760#[error(
2761 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2762)]
2763pub struct ZedUpdateRequiredError {
2764 minimum_version: Version,
2765}
2766
2767#[derive(Debug, Clone, Copy)]
2768pub enum DataCollectionChoice {
2769 NotAnswered,
2770 Enabled,
2771 Disabled,
2772}
2773
2774impl DataCollectionChoice {
2775 pub fn is_enabled(self, cx: &App) -> bool {
2776 if cx.is_staff() {
2777 return true;
2778 }
2779 match self {
2780 Self::Enabled => true,
2781 Self::NotAnswered | Self::Disabled => false,
2782 }
2783 }
2784
2785 #[must_use]
2786 pub fn toggle(&self) -> DataCollectionChoice {
2787 match self {
2788 Self::Enabled => Self::Disabled,
2789 Self::Disabled => Self::Enabled,
2790 Self::NotAnswered => Self::Enabled,
2791 }
2792 }
2793}
2794
2795impl From<bool> for DataCollectionChoice {
2796 fn from(value: bool) -> Self {
2797 match value {
2798 true => DataCollectionChoice::Enabled,
2799 false => DataCollectionChoice::Disabled,
2800 }
2801 }
2802}
2803
2804struct ZedPredictUpsell;
2805
2806impl Dismissable for ZedPredictUpsell {
2807 const KEY: &'static str = "dismissed-edit-predict-upsell";
2808
2809 fn dismissed() -> bool {
2810 // To make this backwards compatible with older versions of Zed, we
2811 // check if the user has seen the previous Edit Prediction Onboarding
2812 // before, by checking the data collection choice which was written to
2813 // the database once the user clicked on "Accept and Enable"
2814 if KEY_VALUE_STORE
2815 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2816 .log_err()
2817 .is_some_and(|s| s.is_some())
2818 {
2819 return true;
2820 }
2821
2822 KEY_VALUE_STORE
2823 .read_kvp(Self::KEY)
2824 .log_err()
2825 .is_some_and(|s| s.is_some())
2826 }
2827}
2828
2829pub fn should_show_upsell_modal() -> bool {
2830 !ZedPredictUpsell::dismissed()
2831}
2832
2833pub fn init(cx: &mut App) {
2834 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2835 workspace.register_action(
2836 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2837 ZedPredictModal::toggle(
2838 workspace,
2839 workspace.user_store().clone(),
2840 workspace.client().clone(),
2841 window,
2842 cx,
2843 )
2844 },
2845 );
2846
2847 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2848 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2849 settings
2850 .project
2851 .all_languages
2852 .edit_predictions
2853 .get_or_insert_default()
2854 .provider = Some(EditPredictionProvider::None)
2855 });
2856 });
2857 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2858 EditPredictionStore::try_global(cx).and_then(|store| {
2859 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2860 })
2861 }
2862
2863 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2864 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2865 copilot_ui::initiate_sign_in(copilot, window, cx);
2866 }
2867 });
2868 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2869 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2870 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2871 }
2872 });
2873 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2874 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2875 copilot_ui::initiate_sign_out(copilot, window, cx);
2876 }
2877 });
2878 })
2879 .detach();
2880}