1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod cursor_excerpt;
59pub mod example_spec;
60pub mod fim;
61mod license_detection;
62pub mod mercury;
63pub mod ollama;
64mod onboarding_modal;
65pub mod open_ai_response;
66mod prediction;
67pub mod sweep_ai;
68
69pub mod udiff;
70
71mod capture_example;
72pub mod open_ai_compatible;
73mod zed_edit_prediction_delegate;
74pub mod zeta;
75
76#[cfg(test)]
77mod edit_prediction_tests;
78
79use crate::license_detection::LicenseDetectionWatcher;
80use crate::mercury::Mercury;
81use crate::onboarding_modal::ZedPredictModal;
82pub use crate::prediction::EditPrediction;
83pub use crate::prediction::EditPredictionId;
84use crate::prediction::EditPredictionResult;
85pub use crate::sweep_ai::SweepAi;
86pub use capture_example::capture_example;
87pub use language_model::ApiKeyState;
88pub use telemetry_events::EditPredictionRating;
89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
90
91actions!(
92 edit_prediction,
93 [
94 /// Resets the edit prediction onboarding state.
95 ResetOnboarding,
96 /// Clears the edit prediction history.
97 ClearHistory,
98 ]
99);
100
101/// Maximum number of events to track.
102const EVENT_COUNT_MAX: usize = 6;
103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
107const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
108const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
109const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
110
111pub struct EditPredictionJumpsFeatureFlag;
112
113impl FeatureFlag for EditPredictionJumpsFeatureFlag {
114 const NAME: &'static str = "edit_prediction_jumps";
115}
116
117#[derive(Clone)]
118struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
119
120impl Global for EditPredictionStoreGlobal {}
121
122/// Configuration for using the raw Zeta2 endpoint.
123/// When set, the client uses the raw endpoint and constructs the prompt itself.
124/// The version is also used as the Baseten environment name (lowercased).
125#[derive(Clone)]
126pub struct Zeta2RawConfig {
127 pub model_id: Option<String>,
128 pub environment: Option<String>,
129 pub format: ZetaFormat,
130}
131
132pub struct EditPredictionStore {
133 client: Arc<Client>,
134 user_store: Entity<UserStore>,
135 llm_token: LlmApiToken,
136 _llm_token_subscription: Subscription,
137 _fetch_experiments_task: Task<()>,
138 projects: HashMap<EntityId, ProjectState>,
139 update_required: bool,
140 edit_prediction_model: EditPredictionModel,
141 zeta2_raw_config: Option<Zeta2RawConfig>,
142 preferred_experiment: Option<String>,
143 available_experiments: Vec<String>,
144 pub sweep_ai: SweepAi,
145 pub mercury: Mercury,
146 data_collection_choice: DataCollectionChoice,
147 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
148 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
149 shown_predictions: VecDeque<EditPrediction>,
150 rated_predictions: HashSet<EditPredictionId>,
151 #[cfg(test)]
152 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
153}
154
155pub(crate) struct EditPredictionRejectionPayload {
156 rejection: EditPredictionRejection,
157 organization_id: Option<OrganizationId>,
158}
159
160#[derive(Copy, Clone, PartialEq, Eq)]
161pub enum EditPredictionModel {
162 Zeta,
163 Fim { format: EditPredictionPromptFormat },
164 Sweep,
165 Mercury,
166}
167
168#[derive(Clone)]
169pub struct EditPredictionModelInput {
170 project: Entity<Project>,
171 buffer: Entity<Buffer>,
172 snapshot: BufferSnapshot,
173 position: Anchor,
174 events: Vec<Arc<zeta_prompt::Event>>,
175 related_files: Vec<RelatedFile>,
176 recent_paths: VecDeque<ProjectPath>,
177 trigger: PredictEditsRequestTrigger,
178 diagnostic_search_range: Range<Point>,
179 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
180 can_collect_data: bool,
181 is_open_source: bool,
182 pub user_actions: Vec<UserActionRecord>,
183}
184
185#[derive(Debug)]
186pub enum DebugEvent {
187 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
188 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
189 EditPredictionStarted(EditPredictionStartedDebugEvent),
190 EditPredictionFinished(EditPredictionFinishedDebugEvent),
191}
192
193#[derive(Debug)]
194pub struct ContextRetrievalStartedDebugEvent {
195 pub project_entity_id: EntityId,
196 pub timestamp: Instant,
197 pub search_prompt: String,
198}
199
200#[derive(Debug)]
201pub struct ContextRetrievalFinishedDebugEvent {
202 pub project_entity_id: EntityId,
203 pub timestamp: Instant,
204 pub metadata: Vec<(&'static str, SharedString)>,
205}
206
207#[derive(Debug)]
208pub struct EditPredictionStartedDebugEvent {
209 pub buffer: WeakEntity<Buffer>,
210 pub position: Anchor,
211 pub prompt: Option<String>,
212}
213
214#[derive(Debug)]
215pub struct EditPredictionFinishedDebugEvent {
216 pub buffer: WeakEntity<Buffer>,
217 pub position: Anchor,
218 pub model_output: Option<String>,
219}
220
221const USER_ACTION_HISTORY_SIZE: usize = 16;
222
223#[derive(Clone, Debug)]
224pub struct UserActionRecord {
225 pub action_type: UserActionType,
226 pub buffer_id: EntityId,
227 pub line_number: u32,
228 pub offset: usize,
229 pub timestamp_epoch_ms: u64,
230}
231
232#[derive(Clone, Copy, Debug, PartialEq, Eq)]
233pub enum UserActionType {
234 InsertChar,
235 InsertSelection,
236 DeleteChar,
237 DeleteSelection,
238 CursorMovement,
239}
240
241/// An event with associated metadata for reconstructing buffer state.
242#[derive(Clone)]
243pub struct StoredEvent {
244 pub event: Arc<zeta_prompt::Event>,
245 pub old_snapshot: TextBufferSnapshot,
246 pub edit_range: Range<Anchor>,
247}
248
249impl StoredEvent {
250 fn can_merge(
251 &self,
252 next_old_event: &&&StoredEvent,
253 new_snapshot: &TextBufferSnapshot,
254 last_edit_range: &Range<Anchor>,
255 ) -> bool {
256 // Events must be for the same buffer
257 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
258 return false;
259 }
260 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
261 return false;
262 }
263
264 let a_is_predicted = matches!(
265 self.event.as_ref(),
266 zeta_prompt::Event::BufferChange {
267 predicted: true,
268 ..
269 }
270 );
271 let b_is_predicted = matches!(
272 next_old_event.event.as_ref(),
273 zeta_prompt::Event::BufferChange {
274 predicted: true,
275 ..
276 }
277 );
278
279 // If events come from the same source (both predicted or both manual) then
280 // we would have coalesced them already.
281 if a_is_predicted == b_is_predicted {
282 return false;
283 }
284
285 let left_range = self.edit_range.to_point(new_snapshot);
286 let right_range = next_old_event.edit_range.to_point(new_snapshot);
287 let latest_range = last_edit_range.to_point(&new_snapshot);
288
289 // Events near to the latest edit are not merged if their sources differ.
290 if lines_between_ranges(&left_range, &latest_range)
291 .min(lines_between_ranges(&right_range, &latest_range))
292 <= CHANGE_GROUPING_LINE_SPAN
293 {
294 return false;
295 }
296
297 // Events that are distant from each other are not merged.
298 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
299 return false;
300 }
301
302 true
303 }
304}
305
306fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
307 if left.start > right.end {
308 return left.start.row - right.end.row;
309 }
310 if right.start > left.end {
311 return right.start.row - left.end.row;
312 }
313 0
314}
315
316struct ProjectState {
317 events: VecDeque<StoredEvent>,
318 last_event: Option<LastEvent>,
319 recent_paths: VecDeque<ProjectPath>,
320 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
321 current_prediction: Option<CurrentEditPrediction>,
322 next_pending_prediction_id: usize,
323 pending_predictions: ArrayVec<PendingPrediction, 2>,
324 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
325 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
326 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
327 cancelled_predictions: HashSet<usize>,
328 context: Entity<RelatedExcerptStore>,
329 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
330 user_actions: VecDeque<UserActionRecord>,
331 _subscriptions: [gpui::Subscription; 2],
332 copilot: Option<Entity<Copilot>>,
333}
334
335impl ProjectState {
336 fn record_user_action(&mut self, action: UserActionRecord) {
337 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
338 self.user_actions.pop_front();
339 }
340 self.user_actions.push_back(action);
341 }
342
343 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
344 self.events
345 .iter()
346 .cloned()
347 .chain(self.last_event.as_ref().iter().flat_map(|event| {
348 let (one, two) = event.split_by_pause();
349 let one = one.finalize(&self.license_detection_watchers, cx);
350 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
351 one.into_iter().chain(two)
352 }))
353 .collect()
354 }
355
356 fn cancel_pending_prediction(
357 &mut self,
358 pending_prediction: PendingPrediction,
359 cx: &mut Context<EditPredictionStore>,
360 ) {
361 self.cancelled_predictions.insert(pending_prediction.id);
362
363 if pending_prediction.drop_on_cancel {
364 drop(pending_prediction.task);
365 } else {
366 cx.spawn(async move |this, cx| {
367 let Some(prediction_id) = pending_prediction.task.await else {
368 return;
369 };
370
371 this.update(cx, |this, cx| {
372 this.reject_prediction(
373 prediction_id,
374 EditPredictionRejectReason::Canceled,
375 false,
376 None,
377 cx,
378 );
379 })
380 .ok();
381 })
382 .detach()
383 }
384 }
385
386 fn active_buffer(
387 &self,
388 project: &Entity<Project>,
389 cx: &App,
390 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
391 let project = project.read(cx);
392 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
393 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
394 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
395 Some((active_buffer, registered_buffer.last_position))
396 }
397}
398
399#[derive(Debug, Clone)]
400struct CurrentEditPrediction {
401 pub requested_by: PredictionRequestedBy,
402 pub prediction: EditPrediction,
403 pub was_shown: bool,
404 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
405}
406
407impl CurrentEditPrediction {
408 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
409 let Some(new_edits) = self
410 .prediction
411 .interpolate(&self.prediction.buffer.read(cx))
412 else {
413 return false;
414 };
415
416 if self.prediction.buffer != old_prediction.prediction.buffer {
417 return true;
418 }
419
420 let Some(old_edits) = old_prediction
421 .prediction
422 .interpolate(&old_prediction.prediction.buffer.read(cx))
423 else {
424 return true;
425 };
426
427 let requested_by_buffer_id = self.requested_by.buffer_id();
428
429 // This reduces the occurrence of UI thrash from replacing edits
430 //
431 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
432 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
433 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
434 && old_edits.len() == 1
435 && new_edits.len() == 1
436 {
437 let (old_range, old_text) = &old_edits[0];
438 let (new_range, new_text) = &new_edits[0];
439 new_range == old_range && new_text.starts_with(old_text.as_ref())
440 } else {
441 true
442 }
443 }
444}
445
446#[derive(Debug, Clone)]
447enum PredictionRequestedBy {
448 DiagnosticsUpdate,
449 Buffer(EntityId),
450}
451
452impl PredictionRequestedBy {
453 pub fn buffer_id(&self) -> Option<EntityId> {
454 match self {
455 PredictionRequestedBy::DiagnosticsUpdate => None,
456 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
457 }
458 }
459}
460
461const DIAGNOSTIC_LINES_RANGE: u32 = 20;
462
463#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
464pub enum DiagnosticSearchScope {
465 Local,
466 Global,
467}
468
469#[derive(Debug)]
470struct PendingPrediction {
471 id: usize,
472 task: Task<Option<EditPredictionId>>,
473 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
474 /// If false, the task is awaited to completion so rejection can be reported.
475 drop_on_cancel: bool,
476}
477
478/// A prediction from the perspective of a buffer.
479#[derive(Debug)]
480enum BufferEditPrediction<'a> {
481 Local { prediction: &'a EditPrediction },
482 Jump { prediction: &'a EditPrediction },
483}
484
485#[cfg(test)]
486impl std::ops::Deref for BufferEditPrediction<'_> {
487 type Target = EditPrediction;
488
489 fn deref(&self) -> &Self::Target {
490 match self {
491 BufferEditPrediction::Local { prediction } => prediction,
492 BufferEditPrediction::Jump { prediction } => prediction,
493 }
494 }
495}
496
497#[derive(Clone)]
498struct PendingSettledPrediction {
499 request_id: EditPredictionId,
500 editable_anchor_range: Range<Anchor>,
501 enqueued_at: Instant,
502 last_edit_at: Instant,
503}
504
505struct RegisteredBuffer {
506 file: Option<Arc<dyn File>>,
507 snapshot: TextBufferSnapshot,
508 pending_predictions: Vec<PendingSettledPrediction>,
509 last_position: Option<Anchor>,
510 _subscriptions: [gpui::Subscription; 2],
511}
512
513#[derive(Clone)]
514struct LastEvent {
515 old_snapshot: TextBufferSnapshot,
516 new_snapshot: TextBufferSnapshot,
517 old_file: Option<Arc<dyn File>>,
518 new_file: Option<Arc<dyn File>>,
519 edit_range: Option<Range<Anchor>>,
520 predicted: bool,
521 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
522 last_edit_time: Option<Instant>,
523}
524
525impl LastEvent {
526 pub fn finalize(
527 &self,
528 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
529 cx: &App,
530 ) -> Option<StoredEvent> {
531 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
532 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
533
534 let in_open_source_repo =
535 [self.new_file.as_ref(), self.old_file.as_ref()]
536 .iter()
537 .all(|file| {
538 file.is_some_and(|file| {
539 license_detection_watchers
540 .get(&file.worktree_id(cx))
541 .is_some_and(|watcher| watcher.is_project_open_source())
542 })
543 });
544
545 let (diff, edit_range) =
546 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
547
548 if path == old_path && diff.is_empty() {
549 None
550 } else {
551 Some(StoredEvent {
552 event: Arc::new(zeta_prompt::Event::BufferChange {
553 old_path,
554 path,
555 diff,
556 in_open_source_repo,
557 predicted: self.predicted,
558 }),
559 edit_range: self.new_snapshot.anchor_before(edit_range.start)
560 ..self.new_snapshot.anchor_before(edit_range.end),
561 old_snapshot: self.old_snapshot.clone(),
562 })
563 }
564 }
565
566 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
567 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
568 return (self.clone(), None);
569 };
570
571 let before = LastEvent {
572 old_snapshot: self.old_snapshot.clone(),
573 new_snapshot: boundary_snapshot.clone(),
574 old_file: self.old_file.clone(),
575 new_file: self.new_file.clone(),
576 edit_range: None,
577 predicted: self.predicted,
578 snapshot_after_last_editing_pause: None,
579 last_edit_time: self.last_edit_time,
580 };
581
582 let after = LastEvent {
583 old_snapshot: boundary_snapshot.clone(),
584 new_snapshot: self.new_snapshot.clone(),
585 old_file: self.old_file.clone(),
586 new_file: self.new_file.clone(),
587 edit_range: None,
588 predicted: self.predicted,
589 snapshot_after_last_editing_pause: None,
590 last_edit_time: self.last_edit_time,
591 };
592
593 (before, Some(after))
594 }
595}
596
597pub(crate) fn compute_diff_between_snapshots(
598 old_snapshot: &TextBufferSnapshot,
599 new_snapshot: &TextBufferSnapshot,
600) -> Option<(String, Range<Point>)> {
601 let edits: Vec<Edit<usize>> = new_snapshot
602 .edits_since::<usize>(&old_snapshot.version)
603 .collect();
604
605 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
606
607 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
608 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
609 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
610 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
611
612 const CONTEXT_LINES: u32 = 3;
613
614 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
615 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
616 let old_context_end_row =
617 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
618 let new_context_end_row =
619 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
620
621 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
622 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
623 let old_end_line_offset = old_snapshot
624 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
625 let new_end_line_offset = new_snapshot
626 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
627 let old_edit_range = old_start_line_offset..old_end_line_offset;
628 let new_edit_range = new_start_line_offset..new_end_line_offset;
629
630 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
631 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
632
633 let diff = language::unified_diff_with_offsets(
634 &old_region_text,
635 &new_region_text,
636 old_context_start_row,
637 new_context_start_row,
638 );
639
640 Some((diff, new_start_point..new_end_point))
641}
642
643fn buffer_path_with_id_fallback(
644 file: Option<&Arc<dyn File>>,
645 snapshot: &TextBufferSnapshot,
646 cx: &App,
647) -> Arc<Path> {
648 if let Some(file) = file {
649 file.full_path(cx).into()
650 } else {
651 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
652 }
653}
654
655impl EditPredictionStore {
656 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
657 cx.try_global::<EditPredictionStoreGlobal>()
658 .map(|global| global.0.clone())
659 }
660
661 pub fn global(
662 client: &Arc<Client>,
663 user_store: &Entity<UserStore>,
664 cx: &mut App,
665 ) -> Entity<Self> {
666 cx.try_global::<EditPredictionStoreGlobal>()
667 .map(|global| global.0.clone())
668 .unwrap_or_else(|| {
669 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
670 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
671 ep_store
672 })
673 }
674
675 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
676 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
677 let data_collection_choice = Self::load_data_collection_choice();
678
679 let llm_token = LlmApiToken::default();
680
681 let (reject_tx, reject_rx) = mpsc::unbounded();
682 cx.background_spawn({
683 let client = client.clone();
684 let llm_token = llm_token.clone();
685 let app_version = AppVersion::global(cx);
686 let background_executor = cx.background_executor().clone();
687 async move {
688 Self::handle_rejected_predictions(
689 reject_rx,
690 client,
691 llm_token,
692 app_version,
693 background_executor,
694 )
695 .await
696 }
697 })
698 .detach();
699
700 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
701 cx.spawn(async move |this, cx| {
702 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
703 })
704 .detach();
705
706 let mut current_user = user_store.read(cx).watch_current_user();
707 let fetch_experiments_task = cx.spawn(async move |this, cx| {
708 while current_user.borrow().is_none() {
709 current_user.next().await;
710 }
711 this.update(cx, |this, cx| {
712 this.refresh_available_experiments(cx);
713 })
714 .log_err();
715 });
716
717 let this = Self {
718 projects: HashMap::default(),
719 client,
720 user_store,
721 llm_token,
722 _fetch_experiments_task: fetch_experiments_task,
723 _llm_token_subscription: cx.subscribe(
724 &refresh_llm_token_listener,
725 |this, _listener, _event, cx| {
726 let client = this.client.clone();
727 let llm_token = this.llm_token.clone();
728 let organization_id = this
729 .user_store
730 .read(cx)
731 .current_organization()
732 .map(|organization| organization.id.clone());
733 cx.spawn(async move |_this, _cx| {
734 llm_token.refresh(&client, organization_id).await?;
735 anyhow::Ok(())
736 })
737 .detach_and_log_err(cx);
738 },
739 ),
740 update_required: false,
741 edit_prediction_model: EditPredictionModel::Zeta,
742 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
743 preferred_experiment: None,
744 available_experiments: Vec::new(),
745 sweep_ai: SweepAi::new(cx),
746 mercury: Mercury::new(cx),
747
748 data_collection_choice,
749 reject_predictions_tx: reject_tx,
750 settled_predictions_tx,
751 rated_predictions: Default::default(),
752 shown_predictions: Default::default(),
753 #[cfg(test)]
754 settled_event_callback: None,
755 };
756
757 this
758 }
759
760 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
761 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
762 let format = ZetaFormat::parse(&version_str).ok()?;
763 let model_id = env::var("ZED_ZETA_MODEL").ok();
764 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
765 Some(Zeta2RawConfig {
766 model_id,
767 environment,
768 format,
769 })
770 }
771
772 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
773 self.edit_prediction_model = model;
774 }
775
776 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
777 self.zeta2_raw_config = Some(config);
778 }
779
780 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
781 self.zeta2_raw_config.as_ref()
782 }
783
784 pub fn preferred_experiment(&self) -> Option<&str> {
785 self.preferred_experiment.as_deref()
786 }
787
788 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
789 self.preferred_experiment = experiment;
790 }
791
792 pub fn available_experiments(&self) -> &[String] {
793 &self.available_experiments
794 }
795
796 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
797 let client = self.client.clone();
798 let llm_token = self.llm_token.clone();
799 let app_version = AppVersion::global(cx);
800 let organization_id = self
801 .user_store
802 .read(cx)
803 .current_organization()
804 .map(|organization| organization.id.clone());
805
806 cx.spawn(async move |this, cx| {
807 let experiments = cx
808 .background_spawn(async move {
809 let http_client = client.http_client();
810 let token = llm_token.acquire(&client, organization_id).await?;
811 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
812 let request = http_client::Request::builder()
813 .method(Method::GET)
814 .uri(url.as_ref())
815 .header("Authorization", format!("Bearer {}", token))
816 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
817 .body(Default::default())?;
818 let mut response = http_client.send(request).await?;
819 if response.status().is_success() {
820 let mut body = Vec::new();
821 response.body_mut().read_to_end(&mut body).await?;
822 let experiments: Vec<String> = serde_json::from_slice(&body)?;
823 Ok(experiments)
824 } else {
825 let mut body = String::new();
826 response.body_mut().read_to_string(&mut body).await?;
827 anyhow::bail!(
828 "Failed to fetch experiments: {:?}\nBody: {}",
829 response.status(),
830 body
831 );
832 }
833 })
834 .await?;
835 this.update(cx, |this, cx| {
836 this.available_experiments = experiments;
837 cx.notify();
838 })?;
839 anyhow::Ok(())
840 })
841 .detach_and_log_err(cx);
842 }
843
844 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
845 use ui::IconName;
846 match self.edit_prediction_model {
847 EditPredictionModel::Sweep => {
848 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
849 .with_disabled(IconName::SweepAiDisabled)
850 .with_up(IconName::SweepAiUp)
851 .with_down(IconName::SweepAiDown)
852 .with_error(IconName::SweepAiError)
853 }
854 EditPredictionModel::Mercury => {
855 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
856 }
857 EditPredictionModel::Zeta => {
858 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
859 .with_disabled(IconName::ZedPredictDisabled)
860 .with_up(IconName::ZedPredictUp)
861 .with_down(IconName::ZedPredictDown)
862 .with_error(IconName::ZedPredictError)
863 }
864 EditPredictionModel::Fim { .. } => {
865 let settings = &all_language_settings(None, cx).edit_predictions;
866 match settings.provider {
867 EditPredictionProvider::Ollama => {
868 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
869 }
870 _ => {
871 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
872 }
873 }
874 }
875 }
876 }
877
878 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
879 self.sweep_ai.api_token.read(cx).has_key()
880 }
881
882 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
883 self.mercury.api_token.read(cx).has_key()
884 }
885
886 pub fn clear_history(&mut self) {
887 for project_state in self.projects.values_mut() {
888 project_state.events.clear();
889 project_state.last_event.take();
890 }
891 }
892
893 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
894 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
895 project_state.events.clear();
896 project_state.last_event.take();
897 }
898 }
899
900 pub fn edit_history_for_project(
901 &self,
902 project: &Entity<Project>,
903 cx: &App,
904 ) -> Vec<StoredEvent> {
905 self.projects
906 .get(&project.entity_id())
907 .map(|project_state| project_state.events(cx))
908 .unwrap_or_default()
909 }
910
911 pub fn context_for_project<'a>(
912 &'a self,
913 project: &Entity<Project>,
914 cx: &'a mut App,
915 ) -> Vec<RelatedFile> {
916 self.projects
917 .get(&project.entity_id())
918 .map(|project_state| {
919 project_state.context.update(cx, |context, cx| {
920 context
921 .related_files_with_buffers(cx)
922 .map(|(mut related_file, buffer)| {
923 related_file.in_open_source_repo = buffer
924 .read(cx)
925 .file()
926 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
927 related_file
928 })
929 .collect()
930 })
931 })
932 .unwrap_or_default()
933 }
934
935 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
936 self.projects
937 .get(&project.entity_id())
938 .and_then(|project| project.copilot.clone())
939 }
940
941 pub fn start_copilot_for_project(
942 &mut self,
943 project: &Entity<Project>,
944 cx: &mut Context<Self>,
945 ) -> Option<Entity<Copilot>> {
946 if DisableAiSettings::get(None, cx).disable_ai {
947 return None;
948 }
949 let state = self.get_or_init_project(project, cx);
950
951 if state.copilot.is_some() {
952 return state.copilot.clone();
953 }
954 let _project = project.clone();
955 let project = project.read(cx);
956
957 let node = project.node_runtime().cloned();
958 if let Some(node) = node {
959 let next_id = project.languages().next_language_server_id();
960 let fs = project.fs().clone();
961
962 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
963 state.copilot = Some(copilot.clone());
964 Some(copilot)
965 } else {
966 None
967 }
968 }
969
970 pub fn context_for_project_with_buffers<'a>(
971 &'a self,
972 project: &Entity<Project>,
973 cx: &'a mut App,
974 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
975 self.projects
976 .get(&project.entity_id())
977 .map(|project| {
978 project.context.update(cx, |context, cx| {
979 context.related_files_with_buffers(cx).collect()
980 })
981 })
982 .unwrap_or_default()
983 }
984
985 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
986 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
987 self.user_store.read(cx).edit_prediction_usage()
988 } else {
989 None
990 }
991 }
992
993 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
994 self.get_or_init_project(project, cx);
995 }
996
997 pub fn register_buffer(
998 &mut self,
999 buffer: &Entity<Buffer>,
1000 project: &Entity<Project>,
1001 cx: &mut Context<Self>,
1002 ) {
1003 let project_state = self.get_or_init_project(project, cx);
1004 Self::register_buffer_impl(project_state, buffer, project, cx);
1005 }
1006
1007 fn get_or_init_project(
1008 &mut self,
1009 project: &Entity<Project>,
1010 cx: &mut Context<Self>,
1011 ) -> &mut ProjectState {
1012 let entity_id = project.entity_id();
1013 self.projects
1014 .entry(entity_id)
1015 .or_insert_with(|| ProjectState {
1016 context: {
1017 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1018 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1019 this.handle_excerpt_store_event(entity_id, event);
1020 })
1021 .detach();
1022 related_excerpt_store
1023 },
1024 events: VecDeque::new(),
1025 last_event: None,
1026 recent_paths: VecDeque::new(),
1027 debug_tx: None,
1028 registered_buffers: HashMap::default(),
1029 current_prediction: None,
1030 cancelled_predictions: HashSet::default(),
1031 pending_predictions: ArrayVec::new(),
1032 next_pending_prediction_id: 0,
1033 last_edit_prediction_refresh: None,
1034 last_jump_prediction_refresh: None,
1035 license_detection_watchers: HashMap::default(),
1036 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1037 _subscriptions: [
1038 cx.subscribe(&project, Self::handle_project_event),
1039 cx.observe_release(&project, move |this, _, cx| {
1040 this.projects.remove(&entity_id);
1041 cx.notify();
1042 }),
1043 ],
1044 copilot: None,
1045 })
1046 }
1047
1048 pub fn remove_project(&mut self, project: &Entity<Project>) {
1049 self.projects.remove(&project.entity_id());
1050 }
1051
1052 fn handle_excerpt_store_event(
1053 &mut self,
1054 project_entity_id: EntityId,
1055 event: &RelatedExcerptStoreEvent,
1056 ) {
1057 if let Some(project_state) = self.projects.get(&project_entity_id) {
1058 if let Some(debug_tx) = project_state.debug_tx.clone() {
1059 match event {
1060 RelatedExcerptStoreEvent::StartedRefresh => {
1061 debug_tx
1062 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1063 ContextRetrievalStartedDebugEvent {
1064 project_entity_id: project_entity_id,
1065 timestamp: Instant::now(),
1066 search_prompt: String::new(),
1067 },
1068 ))
1069 .ok();
1070 }
1071 RelatedExcerptStoreEvent::FinishedRefresh {
1072 cache_hit_count,
1073 cache_miss_count,
1074 mean_definition_latency,
1075 max_definition_latency,
1076 } => {
1077 debug_tx
1078 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1079 ContextRetrievalFinishedDebugEvent {
1080 project_entity_id: project_entity_id,
1081 timestamp: Instant::now(),
1082 metadata: vec![
1083 (
1084 "Cache Hits",
1085 format!(
1086 "{}/{}",
1087 cache_hit_count,
1088 cache_hit_count + cache_miss_count
1089 )
1090 .into(),
1091 ),
1092 (
1093 "Max LSP Time",
1094 format!("{} ms", max_definition_latency.as_millis())
1095 .into(),
1096 ),
1097 (
1098 "Mean LSP Time",
1099 format!("{} ms", mean_definition_latency.as_millis())
1100 .into(),
1101 ),
1102 ],
1103 },
1104 ))
1105 .ok();
1106 }
1107 }
1108 }
1109 }
1110 }
1111
1112 pub fn debug_info(
1113 &mut self,
1114 project: &Entity<Project>,
1115 cx: &mut Context<Self>,
1116 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1117 let project_state = self.get_or_init_project(project, cx);
1118 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1119 project_state.debug_tx = Some(debug_watch_tx);
1120 debug_watch_rx
1121 }
1122
1123 fn handle_project_event(
1124 &mut self,
1125 project: Entity<Project>,
1126 event: &project::Event,
1127 cx: &mut Context<Self>,
1128 ) {
1129 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1130 return;
1131 }
1132 // TODO [zeta2] init with recent paths
1133 match event {
1134 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1135 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1136 return;
1137 };
1138 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1139 if let Some(path) = path {
1140 if let Some(ix) = project_state
1141 .recent_paths
1142 .iter()
1143 .position(|probe| probe == &path)
1144 {
1145 project_state.recent_paths.remove(ix);
1146 }
1147 project_state.recent_paths.push_front(path);
1148 }
1149 }
1150 project::Event::DiagnosticsUpdated { .. } => {
1151 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1152 self.refresh_prediction_from_diagnostics(
1153 project,
1154 DiagnosticSearchScope::Global,
1155 cx,
1156 );
1157 }
1158 }
1159 _ => (),
1160 }
1161 }
1162
1163 fn register_buffer_impl<'a>(
1164 project_state: &'a mut ProjectState,
1165 buffer: &Entity<Buffer>,
1166 project: &Entity<Project>,
1167 cx: &mut Context<Self>,
1168 ) -> &'a mut RegisteredBuffer {
1169 let buffer_id = buffer.entity_id();
1170
1171 if let Some(file) = buffer.read(cx).file() {
1172 let worktree_id = file.worktree_id(cx);
1173 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1174 project_state
1175 .license_detection_watchers
1176 .entry(worktree_id)
1177 .or_insert_with(|| {
1178 let project_entity_id = project.entity_id();
1179 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1180 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1181 else {
1182 return;
1183 };
1184 project_state
1185 .license_detection_watchers
1186 .remove(&worktree_id);
1187 })
1188 .detach();
1189 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1190 });
1191 }
1192 }
1193
1194 match project_state.registered_buffers.entry(buffer_id) {
1195 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1196 hash_map::Entry::Vacant(entry) => {
1197 let buf = buffer.read(cx);
1198 let snapshot = buf.text_snapshot();
1199 let file = buf.file().cloned();
1200 let project_entity_id = project.entity_id();
1201 entry.insert(RegisteredBuffer {
1202 snapshot,
1203 file,
1204 last_position: None,
1205 pending_predictions: Vec::new(),
1206 _subscriptions: [
1207 cx.subscribe(buffer, {
1208 let project = project.downgrade();
1209 move |this, buffer, event, cx| {
1210 if let language::BufferEvent::Edited = event
1211 && let Some(project) = project.upgrade()
1212 {
1213 this.report_changes_for_buffer(&buffer, &project, false, cx);
1214 }
1215 }
1216 }),
1217 cx.observe_release(buffer, move |this, _buffer, _cx| {
1218 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1219 else {
1220 return;
1221 };
1222 project_state.registered_buffers.remove(&buffer_id);
1223 }),
1224 ],
1225 })
1226 }
1227 }
1228 }
1229
1230 fn report_changes_for_buffer(
1231 &mut self,
1232 buffer: &Entity<Buffer>,
1233 project: &Entity<Project>,
1234 is_predicted: bool,
1235 cx: &mut Context<Self>,
1236 ) {
1237 let project_state = self.get_or_init_project(project, cx);
1238 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1239
1240 let buf = buffer.read(cx);
1241 let new_file = buf.file().cloned();
1242 let new_snapshot = buf.text_snapshot();
1243 if new_snapshot.version == registered_buffer.snapshot.version {
1244 return;
1245 }
1246
1247 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1248 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1249 let mut num_edits = 0usize;
1250 let mut total_deleted = 0usize;
1251 let mut total_inserted = 0usize;
1252 let mut edit_range: Option<Range<Anchor>> = None;
1253 let mut last_offset: Option<usize> = None;
1254 let now = cx.background_executor().now();
1255
1256 for (edit, anchor_range) in
1257 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1258 {
1259 num_edits += 1;
1260 total_deleted += edit.old.len();
1261 total_inserted += edit.new.len();
1262 edit_range = Some(match edit_range {
1263 None => anchor_range,
1264 Some(acc) => acc.start..anchor_range.end,
1265 });
1266 last_offset = Some(edit.new.end);
1267 }
1268
1269 let Some(edit_range) = edit_range else {
1270 return;
1271 };
1272
1273 for pending_prediction in &mut registered_buffer.pending_predictions {
1274 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1275 pending_prediction.last_edit_at = now;
1276 }
1277 }
1278
1279 let action_type = match (total_deleted, total_inserted, num_edits) {
1280 (0, ins, n) if ins == n => UserActionType::InsertChar,
1281 (0, _, _) => UserActionType::InsertSelection,
1282 (del, 0, n) if del == n => UserActionType::DeleteChar,
1283 (_, 0, _) => UserActionType::DeleteSelection,
1284 (_, ins, n) if ins == n => UserActionType::InsertChar,
1285 (_, _, _) => UserActionType::InsertSelection,
1286 };
1287
1288 if let Some(offset) = last_offset {
1289 let point = new_snapshot.offset_to_point(offset);
1290 let timestamp_epoch_ms = SystemTime::now()
1291 .duration_since(UNIX_EPOCH)
1292 .map(|d| d.as_millis() as u64)
1293 .unwrap_or(0);
1294 project_state.record_user_action(UserActionRecord {
1295 action_type,
1296 buffer_id: buffer.entity_id(),
1297 line_number: point.row,
1298 offset,
1299 timestamp_epoch_ms,
1300 });
1301 }
1302
1303 let events = &mut project_state.events;
1304
1305 if let Some(last_event) = project_state.last_event.as_mut() {
1306 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1307 == last_event.new_snapshot.remote_id()
1308 && old_snapshot.version == last_event.new_snapshot.version;
1309
1310 let prediction_source_changed = is_predicted != last_event.predicted;
1311
1312 let should_coalesce = is_next_snapshot_of_same_buffer
1313 && !prediction_source_changed
1314 && last_event
1315 .edit_range
1316 .as_ref()
1317 .is_some_and(|last_edit_range| {
1318 lines_between_ranges(
1319 &edit_range.to_point(&new_snapshot),
1320 &last_edit_range.to_point(&new_snapshot),
1321 ) <= CHANGE_GROUPING_LINE_SPAN
1322 });
1323
1324 if should_coalesce {
1325 let pause_elapsed = last_event
1326 .last_edit_time
1327 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1328 .unwrap_or(false);
1329 if pause_elapsed {
1330 last_event.snapshot_after_last_editing_pause =
1331 Some(last_event.new_snapshot.clone());
1332 }
1333
1334 last_event.edit_range = Some(edit_range);
1335 last_event.new_snapshot = new_snapshot;
1336 last_event.last_edit_time = Some(now);
1337 return;
1338 }
1339 }
1340
1341 if let Some(event) = project_state.last_event.take() {
1342 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1343 if events.len() + 1 >= EVENT_COUNT_MAX {
1344 events.pop_front();
1345 }
1346 events.push_back(event);
1347 }
1348 }
1349
1350 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1351
1352 project_state.last_event = Some(LastEvent {
1353 old_file,
1354 new_file,
1355 old_snapshot,
1356 new_snapshot,
1357 edit_range: Some(edit_range),
1358 predicted: is_predicted,
1359 snapshot_after_last_editing_pause: None,
1360 last_edit_time: Some(now),
1361 });
1362 }
1363
1364 fn prediction_at(
1365 &mut self,
1366 buffer: &Entity<Buffer>,
1367 position: Option<language::Anchor>,
1368 project: &Entity<Project>,
1369 cx: &App,
1370 ) -> Option<BufferEditPrediction<'_>> {
1371 let project_state = self.projects.get_mut(&project.entity_id())?;
1372 if let Some(position) = position
1373 && let Some(buffer) = project_state
1374 .registered_buffers
1375 .get_mut(&buffer.entity_id())
1376 {
1377 buffer.last_position = Some(position);
1378 }
1379
1380 let CurrentEditPrediction {
1381 requested_by,
1382 prediction,
1383 ..
1384 } = project_state.current_prediction.as_ref()?;
1385
1386 if prediction.targets_buffer(buffer.read(cx)) {
1387 Some(BufferEditPrediction::Local { prediction })
1388 } else {
1389 let show_jump = match requested_by {
1390 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1391 requested_by_buffer_id == &buffer.entity_id()
1392 }
1393 PredictionRequestedBy::DiagnosticsUpdate => true,
1394 };
1395
1396 if show_jump {
1397 Some(BufferEditPrediction::Jump { prediction })
1398 } else {
1399 None
1400 }
1401 }
1402 }
1403
1404 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1405 let Some(current_prediction) = self
1406 .projects
1407 .get_mut(&project.entity_id())
1408 .and_then(|project_state| project_state.current_prediction.take())
1409 else {
1410 return;
1411 };
1412
1413 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1414
1415 // can't hold &mut project_state ref across report_changes_for_buffer_call
1416 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1417 return;
1418 };
1419
1420 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1421 project_state.cancel_pending_prediction(pending_prediction, cx);
1422 }
1423
1424 match self.edit_prediction_model {
1425 EditPredictionModel::Sweep => {
1426 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1427 }
1428 EditPredictionModel::Mercury => {
1429 mercury::edit_prediction_accepted(
1430 current_prediction.prediction.id,
1431 self.client.http_client(),
1432 cx,
1433 );
1434 }
1435 EditPredictionModel::Zeta => {
1436 let is_cloud = !matches!(
1437 all_language_settings(None, cx).edit_predictions.provider,
1438 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1439 );
1440 if is_cloud {
1441 zeta::edit_prediction_accepted(self, current_prediction, cx)
1442 }
1443 }
1444 EditPredictionModel::Fim { .. } => {}
1445 }
1446 }
1447
1448 async fn handle_rejected_predictions(
1449 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1450 client: Arc<Client>,
1451 llm_token: LlmApiToken,
1452 app_version: Version,
1453 background_executor: BackgroundExecutor,
1454 ) {
1455 let mut rx = std::pin::pin!(rx.peekable());
1456 let mut batched = Vec::new();
1457
1458 while let Some(EditPredictionRejectionPayload {
1459 rejection,
1460 organization_id,
1461 }) = rx.next().await
1462 {
1463 batched.push(rejection);
1464
1465 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1466 select_biased! {
1467 next = rx.as_mut().peek().fuse() => {
1468 if next.is_some() {
1469 continue;
1470 }
1471 }
1472 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1473 }
1474 }
1475
1476 let url = client
1477 .http_client()
1478 .build_zed_llm_url("/predict_edits/reject", &[])
1479 .unwrap();
1480
1481 let flush_count = batched
1482 .len()
1483 // in case items have accumulated after failure
1484 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1485 let start = batched.len() - flush_count;
1486
1487 let body = RejectEditPredictionsBodyRef {
1488 rejections: &batched[start..],
1489 };
1490
1491 let result = Self::send_api_request::<()>(
1492 |builder| {
1493 let req = builder
1494 .uri(url.as_ref())
1495 .body(serde_json::to_string(&body)?.into());
1496 anyhow::Ok(req?)
1497 },
1498 client.clone(),
1499 llm_token.clone(),
1500 organization_id,
1501 app_version.clone(),
1502 true,
1503 )
1504 .await;
1505
1506 if result.log_err().is_some() {
1507 batched.drain(start..);
1508 }
1509 }
1510 }
1511
1512 async fn run_settled_predictions_worker(
1513 this: WeakEntity<Self>,
1514 mut rx: UnboundedReceiver<Instant>,
1515 cx: &mut AsyncApp,
1516 ) {
1517 let mut next_wake_time: Option<Instant> = None;
1518 loop {
1519 let now = cx.background_executor().now();
1520 if let Some(wake_time) = next_wake_time.take() {
1521 cx.background_executor()
1522 .timer(wake_time.duration_since(now))
1523 .await;
1524 } else {
1525 let Some(new_enqueue_time) = rx.next().await else {
1526 break;
1527 };
1528 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1529 while rx.next().now_or_never().flatten().is_some() {}
1530 continue;
1531 }
1532
1533 let Some(this) = this.upgrade() else {
1534 break;
1535 };
1536
1537 let now = cx.background_executor().now();
1538
1539 let mut oldest_edited_at = None;
1540
1541 this.update(cx, |this, _| {
1542 for (_, project_state) in this.projects.iter_mut() {
1543 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1544 registered_buffer
1545 .pending_predictions
1546 .retain_mut(|pending_prediction| {
1547 let age =
1548 now.saturating_duration_since(pending_prediction.enqueued_at);
1549 if age >= EDIT_PREDICTION_SETTLED_TTL {
1550 return false;
1551 }
1552
1553 let quiet_for =
1554 now.saturating_duration_since(pending_prediction.last_edit_at);
1555 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1556 let settled_editable_region = registered_buffer
1557 .snapshot
1558 .text_for_range(
1559 pending_prediction.editable_anchor_range.clone(),
1560 )
1561 .collect::<String>();
1562
1563 #[cfg(test)]
1564 if let Some(callback) = &this.settled_event_callback {
1565 callback(
1566 pending_prediction.request_id.clone(),
1567 settled_editable_region.clone(),
1568 );
1569 }
1570
1571 telemetry::event!(
1572 EDIT_PREDICTION_SETTLED_EVENT,
1573 request_id = pending_prediction.request_id.0.clone(),
1574 settled_editable_region,
1575 );
1576
1577 return false;
1578 }
1579
1580 if oldest_edited_at
1581 .is_none_or(|t| pending_prediction.last_edit_at < t)
1582 {
1583 oldest_edited_at = Some(pending_prediction.last_edit_at);
1584 }
1585
1586 true
1587 });
1588 }
1589 }
1590 });
1591
1592 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1593 }
1594 }
1595
1596 pub(crate) fn enqueue_settled_prediction(
1597 &mut self,
1598 request_id: EditPredictionId,
1599 project: &Entity<Project>,
1600 edited_buffer: &Entity<Buffer>,
1601 edited_buffer_snapshot: &BufferSnapshot,
1602 editable_offset_range: Range<usize>,
1603 cx: &mut Context<Self>,
1604 ) {
1605 let project_state = self.get_or_init_project(project, cx);
1606 if let Some(buffer) = project_state
1607 .registered_buffers
1608 .get_mut(&edited_buffer.entity_id())
1609 {
1610 let now = cx.background_executor().now();
1611 buffer.pending_predictions.push(PendingSettledPrediction {
1612 request_id,
1613 editable_anchor_range: edited_buffer_snapshot
1614 .anchor_range_around(editable_offset_range),
1615 enqueued_at: now,
1616 last_edit_at: now,
1617 });
1618 self.settled_predictions_tx.unbounded_send(now).ok();
1619 }
1620 }
1621
1622 fn reject_current_prediction(
1623 &mut self,
1624 reason: EditPredictionRejectReason,
1625 project: &Entity<Project>,
1626 cx: &App,
1627 ) {
1628 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1629 project_state.pending_predictions.clear();
1630 if let Some(prediction) = project_state.current_prediction.take() {
1631 let model_version = prediction.prediction.model_version.clone();
1632 self.reject_prediction(
1633 prediction.prediction.id,
1634 reason,
1635 prediction.was_shown,
1636 model_version,
1637 cx,
1638 );
1639 }
1640 };
1641 }
1642
1643 fn did_show_current_prediction(
1644 &mut self,
1645 project: &Entity<Project>,
1646 display_type: edit_prediction_types::SuggestionDisplayType,
1647 cx: &mut Context<Self>,
1648 ) {
1649 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1650 return;
1651 };
1652
1653 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1654 return;
1655 };
1656
1657 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1658 let previous_shown_with = current_prediction.shown_with;
1659
1660 if previous_shown_with.is_none() || !is_jump {
1661 current_prediction.shown_with = Some(display_type);
1662 }
1663
1664 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1665
1666 if is_first_non_jump_show {
1667 current_prediction.was_shown = true;
1668 }
1669
1670 let display_type_changed = previous_shown_with != Some(display_type);
1671
1672 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1673 sweep_ai::edit_prediction_shown(
1674 &self.sweep_ai,
1675 self.client.clone(),
1676 ¤t_prediction.prediction,
1677 display_type,
1678 cx,
1679 );
1680 }
1681
1682 if is_first_non_jump_show {
1683 self.shown_predictions
1684 .push_front(current_prediction.prediction.clone());
1685 if self.shown_predictions.len() > 50 {
1686 let completion = self.shown_predictions.pop_back().unwrap();
1687 self.rated_predictions.remove(&completion.id);
1688 }
1689 }
1690 }
1691
1692 fn reject_prediction(
1693 &mut self,
1694 prediction_id: EditPredictionId,
1695 reason: EditPredictionRejectReason,
1696 was_shown: bool,
1697 model_version: Option<String>,
1698 cx: &App,
1699 ) {
1700 match self.edit_prediction_model {
1701 EditPredictionModel::Zeta => {
1702 let is_cloud = !matches!(
1703 all_language_settings(None, cx).edit_predictions.provider,
1704 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1705 );
1706
1707 if is_cloud {
1708 let organization_id = self
1709 .user_store
1710 .read(cx)
1711 .current_organization()
1712 .map(|organization| organization.id.clone());
1713
1714 self.reject_predictions_tx
1715 .unbounded_send(EditPredictionRejectionPayload {
1716 rejection: EditPredictionRejection {
1717 request_id: prediction_id.to_string(),
1718 reason,
1719 was_shown,
1720 model_version,
1721 },
1722 organization_id,
1723 })
1724 .log_err();
1725 }
1726 }
1727 EditPredictionModel::Mercury => {
1728 mercury::edit_prediction_rejected(
1729 prediction_id,
1730 was_shown,
1731 reason,
1732 self.client.http_client(),
1733 cx,
1734 );
1735 }
1736 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1737 }
1738 }
1739
1740 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1741 self.projects
1742 .get(&project.entity_id())
1743 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1744 }
1745
1746 pub fn refresh_prediction_from_buffer(
1747 &mut self,
1748 project: Entity<Project>,
1749 buffer: Entity<Buffer>,
1750 position: language::Anchor,
1751 cx: &mut Context<Self>,
1752 ) {
1753 self.queue_prediction_refresh(
1754 project.clone(),
1755 PredictEditsRequestTrigger::Other,
1756 buffer.entity_id(),
1757 cx,
1758 move |this, cx| {
1759 let Some(request_task) = this
1760 .update(cx, |this, cx| {
1761 this.request_prediction(
1762 &project,
1763 &buffer,
1764 position,
1765 PredictEditsRequestTrigger::Other,
1766 cx,
1767 )
1768 })
1769 .log_err()
1770 else {
1771 return Task::ready(anyhow::Ok(None));
1772 };
1773
1774 cx.spawn(async move |_cx| {
1775 request_task.await.map(|prediction_result| {
1776 prediction_result.map(|prediction_result| {
1777 (
1778 prediction_result,
1779 PredictionRequestedBy::Buffer(buffer.entity_id()),
1780 )
1781 })
1782 })
1783 })
1784 },
1785 )
1786 }
1787
1788 pub fn refresh_prediction_from_diagnostics(
1789 &mut self,
1790 project: Entity<Project>,
1791 scope: DiagnosticSearchScope,
1792 cx: &mut Context<Self>,
1793 ) {
1794 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1795 return;
1796 }
1797
1798 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1799 return;
1800 };
1801
1802 // Prefer predictions from buffer
1803 if project_state.current_prediction.is_some() {
1804 log::debug!(
1805 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1806 );
1807 return;
1808 }
1809
1810 self.queue_prediction_refresh(
1811 project.clone(),
1812 PredictEditsRequestTrigger::Diagnostics,
1813 project.entity_id(),
1814 cx,
1815 move |this, cx| {
1816 let Some((active_buffer, snapshot, cursor_point)) = this
1817 .read_with(cx, |this, cx| {
1818 let project_state = this.projects.get(&project.entity_id())?;
1819 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1820 let snapshot = buffer.read(cx).snapshot();
1821
1822 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1823 return None;
1824 }
1825
1826 let cursor_point = position
1827 .map(|pos| pos.to_point(&snapshot))
1828 .unwrap_or_default();
1829
1830 Some((buffer, snapshot, cursor_point))
1831 })
1832 .log_err()
1833 .flatten()
1834 else {
1835 return Task::ready(anyhow::Ok(None));
1836 };
1837
1838 cx.spawn(async move |cx| {
1839 let diagnostic_search_range = match scope {
1840 DiagnosticSearchScope::Local => {
1841 let diagnostic_search_start =
1842 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1843 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1844 Point::new(diagnostic_search_start, 0)
1845 ..Point::new(diagnostic_search_end, 0)
1846 }
1847 DiagnosticSearchScope::Global => Default::default(),
1848 };
1849
1850 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1851 active_buffer,
1852 &snapshot,
1853 diagnostic_search_range,
1854 cursor_point,
1855 &project,
1856 cx,
1857 )
1858 .await?
1859 else {
1860 return anyhow::Ok(None);
1861 };
1862
1863 let Some(prediction_result) = this
1864 .update(cx, |this, cx| {
1865 this.request_prediction(
1866 &project,
1867 &jump_buffer,
1868 jump_position,
1869 PredictEditsRequestTrigger::Diagnostics,
1870 cx,
1871 )
1872 })?
1873 .await?
1874 else {
1875 return anyhow::Ok(None);
1876 };
1877
1878 this.update(cx, |this, cx| {
1879 Some((
1880 if this
1881 .get_or_init_project(&project, cx)
1882 .current_prediction
1883 .is_none()
1884 {
1885 prediction_result
1886 } else {
1887 EditPredictionResult {
1888 id: prediction_result.id,
1889 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1890 }
1891 },
1892 PredictionRequestedBy::DiagnosticsUpdate,
1893 ))
1894 })
1895 })
1896 },
1897 );
1898 }
1899
1900 fn predictions_enabled_at(
1901 snapshot: &BufferSnapshot,
1902 position: Option<language::Anchor>,
1903 cx: &App,
1904 ) -> bool {
1905 let file = snapshot.file();
1906 let all_settings = all_language_settings(file, cx);
1907 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1908 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1909 {
1910 return false;
1911 }
1912
1913 if let Some(last_position) = position {
1914 let settings = snapshot.settings_at(last_position, cx);
1915
1916 if !settings.edit_predictions_disabled_in.is_empty()
1917 && let Some(scope) = snapshot.language_scope_at(last_position)
1918 && let Some(scope_name) = scope.override_name()
1919 && settings
1920 .edit_predictions_disabled_in
1921 .iter()
1922 .any(|s| s == scope_name)
1923 {
1924 return false;
1925 }
1926 }
1927
1928 true
1929 }
1930
1931 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1932}
1933
1934fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1935 match provider {
1936 EditPredictionProvider::Zed
1937 | EditPredictionProvider::Sweep
1938 | EditPredictionProvider::Mercury
1939 | EditPredictionProvider::Ollama
1940 | EditPredictionProvider::OpenAiCompatibleApi
1941 | EditPredictionProvider::Experimental(_) => true,
1942 EditPredictionProvider::None
1943 | EditPredictionProvider::Copilot
1944 | EditPredictionProvider::Codestral => false,
1945 }
1946}
1947
1948impl EditPredictionStore {
1949 fn queue_prediction_refresh(
1950 &mut self,
1951 project: Entity<Project>,
1952 request_trigger: PredictEditsRequestTrigger,
1953 throttle_entity: EntityId,
1954 cx: &mut Context<Self>,
1955 do_refresh: impl FnOnce(
1956 WeakEntity<Self>,
1957 &mut AsyncApp,
1958 )
1959 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1960 + 'static,
1961 ) {
1962 fn select_throttle(
1963 project_state: &mut ProjectState,
1964 request_trigger: PredictEditsRequestTrigger,
1965 ) -> &mut Option<(EntityId, Instant)> {
1966 match request_trigger {
1967 PredictEditsRequestTrigger::Diagnostics => {
1968 &mut project_state.last_jump_prediction_refresh
1969 }
1970 _ => &mut project_state.last_edit_prediction_refresh,
1971 }
1972 }
1973
1974 let (needs_acceptance_tracking, max_pending_predictions) =
1975 match all_language_settings(None, cx).edit_predictions.provider {
1976 EditPredictionProvider::Zed
1977 | EditPredictionProvider::Sweep
1978 | EditPredictionProvider::Mercury
1979 | EditPredictionProvider::Experimental(_) => (true, 2),
1980 EditPredictionProvider::Ollama => (false, 1),
1981 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1982 EditPredictionProvider::None
1983 | EditPredictionProvider::Copilot
1984 | EditPredictionProvider::Codestral => {
1985 log::error!("queue_prediction_refresh called with non-store provider");
1986 return;
1987 }
1988 };
1989
1990 let drop_on_cancel = !needs_acceptance_tracking;
1991 let throttle_timeout = Self::THROTTLE_TIMEOUT;
1992 let project_state = self.get_or_init_project(&project, cx);
1993 let pending_prediction_id = project_state.next_pending_prediction_id;
1994 project_state.next_pending_prediction_id += 1;
1995 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
1996
1997 let task = cx.spawn(async move |this, cx| {
1998 let throttle_wait = this
1999 .update(cx, |this, cx| {
2000 let project_state = this.get_or_init_project(&project, cx);
2001 let throttle = *select_throttle(project_state, request_trigger);
2002
2003 throttle.and_then(|(last_entity, last_timestamp)| {
2004 if throttle_entity != last_entity {
2005 return None;
2006 }
2007 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2008 })
2009 })
2010 .ok()
2011 .flatten();
2012
2013 if let Some(timeout) = throttle_wait {
2014 cx.background_executor().timer(timeout).await;
2015 }
2016
2017 // If this task was cancelled before the throttle timeout expired,
2018 // do not perform a request. Also skip if another task already
2019 // proceeded since we were enqueued (duplicate).
2020 let mut is_cancelled = true;
2021 this.update(cx, |this, cx| {
2022 let project_state = this.get_or_init_project(&project, cx);
2023 let was_cancelled = project_state
2024 .cancelled_predictions
2025 .remove(&pending_prediction_id);
2026 if was_cancelled {
2027 return;
2028 }
2029
2030 // Another request has been already sent since this was enqueued
2031 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2032 return;
2033 }
2034
2035 let new_refresh = (throttle_entity, Instant::now());
2036 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2037 is_cancelled = false;
2038 })
2039 .ok();
2040 if is_cancelled {
2041 return None;
2042 }
2043
2044 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2045 let new_prediction_id = new_prediction_result
2046 .as_ref()
2047 .map(|(prediction, _)| prediction.id.clone());
2048
2049 // When a prediction completes, remove it from the pending list, and cancel
2050 // any pending predictions that were enqueued before it.
2051 this.update(cx, |this, cx| {
2052 let project_state = this.get_or_init_project(&project, cx);
2053
2054 let is_cancelled = project_state
2055 .cancelled_predictions
2056 .remove(&pending_prediction_id);
2057
2058 let new_current_prediction = if !is_cancelled
2059 && let Some((prediction_result, requested_by)) = new_prediction_result
2060 {
2061 match prediction_result.prediction {
2062 Ok(prediction) => {
2063 let new_prediction = CurrentEditPrediction {
2064 requested_by,
2065 prediction,
2066 was_shown: false,
2067 shown_with: None,
2068 };
2069
2070 if let Some(current_prediction) =
2071 project_state.current_prediction.as_ref()
2072 {
2073 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2074 {
2075 this.reject_current_prediction(
2076 EditPredictionRejectReason::Replaced,
2077 &project,
2078 cx,
2079 );
2080
2081 Some(new_prediction)
2082 } else {
2083 this.reject_prediction(
2084 new_prediction.prediction.id,
2085 EditPredictionRejectReason::CurrentPreferred,
2086 false,
2087 new_prediction.prediction.model_version,
2088 cx,
2089 );
2090 None
2091 }
2092 } else {
2093 Some(new_prediction)
2094 }
2095 }
2096 Err(reject_reason) => {
2097 this.reject_prediction(
2098 prediction_result.id,
2099 reject_reason,
2100 false,
2101 None,
2102 cx,
2103 );
2104 None
2105 }
2106 }
2107 } else {
2108 None
2109 };
2110
2111 let project_state = this.get_or_init_project(&project, cx);
2112
2113 if let Some(new_prediction) = new_current_prediction {
2114 project_state.current_prediction = Some(new_prediction);
2115 }
2116
2117 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2118 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2119 if pending_prediction.id == pending_prediction_id {
2120 pending_predictions.remove(ix);
2121 for pending_prediction in pending_predictions.drain(0..ix) {
2122 project_state.cancel_pending_prediction(pending_prediction, cx)
2123 }
2124 break;
2125 }
2126 }
2127 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2128 cx.notify();
2129 })
2130 .ok();
2131
2132 new_prediction_id
2133 });
2134
2135 if project_state.pending_predictions.len() < max_pending_predictions {
2136 project_state.pending_predictions.push(PendingPrediction {
2137 id: pending_prediction_id,
2138 task,
2139 drop_on_cancel,
2140 });
2141 } else {
2142 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2143 project_state.pending_predictions.push(PendingPrediction {
2144 id: pending_prediction_id,
2145 task,
2146 drop_on_cancel,
2147 });
2148 project_state.cancel_pending_prediction(pending_prediction, cx);
2149 }
2150 }
2151
2152 pub fn request_prediction(
2153 &mut self,
2154 project: &Entity<Project>,
2155 active_buffer: &Entity<Buffer>,
2156 position: language::Anchor,
2157 trigger: PredictEditsRequestTrigger,
2158 cx: &mut Context<Self>,
2159 ) -> Task<Result<Option<EditPredictionResult>>> {
2160 self.request_prediction_internal(
2161 project.clone(),
2162 active_buffer.clone(),
2163 position,
2164 trigger,
2165 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2166 cx,
2167 )
2168 }
2169
2170 fn request_prediction_internal(
2171 &mut self,
2172 project: Entity<Project>,
2173 active_buffer: Entity<Buffer>,
2174 position: language::Anchor,
2175 trigger: PredictEditsRequestTrigger,
2176 allow_jump: bool,
2177 cx: &mut Context<Self>,
2178 ) -> Task<Result<Option<EditPredictionResult>>> {
2179 self.get_or_init_project(&project, cx);
2180 let project_state = self.projects.get(&project.entity_id()).unwrap();
2181 let stored_events = project_state.events(cx);
2182 let has_events = !stored_events.is_empty();
2183 let events: Vec<Arc<zeta_prompt::Event>> =
2184 stored_events.iter().map(|e| e.event.clone()).collect();
2185 let debug_tx = project_state.debug_tx.clone();
2186
2187 let snapshot = active_buffer.read(cx).snapshot();
2188 let cursor_point = position.to_point(&snapshot);
2189 let current_offset = position.to_offset(&snapshot);
2190
2191 let mut user_actions: Vec<UserActionRecord> =
2192 project_state.user_actions.iter().cloned().collect();
2193
2194 if let Some(last_action) = user_actions.last() {
2195 if last_action.buffer_id == active_buffer.entity_id()
2196 && current_offset != last_action.offset
2197 {
2198 let timestamp_epoch_ms = SystemTime::now()
2199 .duration_since(UNIX_EPOCH)
2200 .map(|d| d.as_millis() as u64)
2201 .unwrap_or(0);
2202 user_actions.push(UserActionRecord {
2203 action_type: UserActionType::CursorMovement,
2204 buffer_id: active_buffer.entity_id(),
2205 line_number: cursor_point.row,
2206 offset: current_offset,
2207 timestamp_epoch_ms,
2208 });
2209 }
2210 }
2211 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2212 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2213 let diagnostic_search_range =
2214 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2215
2216 let related_files = self.context_for_project(&project, cx);
2217
2218 let is_open_source = snapshot
2219 .file()
2220 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2221 && events.iter().all(|event| event.in_open_source_repo())
2222 && related_files.iter().all(|file| file.in_open_source_repo);
2223
2224 let can_collect_data = !cfg!(test)
2225 && is_open_source
2226 && self.is_data_collection_enabled(cx)
2227 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2228
2229 let inputs = EditPredictionModelInput {
2230 project: project.clone(),
2231 buffer: active_buffer.clone(),
2232 snapshot: snapshot,
2233 position,
2234 events,
2235 related_files,
2236 recent_paths: project_state.recent_paths.clone(),
2237 trigger,
2238 diagnostic_search_range: diagnostic_search_range,
2239 debug_tx,
2240 user_actions,
2241 can_collect_data,
2242 is_open_source,
2243 };
2244
2245 if can_collect_data && rand::random_ratio(1, 1000) {
2246 if let Some(task) = capture_example(
2247 project.clone(),
2248 active_buffer,
2249 position,
2250 stored_events,
2251 false,
2252 cx,
2253 ) {
2254 task.detach();
2255 }
2256 }
2257
2258 let task = match self.edit_prediction_model {
2259 EditPredictionModel::Zeta => zeta::request_prediction_with_zeta(self, inputs, cx),
2260 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2261 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2262 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2263 };
2264
2265 cx.spawn(async move |this, cx| {
2266 let prediction = task.await?;
2267
2268 // Only fall back to diagnostics-based prediction if we got a
2269 // the model had nothing to suggest for the buffer
2270 if prediction.is_none()
2271 && allow_jump
2272 && has_events
2273 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2274 {
2275 this.update(cx, |this, cx| {
2276 this.refresh_prediction_from_diagnostics(
2277 project,
2278 DiagnosticSearchScope::Local,
2279 cx,
2280 );
2281 })?;
2282 return anyhow::Ok(None);
2283 }
2284
2285 Ok(prediction)
2286 })
2287 }
2288
2289 pub(crate) async fn next_diagnostic_location(
2290 active_buffer: Entity<Buffer>,
2291 active_buffer_snapshot: &BufferSnapshot,
2292 active_buffer_diagnostic_search_range: Range<Point>,
2293 active_buffer_cursor_point: Point,
2294 project: &Entity<Project>,
2295 cx: &mut AsyncApp,
2296 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2297 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2298 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2299 .flat_map(|(_, _, _, selections)| {
2300 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2301 })
2302 .collect();
2303
2304 let mut jump_location = active_buffer_snapshot
2305 .diagnostic_groups(None)
2306 .into_iter()
2307 .filter_map(|(_, group)| {
2308 let range = &group.entries[group.primary_ix]
2309 .range
2310 .to_point(&active_buffer_snapshot);
2311 if range.overlaps(&active_buffer_diagnostic_search_range) {
2312 return None;
2313 }
2314 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2315 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2316 });
2317 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2318 <= DIAGNOSTIC_LINES_RANGE;
2319 if near_collaborator && !near_local {
2320 return None;
2321 }
2322 Some(range.start)
2323 })
2324 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2325 .map(|position| {
2326 (
2327 active_buffer.clone(),
2328 active_buffer_snapshot.anchor_before(position),
2329 )
2330 });
2331
2332 if jump_location.is_none() {
2333 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2334 let file = buffer.file()?;
2335
2336 Some(ProjectPath {
2337 worktree_id: file.worktree_id(cx),
2338 path: file.path().clone(),
2339 })
2340 });
2341
2342 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2343 project
2344 .diagnostic_summaries(false, cx)
2345 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2346 .map(|(path, _, _)| {
2347 let shared_prefix = path
2348 .path
2349 .components()
2350 .zip(
2351 active_buffer_path
2352 .as_ref()
2353 .map(|p| p.path.components())
2354 .unwrap_or_default(),
2355 )
2356 .take_while(|(a, b)| a == b)
2357 .count();
2358 (path, shared_prefix)
2359 })
2360 .collect()
2361 });
2362
2363 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2364
2365 for (path, _) in candidates {
2366 let candidate_buffer = project
2367 .update(cx, |project, cx| project.open_buffer(path, cx))
2368 .await?;
2369
2370 let (has_collaborators, diagnostic_position) =
2371 candidate_buffer.read_with(cx, |buffer, _cx| {
2372 let snapshot = buffer.snapshot();
2373 let has_collaborators = snapshot
2374 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2375 .next()
2376 .is_some();
2377 let position = buffer
2378 .buffer_diagnostics(None)
2379 .into_iter()
2380 .min_by_key(|entry| entry.diagnostic.severity)
2381 .map(|entry| entry.range.start);
2382 (has_collaborators, position)
2383 });
2384
2385 if has_collaborators {
2386 continue;
2387 }
2388
2389 if let Some(position) = diagnostic_position {
2390 jump_location = Some((candidate_buffer, position));
2391 break;
2392 }
2393 }
2394 }
2395
2396 anyhow::Ok(jump_location)
2397 }
2398
2399 async fn send_raw_llm_request(
2400 request: RawCompletionRequest,
2401 client: Arc<Client>,
2402 custom_url: Option<Arc<Url>>,
2403 llm_token: LlmApiToken,
2404 organization_id: Option<OrganizationId>,
2405 app_version: Version,
2406 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2407 let url = if let Some(custom_url) = custom_url {
2408 custom_url.as_ref().clone()
2409 } else {
2410 client
2411 .http_client()
2412 .build_zed_llm_url("/predict_edits/raw", &[])?
2413 };
2414
2415 Self::send_api_request(
2416 |builder| {
2417 let req = builder
2418 .uri(url.as_ref())
2419 .body(serde_json::to_string(&request)?.into());
2420 Ok(req?)
2421 },
2422 client,
2423 llm_token,
2424 organization_id,
2425 app_version,
2426 true,
2427 )
2428 .await
2429 }
2430
2431 pub(crate) async fn send_v3_request(
2432 input: ZetaPromptInput,
2433 client: Arc<Client>,
2434 llm_token: LlmApiToken,
2435 organization_id: Option<OrganizationId>,
2436 app_version: Version,
2437 trigger: PredictEditsRequestTrigger,
2438 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2439 let url = client
2440 .http_client()
2441 .build_zed_llm_url("/predict_edits/v3", &[])?;
2442
2443 let request = PredictEditsV3Request { input, trigger };
2444
2445 let json_bytes = serde_json::to_vec(&request)?;
2446 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2447
2448 Self::send_api_request(
2449 |builder| {
2450 let req = builder
2451 .uri(url.as_ref())
2452 .header("Content-Encoding", "zstd")
2453 .body(compressed.clone().into());
2454 Ok(req?)
2455 },
2456 client,
2457 llm_token,
2458 organization_id,
2459 app_version,
2460 true,
2461 )
2462 .await
2463 }
2464
2465 fn handle_api_response<T>(
2466 this: &WeakEntity<Self>,
2467 response: Result<(T, Option<EditPredictionUsage>)>,
2468 cx: &mut gpui::AsyncApp,
2469 ) -> Result<T> {
2470 match response {
2471 Ok((data, usage)) => {
2472 if let Some(usage) = usage {
2473 this.update(cx, |this, cx| {
2474 this.user_store.update(cx, |user_store, cx| {
2475 user_store.update_edit_prediction_usage(usage, cx);
2476 });
2477 })
2478 .ok();
2479 }
2480 Ok(data)
2481 }
2482 Err(err) => {
2483 if err.is::<ZedUpdateRequiredError>() {
2484 cx.update(|cx| {
2485 this.update(cx, |this, _cx| {
2486 this.update_required = true;
2487 })
2488 .ok();
2489
2490 let error_message: SharedString = err.to_string().into();
2491 show_app_notification(
2492 NotificationId::unique::<ZedUpdateRequiredError>(),
2493 cx,
2494 move |cx| {
2495 cx.new(|cx| {
2496 ErrorMessagePrompt::new(error_message.clone(), cx)
2497 .with_link_button("Update Zed", "https://zed.dev/releases")
2498 })
2499 },
2500 );
2501 });
2502 }
2503 Err(err)
2504 }
2505 }
2506 }
2507
2508 async fn send_api_request<Res>(
2509 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2510 client: Arc<Client>,
2511 llm_token: LlmApiToken,
2512 organization_id: Option<OrganizationId>,
2513 app_version: Version,
2514 require_auth: bool,
2515 ) -> Result<(Res, Option<EditPredictionUsage>)>
2516 where
2517 Res: DeserializeOwned,
2518 {
2519 let http_client = client.http_client();
2520
2521 let mut token = if require_auth {
2522 Some(llm_token.acquire(&client, organization_id.clone()).await?)
2523 } else {
2524 llm_token
2525 .acquire(&client, organization_id.clone())
2526 .await
2527 .ok()
2528 };
2529 let mut did_retry = false;
2530
2531 loop {
2532 let request_builder = http_client::Request::builder().method(Method::POST);
2533
2534 let mut request_builder = request_builder
2535 .header("Content-Type", "application/json")
2536 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2537
2538 // Only add Authorization header if we have a token
2539 if let Some(ref token_value) = token {
2540 request_builder =
2541 request_builder.header("Authorization", format!("Bearer {}", token_value));
2542 }
2543
2544 let request = build(request_builder)?;
2545
2546 let mut response = http_client.send(request).await?;
2547
2548 if let Some(minimum_required_version) = response
2549 .headers()
2550 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2551 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2552 {
2553 anyhow::ensure!(
2554 app_version >= minimum_required_version,
2555 ZedUpdateRequiredError {
2556 minimum_version: minimum_required_version
2557 }
2558 );
2559 }
2560
2561 if response.status().is_success() {
2562 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2563
2564 let mut body = Vec::new();
2565 response.body_mut().read_to_end(&mut body).await?;
2566 return Ok((serde_json::from_slice(&body)?, usage));
2567 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2568 did_retry = true;
2569 token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2570 } else {
2571 let mut body = String::new();
2572 response.body_mut().read_to_string(&mut body).await?;
2573 anyhow::bail!(
2574 "Request failed with status: {:?}\nBody: {}",
2575 response.status(),
2576 body
2577 );
2578 }
2579 }
2580 }
2581
2582 pub fn refresh_context(
2583 &mut self,
2584 project: &Entity<Project>,
2585 buffer: &Entity<language::Buffer>,
2586 cursor_position: language::Anchor,
2587 cx: &mut Context<Self>,
2588 ) {
2589 self.get_or_init_project(project, cx)
2590 .context
2591 .update(cx, |store, cx| {
2592 store.refresh(buffer.clone(), cursor_position, cx);
2593 });
2594 }
2595
2596 #[cfg(feature = "cli-support")]
2597 pub fn set_context_for_buffer(
2598 &mut self,
2599 project: &Entity<Project>,
2600 related_files: Vec<RelatedFile>,
2601 cx: &mut Context<Self>,
2602 ) {
2603 self.get_or_init_project(project, cx)
2604 .context
2605 .update(cx, |store, cx| {
2606 store.set_related_files(related_files, cx);
2607 });
2608 }
2609
2610 #[cfg(feature = "cli-support")]
2611 pub fn set_recent_paths_for_project(
2612 &mut self,
2613 project: &Entity<Project>,
2614 paths: impl IntoIterator<Item = project::ProjectPath>,
2615 cx: &mut Context<Self>,
2616 ) {
2617 let project_state = self.get_or_init_project(project, cx);
2618 project_state.recent_paths = paths.into_iter().collect();
2619 }
2620
2621 fn is_file_open_source(
2622 &self,
2623 project: &Entity<Project>,
2624 file: &Arc<dyn File>,
2625 cx: &App,
2626 ) -> bool {
2627 if !file.is_local() || file.is_private() {
2628 return false;
2629 }
2630 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2631 return false;
2632 };
2633 project_state
2634 .license_detection_watchers
2635 .get(&file.worktree_id(cx))
2636 .as_ref()
2637 .is_some_and(|watcher| watcher.is_project_open_source())
2638 }
2639
2640 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2641 self.data_collection_choice.is_enabled(cx)
2642 }
2643
2644 fn load_data_collection_choice() -> DataCollectionChoice {
2645 let choice = KEY_VALUE_STORE
2646 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2647 .log_err()
2648 .flatten();
2649
2650 match choice.as_deref() {
2651 Some("true") => DataCollectionChoice::Enabled,
2652 Some("false") => DataCollectionChoice::Disabled,
2653 Some(_) => {
2654 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2655 DataCollectionChoice::NotAnswered
2656 }
2657 None => DataCollectionChoice::NotAnswered,
2658 }
2659 }
2660
2661 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2662 self.data_collection_choice = self.data_collection_choice.toggle();
2663 let new_choice = self.data_collection_choice;
2664 let is_enabled = new_choice.is_enabled(cx);
2665 db::write_and_log(cx, move || {
2666 KEY_VALUE_STORE.write_kvp(
2667 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2668 is_enabled.to_string(),
2669 )
2670 });
2671 }
2672
2673 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2674 self.shown_predictions.iter()
2675 }
2676
2677 pub fn shown_completions_len(&self) -> usize {
2678 self.shown_predictions.len()
2679 }
2680
2681 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2682 self.rated_predictions.contains(id)
2683 }
2684
2685 pub fn rate_prediction(
2686 &mut self,
2687 prediction: &EditPrediction,
2688 rating: EditPredictionRating,
2689 feedback: String,
2690 cx: &mut Context<Self>,
2691 ) {
2692 let organization = self.user_store.read(cx).current_organization();
2693
2694 self.rated_predictions.insert(prediction.id.clone());
2695
2696 cx.background_spawn({
2697 let client = self.client.clone();
2698 let prediction_id = prediction.id.to_string();
2699 let inputs = serde_json::to_value(&prediction.inputs);
2700 let output = prediction
2701 .edit_preview
2702 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2703 async move {
2704 client
2705 .cloud_client()
2706 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2707 organization_id: organization.map(|organization| organization.id.clone()),
2708 request_id: prediction_id,
2709 rating: match rating {
2710 EditPredictionRating::Positive => "positive".to_string(),
2711 EditPredictionRating::Negative => "negative".to_string(),
2712 },
2713 inputs: inputs?,
2714 output,
2715 feedback,
2716 })
2717 .await?;
2718
2719 anyhow::Ok(())
2720 }
2721 })
2722 .detach_and_log_err(cx);
2723
2724 cx.notify();
2725 }
2726}
2727
2728fn merge_trailing_events_if_needed(
2729 events: &mut VecDeque<StoredEvent>,
2730 end_snapshot: &TextBufferSnapshot,
2731 latest_snapshot: &TextBufferSnapshot,
2732 latest_edit_range: &Range<Anchor>,
2733) {
2734 if let Some(last_event) = events.back() {
2735 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2736 return;
2737 }
2738 }
2739
2740 let mut next_old_event = None;
2741 let mut mergeable_count = 0;
2742 for old_event in events.iter().rev() {
2743 if let Some(next_old_event) = &next_old_event
2744 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2745 {
2746 break;
2747 }
2748 mergeable_count += 1;
2749 next_old_event = Some(old_event);
2750 }
2751
2752 if mergeable_count <= 1 {
2753 return;
2754 }
2755
2756 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2757 let oldest_event = events_to_merge.peek().unwrap();
2758 let oldest_snapshot = oldest_event.old_snapshot.clone();
2759
2760 if let Some((diff, edited_range)) =
2761 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2762 {
2763 let merged_event = match oldest_event.event.as_ref() {
2764 zeta_prompt::Event::BufferChange {
2765 old_path,
2766 path,
2767 in_open_source_repo,
2768 ..
2769 } => StoredEvent {
2770 event: Arc::new(zeta_prompt::Event::BufferChange {
2771 old_path: old_path.clone(),
2772 path: path.clone(),
2773 diff,
2774 in_open_source_repo: *in_open_source_repo,
2775 predicted: events_to_merge.all(|e| {
2776 matches!(
2777 e.event.as_ref(),
2778 zeta_prompt::Event::BufferChange {
2779 predicted: true,
2780 ..
2781 }
2782 )
2783 }),
2784 }),
2785 old_snapshot: oldest_snapshot.clone(),
2786 edit_range: end_snapshot.anchor_before(edited_range.start)
2787 ..end_snapshot.anchor_before(edited_range.end),
2788 },
2789 };
2790 events.truncate(events.len() - mergeable_count);
2791 events.push_back(merged_event);
2792 }
2793}
2794
2795pub(crate) fn filter_redundant_excerpts(
2796 mut related_files: Vec<RelatedFile>,
2797 cursor_path: &Path,
2798 cursor_row_range: Range<u32>,
2799) -> Vec<RelatedFile> {
2800 for file in &mut related_files {
2801 if file.path.as_ref() == cursor_path {
2802 file.excerpts.retain(|excerpt| {
2803 excerpt.row_range.start < cursor_row_range.start
2804 || excerpt.row_range.end > cursor_row_range.end
2805 });
2806 }
2807 }
2808 related_files.retain(|file| !file.excerpts.is_empty());
2809 related_files
2810}
2811
2812#[derive(Error, Debug)]
2813#[error(
2814 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2815)]
2816pub struct ZedUpdateRequiredError {
2817 minimum_version: Version,
2818}
2819
2820#[derive(Debug, Clone, Copy)]
2821pub enum DataCollectionChoice {
2822 NotAnswered,
2823 Enabled,
2824 Disabled,
2825}
2826
2827impl DataCollectionChoice {
2828 pub fn is_enabled(self, cx: &App) -> bool {
2829 if cx.is_staff() {
2830 return true;
2831 }
2832 match self {
2833 Self::Enabled => true,
2834 Self::NotAnswered | Self::Disabled => false,
2835 }
2836 }
2837
2838 #[must_use]
2839 pub fn toggle(&self) -> DataCollectionChoice {
2840 match self {
2841 Self::Enabled => Self::Disabled,
2842 Self::Disabled => Self::Enabled,
2843 Self::NotAnswered => Self::Enabled,
2844 }
2845 }
2846}
2847
2848impl From<bool> for DataCollectionChoice {
2849 fn from(value: bool) -> Self {
2850 match value {
2851 true => DataCollectionChoice::Enabled,
2852 false => DataCollectionChoice::Disabled,
2853 }
2854 }
2855}
2856
2857struct ZedPredictUpsell;
2858
2859impl Dismissable for ZedPredictUpsell {
2860 const KEY: &'static str = "dismissed-edit-predict-upsell";
2861
2862 fn dismissed() -> bool {
2863 // To make this backwards compatible with older versions of Zed, we
2864 // check if the user has seen the previous Edit Prediction Onboarding
2865 // before, by checking the data collection choice which was written to
2866 // the database once the user clicked on "Accept and Enable"
2867 if KEY_VALUE_STORE
2868 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2869 .log_err()
2870 .is_some_and(|s| s.is_some())
2871 {
2872 return true;
2873 }
2874
2875 KEY_VALUE_STORE
2876 .read_kvp(Self::KEY)
2877 .log_err()
2878 .is_some_and(|s| s.is_some())
2879 }
2880}
2881
2882pub fn should_show_upsell_modal() -> bool {
2883 !ZedPredictUpsell::dismissed()
2884}
2885
2886pub fn init(cx: &mut App) {
2887 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2888 workspace.register_action(
2889 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2890 ZedPredictModal::toggle(
2891 workspace,
2892 workspace.user_store().clone(),
2893 workspace.client().clone(),
2894 window,
2895 cx,
2896 )
2897 },
2898 );
2899
2900 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2901 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2902 settings
2903 .project
2904 .all_languages
2905 .edit_predictions
2906 .get_or_insert_default()
2907 .provider = Some(EditPredictionProvider::None)
2908 });
2909 });
2910 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2911 EditPredictionStore::try_global(cx).and_then(|store| {
2912 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2913 })
2914 }
2915
2916 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2917 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2918 copilot_ui::initiate_sign_in(copilot, window, cx);
2919 }
2920 });
2921 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2922 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2923 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2924 }
2925 });
2926 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2927 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2928 copilot_ui::initiate_sign_out(copilot, window, cx);
2929 }
2930 });
2931 })
2932 .detach();
2933}