1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod cursor_excerpt;
59pub mod example_spec;
60pub mod fim;
61mod license_detection;
62pub mod mercury;
63pub mod ollama;
64mod onboarding_modal;
65pub mod open_ai_response;
66mod prediction;
67pub mod sweep_ai;
68
69pub mod udiff;
70
71mod capture_example;
72pub mod open_ai_compatible;
73mod zed_edit_prediction_delegate;
74pub mod zeta;
75
76#[cfg(test)]
77mod edit_prediction_tests;
78
79use crate::license_detection::LicenseDetectionWatcher;
80use crate::mercury::Mercury;
81use crate::onboarding_modal::ZedPredictModal;
82pub use crate::prediction::EditPrediction;
83pub use crate::prediction::EditPredictionId;
84use crate::prediction::EditPredictionResult;
85pub use crate::sweep_ai::SweepAi;
86pub use capture_example::capture_example;
87pub use language_model::ApiKeyState;
88pub use telemetry_events::EditPredictionRating;
89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
90
91actions!(
92 edit_prediction,
93 [
94 /// Resets the edit prediction onboarding state.
95 ResetOnboarding,
96 /// Clears the edit prediction history.
97 ClearHistory,
98 ]
99);
100
101/// Maximum number of events to track.
102const EVENT_COUNT_MAX: usize = 6;
103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
107const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
108const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
109const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
110
111pub struct EditPredictionJumpsFeatureFlag;
112
113impl FeatureFlag for EditPredictionJumpsFeatureFlag {
114 const NAME: &'static str = "edit_prediction_jumps";
115}
116
117#[derive(Clone)]
118struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
119
120impl Global for EditPredictionStoreGlobal {}
121
122/// Configuration for using the raw Zeta2 endpoint.
123/// When set, the client uses the raw endpoint and constructs the prompt itself.
124/// The version is also used as the Baseten environment name (lowercased).
125#[derive(Clone)]
126pub struct Zeta2RawConfig {
127 pub model_id: Option<String>,
128 pub environment: Option<String>,
129 pub format: ZetaFormat,
130}
131
132pub struct EditPredictionStore {
133 client: Arc<Client>,
134 user_store: Entity<UserStore>,
135 llm_token: LlmApiToken,
136 _llm_token_subscription: Subscription,
137 _fetch_experiments_task: Task<()>,
138 projects: HashMap<EntityId, ProjectState>,
139 update_required: bool,
140 edit_prediction_model: EditPredictionModel,
141 zeta2_raw_config: Option<Zeta2RawConfig>,
142 preferred_experiment: Option<String>,
143 available_experiments: Vec<String>,
144 pub sweep_ai: SweepAi,
145 pub mercury: Mercury,
146 data_collection_choice: DataCollectionChoice,
147 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
148 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
149 shown_predictions: VecDeque<EditPrediction>,
150 rated_predictions: HashSet<EditPredictionId>,
151 #[cfg(test)]
152 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
153}
154
155pub(crate) struct EditPredictionRejectionPayload {
156 rejection: EditPredictionRejection,
157 organization_id: Option<OrganizationId>,
158}
159
160#[derive(Copy, Clone, PartialEq, Eq)]
161pub enum EditPredictionModel {
162 Zeta,
163 Fim { format: EditPredictionPromptFormat },
164 Sweep,
165 Mercury,
166}
167
168#[derive(Clone)]
169pub struct EditPredictionModelInput {
170 project: Entity<Project>,
171 buffer: Entity<Buffer>,
172 snapshot: BufferSnapshot,
173 position: Anchor,
174 events: Vec<Arc<zeta_prompt::Event>>,
175 related_files: Vec<RelatedFile>,
176 recent_paths: VecDeque<ProjectPath>,
177 trigger: PredictEditsRequestTrigger,
178 diagnostic_search_range: Range<Point>,
179 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
180 can_collect_data: bool,
181 is_open_source: bool,
182 pub user_actions: Vec<UserActionRecord>,
183}
184
185#[derive(Debug)]
186pub enum DebugEvent {
187 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
188 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
189 EditPredictionStarted(EditPredictionStartedDebugEvent),
190 EditPredictionFinished(EditPredictionFinishedDebugEvent),
191}
192
193#[derive(Debug)]
194pub struct ContextRetrievalStartedDebugEvent {
195 pub project_entity_id: EntityId,
196 pub timestamp: Instant,
197 pub search_prompt: String,
198}
199
200#[derive(Debug)]
201pub struct ContextRetrievalFinishedDebugEvent {
202 pub project_entity_id: EntityId,
203 pub timestamp: Instant,
204 pub metadata: Vec<(&'static str, SharedString)>,
205}
206
207#[derive(Debug)]
208pub struct EditPredictionStartedDebugEvent {
209 pub buffer: WeakEntity<Buffer>,
210 pub position: Anchor,
211 pub prompt: Option<String>,
212}
213
214#[derive(Debug)]
215pub struct EditPredictionFinishedDebugEvent {
216 pub buffer: WeakEntity<Buffer>,
217 pub position: Anchor,
218 pub model_output: Option<String>,
219}
220
221const USER_ACTION_HISTORY_SIZE: usize = 16;
222
223#[derive(Clone, Debug)]
224pub struct UserActionRecord {
225 pub action_type: UserActionType,
226 pub buffer_id: EntityId,
227 pub line_number: u32,
228 pub offset: usize,
229 pub timestamp_epoch_ms: u64,
230}
231
232#[derive(Clone, Copy, Debug, PartialEq, Eq)]
233pub enum UserActionType {
234 InsertChar,
235 InsertSelection,
236 DeleteChar,
237 DeleteSelection,
238 CursorMovement,
239}
240
241/// An event with associated metadata for reconstructing buffer state.
242#[derive(Clone)]
243pub struct StoredEvent {
244 pub event: Arc<zeta_prompt::Event>,
245 pub old_snapshot: TextBufferSnapshot,
246 pub edit_range: Range<Anchor>,
247}
248
249impl StoredEvent {
250 fn can_merge(
251 &self,
252 next_old_event: &&&StoredEvent,
253 new_snapshot: &TextBufferSnapshot,
254 last_edit_range: &Range<Anchor>,
255 ) -> bool {
256 // Events must be for the same buffer
257 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
258 return false;
259 }
260 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
261 return false;
262 }
263
264 let a_is_predicted = matches!(
265 self.event.as_ref(),
266 zeta_prompt::Event::BufferChange {
267 predicted: true,
268 ..
269 }
270 );
271 let b_is_predicted = matches!(
272 next_old_event.event.as_ref(),
273 zeta_prompt::Event::BufferChange {
274 predicted: true,
275 ..
276 }
277 );
278
279 // If events come from the same source (both predicted or both manual) then
280 // we would have coalesced them already.
281 if a_is_predicted == b_is_predicted {
282 return false;
283 }
284
285 let left_range = self.edit_range.to_point(new_snapshot);
286 let right_range = next_old_event.edit_range.to_point(new_snapshot);
287 let latest_range = last_edit_range.to_point(&new_snapshot);
288
289 // Events near to the latest edit are not merged if their sources differ.
290 if lines_between_ranges(&left_range, &latest_range)
291 .min(lines_between_ranges(&right_range, &latest_range))
292 <= CHANGE_GROUPING_LINE_SPAN
293 {
294 return false;
295 }
296
297 // Events that are distant from each other are not merged.
298 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
299 return false;
300 }
301
302 true
303 }
304}
305
306fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
307 if left.start > right.end {
308 return left.start.row - right.end.row;
309 }
310 if right.start > left.end {
311 return right.start.row - left.end.row;
312 }
313 0
314}
315
316struct ProjectState {
317 events: VecDeque<StoredEvent>,
318 last_event: Option<LastEvent>,
319 recent_paths: VecDeque<ProjectPath>,
320 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
321 current_prediction: Option<CurrentEditPrediction>,
322 next_pending_prediction_id: usize,
323 pending_predictions: ArrayVec<PendingPrediction, 2>,
324 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
325 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
326 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
327 cancelled_predictions: HashSet<usize>,
328 context: Entity<RelatedExcerptStore>,
329 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
330 user_actions: VecDeque<UserActionRecord>,
331 _subscriptions: [gpui::Subscription; 2],
332 copilot: Option<Entity<Copilot>>,
333}
334
335impl ProjectState {
336 fn record_user_action(&mut self, action: UserActionRecord) {
337 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
338 self.user_actions.pop_front();
339 }
340 self.user_actions.push_back(action);
341 }
342
343 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
344 self.events
345 .iter()
346 .cloned()
347 .chain(self.last_event.as_ref().iter().flat_map(|event| {
348 let (one, two) = event.split_by_pause();
349 let one = one.finalize(&self.license_detection_watchers, cx);
350 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
351 one.into_iter().chain(two)
352 }))
353 .collect()
354 }
355
356 fn cancel_pending_prediction(
357 &mut self,
358 pending_prediction: PendingPrediction,
359 cx: &mut Context<EditPredictionStore>,
360 ) {
361 self.cancelled_predictions.insert(pending_prediction.id);
362
363 if pending_prediction.drop_on_cancel {
364 drop(pending_prediction.task);
365 } else {
366 cx.spawn(async move |this, cx| {
367 let Some(prediction_id) = pending_prediction.task.await else {
368 return;
369 };
370
371 this.update(cx, |this, cx| {
372 this.reject_prediction(
373 prediction_id,
374 EditPredictionRejectReason::Canceled,
375 false,
376 None,
377 cx,
378 );
379 })
380 .ok();
381 })
382 .detach()
383 }
384 }
385
386 fn active_buffer(
387 &self,
388 project: &Entity<Project>,
389 cx: &App,
390 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
391 let project = project.read(cx);
392 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
393 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
394 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
395 Some((active_buffer, registered_buffer.last_position))
396 }
397}
398
399#[derive(Debug, Clone)]
400struct CurrentEditPrediction {
401 pub requested_by: PredictionRequestedBy,
402 pub prediction: EditPrediction,
403 pub was_shown: bool,
404 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
405}
406
407impl CurrentEditPrediction {
408 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
409 let Some(new_edits) = self
410 .prediction
411 .interpolate(&self.prediction.buffer.read(cx))
412 else {
413 return false;
414 };
415
416 if self.prediction.buffer != old_prediction.prediction.buffer {
417 return true;
418 }
419
420 let Some(old_edits) = old_prediction
421 .prediction
422 .interpolate(&old_prediction.prediction.buffer.read(cx))
423 else {
424 return true;
425 };
426
427 let requested_by_buffer_id = self.requested_by.buffer_id();
428
429 // This reduces the occurrence of UI thrash from replacing edits
430 //
431 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
432 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
433 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
434 && old_edits.len() == 1
435 && new_edits.len() == 1
436 {
437 let (old_range, old_text) = &old_edits[0];
438 let (new_range, new_text) = &new_edits[0];
439 new_range == old_range && new_text.starts_with(old_text.as_ref())
440 } else {
441 true
442 }
443 }
444}
445
446#[derive(Debug, Clone)]
447enum PredictionRequestedBy {
448 DiagnosticsUpdate,
449 Buffer(EntityId),
450}
451
452impl PredictionRequestedBy {
453 pub fn buffer_id(&self) -> Option<EntityId> {
454 match self {
455 PredictionRequestedBy::DiagnosticsUpdate => None,
456 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
457 }
458 }
459}
460
461const DIAGNOSTIC_LINES_RANGE: u32 = 20;
462
463#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
464pub enum DiagnosticSearchScope {
465 Local,
466 Global,
467}
468
469#[derive(Debug)]
470struct PendingPrediction {
471 id: usize,
472 task: Task<Option<EditPredictionId>>,
473 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
474 /// If false, the task is awaited to completion so rejection can be reported.
475 drop_on_cancel: bool,
476}
477
478/// A prediction from the perspective of a buffer.
479#[derive(Debug)]
480enum BufferEditPrediction<'a> {
481 Local { prediction: &'a EditPrediction },
482 Jump { prediction: &'a EditPrediction },
483}
484
485#[cfg(test)]
486impl std::ops::Deref for BufferEditPrediction<'_> {
487 type Target = EditPrediction;
488
489 fn deref(&self) -> &Self::Target {
490 match self {
491 BufferEditPrediction::Local { prediction } => prediction,
492 BufferEditPrediction::Jump { prediction } => prediction,
493 }
494 }
495}
496
497#[derive(Clone)]
498struct PendingSettledPrediction {
499 request_id: EditPredictionId,
500 editable_anchor_range: Range<Anchor>,
501 enqueued_at: Instant,
502 last_edit_at: Instant,
503}
504
505struct RegisteredBuffer {
506 file: Option<Arc<dyn File>>,
507 snapshot: TextBufferSnapshot,
508 pending_predictions: Vec<PendingSettledPrediction>,
509 last_position: Option<Anchor>,
510 _subscriptions: [gpui::Subscription; 2],
511}
512
513#[derive(Clone)]
514struct LastEvent {
515 old_snapshot: TextBufferSnapshot,
516 new_snapshot: TextBufferSnapshot,
517 old_file: Option<Arc<dyn File>>,
518 new_file: Option<Arc<dyn File>>,
519 edit_range: Option<Range<Anchor>>,
520 predicted: bool,
521 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
522 last_edit_time: Option<Instant>,
523}
524
525impl LastEvent {
526 pub fn finalize(
527 &self,
528 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
529 cx: &App,
530 ) -> Option<StoredEvent> {
531 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
532 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
533
534 let in_open_source_repo =
535 [self.new_file.as_ref(), self.old_file.as_ref()]
536 .iter()
537 .all(|file| {
538 file.is_some_and(|file| {
539 license_detection_watchers
540 .get(&file.worktree_id(cx))
541 .is_some_and(|watcher| watcher.is_project_open_source())
542 })
543 });
544
545 let (diff, edit_range) =
546 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
547
548 if path == old_path && diff.is_empty() {
549 None
550 } else {
551 Some(StoredEvent {
552 event: Arc::new(zeta_prompt::Event::BufferChange {
553 old_path,
554 path,
555 diff,
556 in_open_source_repo,
557 predicted: self.predicted,
558 }),
559 edit_range: self.new_snapshot.anchor_before(edit_range.start)
560 ..self.new_snapshot.anchor_before(edit_range.end),
561 old_snapshot: self.old_snapshot.clone(),
562 })
563 }
564 }
565
566 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
567 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
568 return (self.clone(), None);
569 };
570
571 let before = LastEvent {
572 old_snapshot: self.old_snapshot.clone(),
573 new_snapshot: boundary_snapshot.clone(),
574 old_file: self.old_file.clone(),
575 new_file: self.new_file.clone(),
576 edit_range: None,
577 predicted: self.predicted,
578 snapshot_after_last_editing_pause: None,
579 last_edit_time: self.last_edit_time,
580 };
581
582 let after = LastEvent {
583 old_snapshot: boundary_snapshot.clone(),
584 new_snapshot: self.new_snapshot.clone(),
585 old_file: self.old_file.clone(),
586 new_file: self.new_file.clone(),
587 edit_range: None,
588 predicted: self.predicted,
589 snapshot_after_last_editing_pause: None,
590 last_edit_time: self.last_edit_time,
591 };
592
593 (before, Some(after))
594 }
595}
596
597pub(crate) fn compute_diff_between_snapshots(
598 old_snapshot: &TextBufferSnapshot,
599 new_snapshot: &TextBufferSnapshot,
600) -> Option<(String, Range<Point>)> {
601 let edits: Vec<Edit<usize>> = new_snapshot
602 .edits_since::<usize>(&old_snapshot.version)
603 .collect();
604
605 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
606
607 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
608 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
609 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
610 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
611
612 const CONTEXT_LINES: u32 = 3;
613
614 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
615 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
616 let old_context_end_row =
617 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
618 let new_context_end_row =
619 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
620
621 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
622 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
623 let old_end_line_offset = old_snapshot
624 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
625 let new_end_line_offset = new_snapshot
626 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
627 let old_edit_range = old_start_line_offset..old_end_line_offset;
628 let new_edit_range = new_start_line_offset..new_end_line_offset;
629
630 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
631 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
632
633 let diff = language::unified_diff_with_offsets(
634 &old_region_text,
635 &new_region_text,
636 old_context_start_row,
637 new_context_start_row,
638 );
639
640 Some((diff, new_start_point..new_end_point))
641}
642
643fn buffer_path_with_id_fallback(
644 file: Option<&Arc<dyn File>>,
645 snapshot: &TextBufferSnapshot,
646 cx: &App,
647) -> Arc<Path> {
648 if let Some(file) = file {
649 file.full_path(cx).into()
650 } else {
651 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
652 }
653}
654
655impl EditPredictionStore {
656 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
657 cx.try_global::<EditPredictionStoreGlobal>()
658 .map(|global| global.0.clone())
659 }
660
661 pub fn global(
662 client: &Arc<Client>,
663 user_store: &Entity<UserStore>,
664 cx: &mut App,
665 ) -> Entity<Self> {
666 cx.try_global::<EditPredictionStoreGlobal>()
667 .map(|global| global.0.clone())
668 .unwrap_or_else(|| {
669 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
670 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
671 ep_store
672 })
673 }
674
675 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
676 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
677 let data_collection_choice = Self::load_data_collection_choice();
678
679 let llm_token = LlmApiToken::default();
680
681 let (reject_tx, reject_rx) = mpsc::unbounded();
682 cx.background_spawn({
683 let client = client.clone();
684 let llm_token = llm_token.clone();
685 let app_version = AppVersion::global(cx);
686 let background_executor = cx.background_executor().clone();
687 async move {
688 Self::handle_rejected_predictions(
689 reject_rx,
690 client,
691 llm_token,
692 app_version,
693 background_executor,
694 )
695 .await
696 }
697 })
698 .detach();
699
700 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
701 cx.spawn(async move |this, cx| {
702 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
703 })
704 .detach();
705
706 let mut current_user = user_store.read(cx).watch_current_user();
707 let fetch_experiments_task = cx.spawn(async move |this, cx| {
708 while current_user.borrow().is_none() {
709 current_user.next().await;
710 }
711 this.update(cx, |this, cx| {
712 this.refresh_available_experiments(cx);
713 })
714 .log_err();
715 });
716
717 let this = Self {
718 projects: HashMap::default(),
719 client,
720 user_store,
721 llm_token,
722 _fetch_experiments_task: fetch_experiments_task,
723 _llm_token_subscription: cx.subscribe(
724 &refresh_llm_token_listener,
725 |this, _listener, _event, cx| {
726 let client = this.client.clone();
727 let llm_token = this.llm_token.clone();
728 let organization_id = this
729 .user_store
730 .read(cx)
731 .current_organization()
732 .map(|organization| organization.id.clone());
733 cx.spawn(async move |_this, _cx| {
734 llm_token.refresh(&client, organization_id).await?;
735 anyhow::Ok(())
736 })
737 .detach_and_log_err(cx);
738 },
739 ),
740 update_required: false,
741 edit_prediction_model: EditPredictionModel::Zeta,
742 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
743 preferred_experiment: None,
744 available_experiments: Vec::new(),
745 sweep_ai: SweepAi::new(cx),
746 mercury: Mercury::new(cx),
747
748 data_collection_choice,
749 reject_predictions_tx: reject_tx,
750 settled_predictions_tx,
751 rated_predictions: Default::default(),
752 shown_predictions: Default::default(),
753 #[cfg(test)]
754 settled_event_callback: None,
755 };
756
757 this
758 }
759
760 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
761 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
762 let format = ZetaFormat::parse(&version_str).ok()?;
763 let model_id = env::var("ZED_ZETA_MODEL").ok();
764 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
765 Some(Zeta2RawConfig {
766 model_id,
767 environment,
768 format,
769 })
770 }
771
772 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
773 self.edit_prediction_model = model;
774 }
775
776 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
777 self.zeta2_raw_config = Some(config);
778 }
779
780 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
781 self.zeta2_raw_config.as_ref()
782 }
783
784 pub fn preferred_experiment(&self) -> Option<&str> {
785 self.preferred_experiment.as_deref()
786 }
787
788 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
789 self.preferred_experiment = experiment;
790 }
791
792 pub fn available_experiments(&self) -> &[String] {
793 &self.available_experiments
794 }
795
796 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
797 let client = self.client.clone();
798 let llm_token = self.llm_token.clone();
799 let app_version = AppVersion::global(cx);
800 let organization_id = self
801 .user_store
802 .read(cx)
803 .current_organization()
804 .map(|organization| organization.id.clone());
805
806 cx.spawn(async move |this, cx| {
807 let experiments = cx
808 .background_spawn(async move {
809 let http_client = client.http_client();
810 let token = llm_token.acquire(&client, organization_id).await?;
811 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
812 let request = http_client::Request::builder()
813 .method(Method::GET)
814 .uri(url.as_ref())
815 .header("Authorization", format!("Bearer {}", token))
816 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
817 .body(Default::default())?;
818 let mut response = http_client.send(request).await?;
819 if response.status().is_success() {
820 let mut body = Vec::new();
821 response.body_mut().read_to_end(&mut body).await?;
822 let experiments: Vec<String> = serde_json::from_slice(&body)?;
823 Ok(experiments)
824 } else {
825 let mut body = String::new();
826 response.body_mut().read_to_string(&mut body).await?;
827 anyhow::bail!(
828 "Failed to fetch experiments: {:?}\nBody: {}",
829 response.status(),
830 body
831 );
832 }
833 })
834 .await?;
835 this.update(cx, |this, cx| {
836 this.available_experiments = experiments;
837 cx.notify();
838 })?;
839 anyhow::Ok(())
840 })
841 .detach_and_log_err(cx);
842 }
843
844 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
845 use ui::IconName;
846 match self.edit_prediction_model {
847 EditPredictionModel::Sweep => {
848 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
849 .with_disabled(IconName::SweepAiDisabled)
850 .with_up(IconName::SweepAiUp)
851 .with_down(IconName::SweepAiDown)
852 .with_error(IconName::SweepAiError)
853 }
854 EditPredictionModel::Mercury => {
855 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
856 }
857 EditPredictionModel::Zeta => {
858 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
859 .with_disabled(IconName::ZedPredictDisabled)
860 .with_up(IconName::ZedPredictUp)
861 .with_down(IconName::ZedPredictDown)
862 .with_error(IconName::ZedPredictError)
863 }
864 EditPredictionModel::Fim { .. } => {
865 let settings = &all_language_settings(None, cx).edit_predictions;
866 match settings.provider {
867 EditPredictionProvider::Ollama => {
868 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
869 }
870 _ => {
871 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
872 }
873 }
874 }
875 }
876 }
877
878 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
879 self.sweep_ai.api_token.read(cx).has_key()
880 }
881
882 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
883 self.mercury.api_token.read(cx).has_key()
884 }
885
886 pub fn clear_history(&mut self) {
887 for project_state in self.projects.values_mut() {
888 project_state.events.clear();
889 project_state.last_event.take();
890 }
891 }
892
893 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
894 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
895 project_state.events.clear();
896 project_state.last_event.take();
897 }
898 }
899
900 pub fn edit_history_for_project(
901 &self,
902 project: &Entity<Project>,
903 cx: &App,
904 ) -> Vec<StoredEvent> {
905 self.projects
906 .get(&project.entity_id())
907 .map(|project_state| project_state.events(cx))
908 .unwrap_or_default()
909 }
910
911 pub fn context_for_project<'a>(
912 &'a self,
913 project: &Entity<Project>,
914 cx: &'a mut App,
915 ) -> Vec<RelatedFile> {
916 self.projects
917 .get(&project.entity_id())
918 .map(|project_state| {
919 project_state.context.update(cx, |context, cx| {
920 context
921 .related_files_with_buffers(cx)
922 .map(|(mut related_file, buffer)| {
923 related_file.in_open_source_repo = buffer
924 .read(cx)
925 .file()
926 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
927 related_file
928 })
929 .collect()
930 })
931 })
932 .unwrap_or_default()
933 }
934
935 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
936 self.projects
937 .get(&project.entity_id())
938 .and_then(|project| project.copilot.clone())
939 }
940
941 pub fn start_copilot_for_project(
942 &mut self,
943 project: &Entity<Project>,
944 cx: &mut Context<Self>,
945 ) -> Option<Entity<Copilot>> {
946 if DisableAiSettings::get(None, cx).disable_ai {
947 return None;
948 }
949 let state = self.get_or_init_project(project, cx);
950
951 if state.copilot.is_some() {
952 return state.copilot.clone();
953 }
954 let _project = project.clone();
955 let project = project.read(cx);
956
957 let node = project.node_runtime().cloned();
958 if let Some(node) = node {
959 let next_id = project.languages().next_language_server_id();
960 let fs = project.fs().clone();
961
962 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
963 state.copilot = Some(copilot.clone());
964 Some(copilot)
965 } else {
966 None
967 }
968 }
969
970 pub fn context_for_project_with_buffers<'a>(
971 &'a self,
972 project: &Entity<Project>,
973 cx: &'a mut App,
974 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
975 self.projects
976 .get(&project.entity_id())
977 .map(|project| {
978 project.context.update(cx, |context, cx| {
979 context.related_files_with_buffers(cx).collect()
980 })
981 })
982 .unwrap_or_default()
983 }
984
985 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
986 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
987 self.user_store.read(cx).edit_prediction_usage()
988 } else {
989 None
990 }
991 }
992
993 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
994 self.get_or_init_project(project, cx);
995 }
996
997 pub fn register_buffer(
998 &mut self,
999 buffer: &Entity<Buffer>,
1000 project: &Entity<Project>,
1001 cx: &mut Context<Self>,
1002 ) {
1003 let project_state = self.get_or_init_project(project, cx);
1004 Self::register_buffer_impl(project_state, buffer, project, cx);
1005 }
1006
1007 fn get_or_init_project(
1008 &mut self,
1009 project: &Entity<Project>,
1010 cx: &mut Context<Self>,
1011 ) -> &mut ProjectState {
1012 let entity_id = project.entity_id();
1013 self.projects
1014 .entry(entity_id)
1015 .or_insert_with(|| ProjectState {
1016 context: {
1017 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1018 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1019 this.handle_excerpt_store_event(entity_id, event);
1020 })
1021 .detach();
1022 related_excerpt_store
1023 },
1024 events: VecDeque::new(),
1025 last_event: None,
1026 recent_paths: VecDeque::new(),
1027 debug_tx: None,
1028 registered_buffers: HashMap::default(),
1029 current_prediction: None,
1030 cancelled_predictions: HashSet::default(),
1031 pending_predictions: ArrayVec::new(),
1032 next_pending_prediction_id: 0,
1033 last_edit_prediction_refresh: None,
1034 last_jump_prediction_refresh: None,
1035 license_detection_watchers: HashMap::default(),
1036 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1037 _subscriptions: [
1038 cx.subscribe(&project, Self::handle_project_event),
1039 cx.observe_release(&project, move |this, _, cx| {
1040 if let Some(state) = this.projects.remove(&entity_id) {
1041 if let Some(copilot) = state.copilot {
1042 let shutdown = copilot
1043 .update(cx, |copilot, cx| copilot.shutdown_language_server(cx));
1044 cx.background_spawn(shutdown).detach();
1045 }
1046 }
1047 cx.notify();
1048 }),
1049 ],
1050 copilot: None,
1051 })
1052 }
1053
1054 pub fn remove_project(&mut self, project: &Entity<Project>) {
1055 self.projects.remove(&project.entity_id());
1056 }
1057
1058 fn handle_excerpt_store_event(
1059 &mut self,
1060 project_entity_id: EntityId,
1061 event: &RelatedExcerptStoreEvent,
1062 ) {
1063 if let Some(project_state) = self.projects.get(&project_entity_id) {
1064 if let Some(debug_tx) = project_state.debug_tx.clone() {
1065 match event {
1066 RelatedExcerptStoreEvent::StartedRefresh => {
1067 debug_tx
1068 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1069 ContextRetrievalStartedDebugEvent {
1070 project_entity_id: project_entity_id,
1071 timestamp: Instant::now(),
1072 search_prompt: String::new(),
1073 },
1074 ))
1075 .ok();
1076 }
1077 RelatedExcerptStoreEvent::FinishedRefresh {
1078 cache_hit_count,
1079 cache_miss_count,
1080 mean_definition_latency,
1081 max_definition_latency,
1082 } => {
1083 debug_tx
1084 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1085 ContextRetrievalFinishedDebugEvent {
1086 project_entity_id: project_entity_id,
1087 timestamp: Instant::now(),
1088 metadata: vec![
1089 (
1090 "Cache Hits",
1091 format!(
1092 "{}/{}",
1093 cache_hit_count,
1094 cache_hit_count + cache_miss_count
1095 )
1096 .into(),
1097 ),
1098 (
1099 "Max LSP Time",
1100 format!("{} ms", max_definition_latency.as_millis())
1101 .into(),
1102 ),
1103 (
1104 "Mean LSP Time",
1105 format!("{} ms", mean_definition_latency.as_millis())
1106 .into(),
1107 ),
1108 ],
1109 },
1110 ))
1111 .ok();
1112 }
1113 }
1114 }
1115 }
1116 }
1117
1118 pub fn debug_info(
1119 &mut self,
1120 project: &Entity<Project>,
1121 cx: &mut Context<Self>,
1122 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1123 let project_state = self.get_or_init_project(project, cx);
1124 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1125 project_state.debug_tx = Some(debug_watch_tx);
1126 debug_watch_rx
1127 }
1128
1129 fn handle_project_event(
1130 &mut self,
1131 project: Entity<Project>,
1132 event: &project::Event,
1133 cx: &mut Context<Self>,
1134 ) {
1135 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1136 return;
1137 }
1138 // TODO [zeta2] init with recent paths
1139 match event {
1140 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1141 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1142 return;
1143 };
1144 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1145 if let Some(path) = path {
1146 if let Some(ix) = project_state
1147 .recent_paths
1148 .iter()
1149 .position(|probe| probe == &path)
1150 {
1151 project_state.recent_paths.remove(ix);
1152 }
1153 project_state.recent_paths.push_front(path);
1154 }
1155 }
1156 project::Event::DiagnosticsUpdated { .. } => {
1157 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1158 self.refresh_prediction_from_diagnostics(
1159 project,
1160 DiagnosticSearchScope::Global,
1161 cx,
1162 );
1163 }
1164 }
1165 _ => (),
1166 }
1167 }
1168
1169 fn register_buffer_impl<'a>(
1170 project_state: &'a mut ProjectState,
1171 buffer: &Entity<Buffer>,
1172 project: &Entity<Project>,
1173 cx: &mut Context<Self>,
1174 ) -> &'a mut RegisteredBuffer {
1175 let buffer_id = buffer.entity_id();
1176
1177 if let Some(file) = buffer.read(cx).file() {
1178 let worktree_id = file.worktree_id(cx);
1179 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1180 project_state
1181 .license_detection_watchers
1182 .entry(worktree_id)
1183 .or_insert_with(|| {
1184 let project_entity_id = project.entity_id();
1185 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1186 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1187 else {
1188 return;
1189 };
1190 project_state
1191 .license_detection_watchers
1192 .remove(&worktree_id);
1193 })
1194 .detach();
1195 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1196 });
1197 }
1198 }
1199
1200 match project_state.registered_buffers.entry(buffer_id) {
1201 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1202 hash_map::Entry::Vacant(entry) => {
1203 let buf = buffer.read(cx);
1204 let snapshot = buf.text_snapshot();
1205 let file = buf.file().cloned();
1206 let project_entity_id = project.entity_id();
1207 entry.insert(RegisteredBuffer {
1208 snapshot,
1209 file,
1210 last_position: None,
1211 pending_predictions: Vec::new(),
1212 _subscriptions: [
1213 cx.subscribe(buffer, {
1214 let project = project.downgrade();
1215 move |this, buffer, event, cx| {
1216 if let language::BufferEvent::Edited = event
1217 && let Some(project) = project.upgrade()
1218 {
1219 this.report_changes_for_buffer(&buffer, &project, false, cx);
1220 }
1221 }
1222 }),
1223 cx.observe_release(buffer, move |this, _buffer, _cx| {
1224 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1225 else {
1226 return;
1227 };
1228 project_state.registered_buffers.remove(&buffer_id);
1229 }),
1230 ],
1231 })
1232 }
1233 }
1234 }
1235
1236 fn report_changes_for_buffer(
1237 &mut self,
1238 buffer: &Entity<Buffer>,
1239 project: &Entity<Project>,
1240 is_predicted: bool,
1241 cx: &mut Context<Self>,
1242 ) {
1243 let project_state = self.get_or_init_project(project, cx);
1244 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1245
1246 let buf = buffer.read(cx);
1247 let new_file = buf.file().cloned();
1248 let new_snapshot = buf.text_snapshot();
1249 if new_snapshot.version == registered_buffer.snapshot.version {
1250 return;
1251 }
1252
1253 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1254 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1255 let mut num_edits = 0usize;
1256 let mut total_deleted = 0usize;
1257 let mut total_inserted = 0usize;
1258 let mut edit_range: Option<Range<Anchor>> = None;
1259 let mut last_offset: Option<usize> = None;
1260 let now = cx.background_executor().now();
1261
1262 for (edit, anchor_range) in
1263 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1264 {
1265 num_edits += 1;
1266 total_deleted += edit.old.len();
1267 total_inserted += edit.new.len();
1268 edit_range = Some(match edit_range {
1269 None => anchor_range,
1270 Some(acc) => acc.start..anchor_range.end,
1271 });
1272 last_offset = Some(edit.new.end);
1273 }
1274
1275 let Some(edit_range) = edit_range else {
1276 return;
1277 };
1278
1279 for pending_prediction in &mut registered_buffer.pending_predictions {
1280 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1281 pending_prediction.last_edit_at = now;
1282 }
1283 }
1284
1285 let action_type = match (total_deleted, total_inserted, num_edits) {
1286 (0, ins, n) if ins == n => UserActionType::InsertChar,
1287 (0, _, _) => UserActionType::InsertSelection,
1288 (del, 0, n) if del == n => UserActionType::DeleteChar,
1289 (_, 0, _) => UserActionType::DeleteSelection,
1290 (_, ins, n) if ins == n => UserActionType::InsertChar,
1291 (_, _, _) => UserActionType::InsertSelection,
1292 };
1293
1294 if let Some(offset) = last_offset {
1295 let point = new_snapshot.offset_to_point(offset);
1296 let timestamp_epoch_ms = SystemTime::now()
1297 .duration_since(UNIX_EPOCH)
1298 .map(|d| d.as_millis() as u64)
1299 .unwrap_or(0);
1300 project_state.record_user_action(UserActionRecord {
1301 action_type,
1302 buffer_id: buffer.entity_id(),
1303 line_number: point.row,
1304 offset,
1305 timestamp_epoch_ms,
1306 });
1307 }
1308
1309 let events = &mut project_state.events;
1310
1311 if let Some(last_event) = project_state.last_event.as_mut() {
1312 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1313 == last_event.new_snapshot.remote_id()
1314 && old_snapshot.version == last_event.new_snapshot.version;
1315
1316 let prediction_source_changed = is_predicted != last_event.predicted;
1317
1318 let should_coalesce = is_next_snapshot_of_same_buffer
1319 && !prediction_source_changed
1320 && last_event
1321 .edit_range
1322 .as_ref()
1323 .is_some_and(|last_edit_range| {
1324 lines_between_ranges(
1325 &edit_range.to_point(&new_snapshot),
1326 &last_edit_range.to_point(&new_snapshot),
1327 ) <= CHANGE_GROUPING_LINE_SPAN
1328 });
1329
1330 if should_coalesce {
1331 let pause_elapsed = last_event
1332 .last_edit_time
1333 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1334 .unwrap_or(false);
1335 if pause_elapsed {
1336 last_event.snapshot_after_last_editing_pause =
1337 Some(last_event.new_snapshot.clone());
1338 }
1339
1340 last_event.edit_range = Some(edit_range);
1341 last_event.new_snapshot = new_snapshot;
1342 last_event.last_edit_time = Some(now);
1343 return;
1344 }
1345 }
1346
1347 if let Some(event) = project_state.last_event.take() {
1348 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1349 if events.len() + 1 >= EVENT_COUNT_MAX {
1350 events.pop_front();
1351 }
1352 events.push_back(event);
1353 }
1354 }
1355
1356 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1357
1358 project_state.last_event = Some(LastEvent {
1359 old_file,
1360 new_file,
1361 old_snapshot,
1362 new_snapshot,
1363 edit_range: Some(edit_range),
1364 predicted: is_predicted,
1365 snapshot_after_last_editing_pause: None,
1366 last_edit_time: Some(now),
1367 });
1368 }
1369
1370 fn prediction_at(
1371 &mut self,
1372 buffer: &Entity<Buffer>,
1373 position: Option<language::Anchor>,
1374 project: &Entity<Project>,
1375 cx: &App,
1376 ) -> Option<BufferEditPrediction<'_>> {
1377 let project_state = self.projects.get_mut(&project.entity_id())?;
1378 if let Some(position) = position
1379 && let Some(buffer) = project_state
1380 .registered_buffers
1381 .get_mut(&buffer.entity_id())
1382 {
1383 buffer.last_position = Some(position);
1384 }
1385
1386 let CurrentEditPrediction {
1387 requested_by,
1388 prediction,
1389 ..
1390 } = project_state.current_prediction.as_ref()?;
1391
1392 if prediction.targets_buffer(buffer.read(cx)) {
1393 Some(BufferEditPrediction::Local { prediction })
1394 } else {
1395 let show_jump = match requested_by {
1396 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1397 requested_by_buffer_id == &buffer.entity_id()
1398 }
1399 PredictionRequestedBy::DiagnosticsUpdate => true,
1400 };
1401
1402 if show_jump {
1403 Some(BufferEditPrediction::Jump { prediction })
1404 } else {
1405 None
1406 }
1407 }
1408 }
1409
1410 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1411 let Some(current_prediction) = self
1412 .projects
1413 .get_mut(&project.entity_id())
1414 .and_then(|project_state| project_state.current_prediction.take())
1415 else {
1416 return;
1417 };
1418
1419 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1420
1421 // can't hold &mut project_state ref across report_changes_for_buffer_call
1422 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1423 return;
1424 };
1425
1426 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1427 project_state.cancel_pending_prediction(pending_prediction, cx);
1428 }
1429
1430 match self.edit_prediction_model {
1431 EditPredictionModel::Sweep => {
1432 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1433 }
1434 EditPredictionModel::Mercury => {
1435 mercury::edit_prediction_accepted(
1436 current_prediction.prediction.id,
1437 self.client.http_client(),
1438 cx,
1439 );
1440 }
1441 EditPredictionModel::Zeta => {
1442 let is_cloud = !matches!(
1443 all_language_settings(None, cx).edit_predictions.provider,
1444 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1445 );
1446 if is_cloud {
1447 zeta::edit_prediction_accepted(self, current_prediction, cx)
1448 }
1449 }
1450 EditPredictionModel::Fim { .. } => {}
1451 }
1452 }
1453
1454 async fn handle_rejected_predictions(
1455 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1456 client: Arc<Client>,
1457 llm_token: LlmApiToken,
1458 app_version: Version,
1459 background_executor: BackgroundExecutor,
1460 ) {
1461 let mut rx = std::pin::pin!(rx.peekable());
1462 let mut batched = Vec::new();
1463
1464 while let Some(EditPredictionRejectionPayload {
1465 rejection,
1466 organization_id,
1467 }) = rx.next().await
1468 {
1469 batched.push(rejection);
1470
1471 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1472 select_biased! {
1473 next = rx.as_mut().peek().fuse() => {
1474 if next.is_some() {
1475 continue;
1476 }
1477 }
1478 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1479 }
1480 }
1481
1482 let url = client
1483 .http_client()
1484 .build_zed_llm_url("/predict_edits/reject", &[])
1485 .unwrap();
1486
1487 let flush_count = batched
1488 .len()
1489 // in case items have accumulated after failure
1490 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1491 let start = batched.len() - flush_count;
1492
1493 let body = RejectEditPredictionsBodyRef {
1494 rejections: &batched[start..],
1495 };
1496
1497 let result = Self::send_api_request::<()>(
1498 |builder| {
1499 let req = builder
1500 .uri(url.as_ref())
1501 .body(serde_json::to_string(&body)?.into());
1502 anyhow::Ok(req?)
1503 },
1504 client.clone(),
1505 llm_token.clone(),
1506 organization_id,
1507 app_version.clone(),
1508 true,
1509 )
1510 .await;
1511
1512 if result.log_err().is_some() {
1513 batched.drain(start..);
1514 }
1515 }
1516 }
1517
1518 async fn run_settled_predictions_worker(
1519 this: WeakEntity<Self>,
1520 mut rx: UnboundedReceiver<Instant>,
1521 cx: &mut AsyncApp,
1522 ) {
1523 let mut next_wake_time: Option<Instant> = None;
1524 loop {
1525 let now = cx.background_executor().now();
1526 if let Some(wake_time) = next_wake_time.take() {
1527 cx.background_executor()
1528 .timer(wake_time.duration_since(now))
1529 .await;
1530 } else {
1531 let Some(new_enqueue_time) = rx.next().await else {
1532 break;
1533 };
1534 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1535 while rx.next().now_or_never().flatten().is_some() {}
1536 continue;
1537 }
1538
1539 let Some(this) = this.upgrade() else {
1540 break;
1541 };
1542
1543 let now = cx.background_executor().now();
1544
1545 let mut oldest_edited_at = None;
1546
1547 this.update(cx, |this, _| {
1548 for (_, project_state) in this.projects.iter_mut() {
1549 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1550 registered_buffer
1551 .pending_predictions
1552 .retain_mut(|pending_prediction| {
1553 let age =
1554 now.saturating_duration_since(pending_prediction.enqueued_at);
1555 if age >= EDIT_PREDICTION_SETTLED_TTL {
1556 return false;
1557 }
1558
1559 let quiet_for =
1560 now.saturating_duration_since(pending_prediction.last_edit_at);
1561 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1562 let settled_editable_region = registered_buffer
1563 .snapshot
1564 .text_for_range(
1565 pending_prediction.editable_anchor_range.clone(),
1566 )
1567 .collect::<String>();
1568
1569 #[cfg(test)]
1570 if let Some(callback) = &this.settled_event_callback {
1571 callback(
1572 pending_prediction.request_id.clone(),
1573 settled_editable_region.clone(),
1574 );
1575 }
1576
1577 telemetry::event!(
1578 EDIT_PREDICTION_SETTLED_EVENT,
1579 request_id = pending_prediction.request_id.0.clone(),
1580 settled_editable_region,
1581 );
1582
1583 return false;
1584 }
1585
1586 if oldest_edited_at
1587 .is_none_or(|t| pending_prediction.last_edit_at < t)
1588 {
1589 oldest_edited_at = Some(pending_prediction.last_edit_at);
1590 }
1591
1592 true
1593 });
1594 }
1595 }
1596 });
1597
1598 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1599 }
1600 }
1601
1602 pub(crate) fn enqueue_settled_prediction(
1603 &mut self,
1604 request_id: EditPredictionId,
1605 project: &Entity<Project>,
1606 edited_buffer: &Entity<Buffer>,
1607 edited_buffer_snapshot: &BufferSnapshot,
1608 editable_offset_range: Range<usize>,
1609 cx: &mut Context<Self>,
1610 ) {
1611 let project_state = self.get_or_init_project(project, cx);
1612 if let Some(buffer) = project_state
1613 .registered_buffers
1614 .get_mut(&edited_buffer.entity_id())
1615 {
1616 let now = cx.background_executor().now();
1617 buffer.pending_predictions.push(PendingSettledPrediction {
1618 request_id,
1619 editable_anchor_range: edited_buffer_snapshot
1620 .anchor_range_around(editable_offset_range),
1621 enqueued_at: now,
1622 last_edit_at: now,
1623 });
1624 self.settled_predictions_tx.unbounded_send(now).ok();
1625 }
1626 }
1627
1628 fn reject_current_prediction(
1629 &mut self,
1630 reason: EditPredictionRejectReason,
1631 project: &Entity<Project>,
1632 cx: &App,
1633 ) {
1634 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1635 project_state.pending_predictions.clear();
1636 if let Some(prediction) = project_state.current_prediction.take() {
1637 let model_version = prediction.prediction.model_version.clone();
1638 self.reject_prediction(
1639 prediction.prediction.id,
1640 reason,
1641 prediction.was_shown,
1642 model_version,
1643 cx,
1644 );
1645 }
1646 };
1647 }
1648
1649 fn did_show_current_prediction(
1650 &mut self,
1651 project: &Entity<Project>,
1652 display_type: edit_prediction_types::SuggestionDisplayType,
1653 cx: &mut Context<Self>,
1654 ) {
1655 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1656 return;
1657 };
1658
1659 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1660 return;
1661 };
1662
1663 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1664 let previous_shown_with = current_prediction.shown_with;
1665
1666 if previous_shown_with.is_none() || !is_jump {
1667 current_prediction.shown_with = Some(display_type);
1668 }
1669
1670 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1671
1672 if is_first_non_jump_show {
1673 current_prediction.was_shown = true;
1674 }
1675
1676 let display_type_changed = previous_shown_with != Some(display_type);
1677
1678 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1679 sweep_ai::edit_prediction_shown(
1680 &self.sweep_ai,
1681 self.client.clone(),
1682 ¤t_prediction.prediction,
1683 display_type,
1684 cx,
1685 );
1686 }
1687
1688 if is_first_non_jump_show {
1689 self.shown_predictions
1690 .push_front(current_prediction.prediction.clone());
1691 if self.shown_predictions.len() > 50 {
1692 let completion = self.shown_predictions.pop_back().unwrap();
1693 self.rated_predictions.remove(&completion.id);
1694 }
1695 }
1696 }
1697
1698 fn reject_prediction(
1699 &mut self,
1700 prediction_id: EditPredictionId,
1701 reason: EditPredictionRejectReason,
1702 was_shown: bool,
1703 model_version: Option<String>,
1704 cx: &App,
1705 ) {
1706 match self.edit_prediction_model {
1707 EditPredictionModel::Zeta => {
1708 let is_cloud = !matches!(
1709 all_language_settings(None, cx).edit_predictions.provider,
1710 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1711 );
1712
1713 if is_cloud {
1714 let organization_id = self
1715 .user_store
1716 .read(cx)
1717 .current_organization()
1718 .map(|organization| organization.id.clone());
1719
1720 self.reject_predictions_tx
1721 .unbounded_send(EditPredictionRejectionPayload {
1722 rejection: EditPredictionRejection {
1723 request_id: prediction_id.to_string(),
1724 reason,
1725 was_shown,
1726 model_version,
1727 },
1728 organization_id,
1729 })
1730 .log_err();
1731 }
1732 }
1733 EditPredictionModel::Mercury => {
1734 mercury::edit_prediction_rejected(
1735 prediction_id,
1736 was_shown,
1737 reason,
1738 self.client.http_client(),
1739 cx,
1740 );
1741 }
1742 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1743 }
1744 }
1745
1746 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1747 self.projects
1748 .get(&project.entity_id())
1749 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1750 }
1751
1752 pub fn refresh_prediction_from_buffer(
1753 &mut self,
1754 project: Entity<Project>,
1755 buffer: Entity<Buffer>,
1756 position: language::Anchor,
1757 cx: &mut Context<Self>,
1758 ) {
1759 self.queue_prediction_refresh(
1760 project.clone(),
1761 PredictEditsRequestTrigger::Other,
1762 buffer.entity_id(),
1763 cx,
1764 move |this, cx| {
1765 let Some(request_task) = this
1766 .update(cx, |this, cx| {
1767 this.request_prediction(
1768 &project,
1769 &buffer,
1770 position,
1771 PredictEditsRequestTrigger::Other,
1772 cx,
1773 )
1774 })
1775 .log_err()
1776 else {
1777 return Task::ready(anyhow::Ok(None));
1778 };
1779
1780 cx.spawn(async move |_cx| {
1781 request_task.await.map(|prediction_result| {
1782 prediction_result.map(|prediction_result| {
1783 (
1784 prediction_result,
1785 PredictionRequestedBy::Buffer(buffer.entity_id()),
1786 )
1787 })
1788 })
1789 })
1790 },
1791 )
1792 }
1793
1794 pub fn refresh_prediction_from_diagnostics(
1795 &mut self,
1796 project: Entity<Project>,
1797 scope: DiagnosticSearchScope,
1798 cx: &mut Context<Self>,
1799 ) {
1800 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1801 return;
1802 }
1803
1804 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1805 return;
1806 };
1807
1808 // Prefer predictions from buffer
1809 if project_state.current_prediction.is_some() {
1810 return;
1811 }
1812
1813 self.queue_prediction_refresh(
1814 project.clone(),
1815 PredictEditsRequestTrigger::Diagnostics,
1816 project.entity_id(),
1817 cx,
1818 move |this, cx| {
1819 let Some((active_buffer, snapshot, cursor_point)) = this
1820 .read_with(cx, |this, cx| {
1821 let project_state = this.projects.get(&project.entity_id())?;
1822 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1823 let snapshot = buffer.read(cx).snapshot();
1824
1825 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1826 return None;
1827 }
1828
1829 let cursor_point = position
1830 .map(|pos| pos.to_point(&snapshot))
1831 .unwrap_or_default();
1832
1833 Some((buffer, snapshot, cursor_point))
1834 })
1835 .log_err()
1836 .flatten()
1837 else {
1838 return Task::ready(anyhow::Ok(None));
1839 };
1840
1841 cx.spawn(async move |cx| {
1842 let diagnostic_search_range = match scope {
1843 DiagnosticSearchScope::Local => {
1844 let diagnostic_search_start =
1845 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1846 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1847 Point::new(diagnostic_search_start, 0)
1848 ..Point::new(diagnostic_search_end, 0)
1849 }
1850 DiagnosticSearchScope::Global => Default::default(),
1851 };
1852
1853 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1854 active_buffer,
1855 &snapshot,
1856 diagnostic_search_range,
1857 cursor_point,
1858 &project,
1859 cx,
1860 )
1861 .await?
1862 else {
1863 return anyhow::Ok(None);
1864 };
1865
1866 let Some(prediction_result) = this
1867 .update(cx, |this, cx| {
1868 this.request_prediction(
1869 &project,
1870 &jump_buffer,
1871 jump_position,
1872 PredictEditsRequestTrigger::Diagnostics,
1873 cx,
1874 )
1875 })?
1876 .await?
1877 else {
1878 return anyhow::Ok(None);
1879 };
1880
1881 this.update(cx, |this, cx| {
1882 Some((
1883 if this
1884 .get_or_init_project(&project, cx)
1885 .current_prediction
1886 .is_none()
1887 {
1888 prediction_result
1889 } else {
1890 EditPredictionResult {
1891 id: prediction_result.id,
1892 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1893 }
1894 },
1895 PredictionRequestedBy::DiagnosticsUpdate,
1896 ))
1897 })
1898 })
1899 },
1900 );
1901 }
1902
1903 fn predictions_enabled_at(
1904 snapshot: &BufferSnapshot,
1905 position: Option<language::Anchor>,
1906 cx: &App,
1907 ) -> bool {
1908 let file = snapshot.file();
1909 let all_settings = all_language_settings(file, cx);
1910 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1911 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1912 {
1913 return false;
1914 }
1915
1916 if let Some(last_position) = position {
1917 let settings = snapshot.settings_at(last_position, cx);
1918
1919 if !settings.edit_predictions_disabled_in.is_empty()
1920 && let Some(scope) = snapshot.language_scope_at(last_position)
1921 && let Some(scope_name) = scope.override_name()
1922 && settings
1923 .edit_predictions_disabled_in
1924 .iter()
1925 .any(|s| s == scope_name)
1926 {
1927 return false;
1928 }
1929 }
1930
1931 true
1932 }
1933
1934 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1935}
1936
1937fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1938 match provider {
1939 EditPredictionProvider::Zed
1940 | EditPredictionProvider::Sweep
1941 | EditPredictionProvider::Mercury
1942 | EditPredictionProvider::Ollama
1943 | EditPredictionProvider::OpenAiCompatibleApi
1944 | EditPredictionProvider::Experimental(_) => true,
1945 EditPredictionProvider::None
1946 | EditPredictionProvider::Copilot
1947 | EditPredictionProvider::Codestral => false,
1948 }
1949}
1950
1951impl EditPredictionStore {
1952 fn queue_prediction_refresh(
1953 &mut self,
1954 project: Entity<Project>,
1955 request_trigger: PredictEditsRequestTrigger,
1956 throttle_entity: EntityId,
1957 cx: &mut Context<Self>,
1958 do_refresh: impl FnOnce(
1959 WeakEntity<Self>,
1960 &mut AsyncApp,
1961 )
1962 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1963 + 'static,
1964 ) {
1965 fn select_throttle(
1966 project_state: &mut ProjectState,
1967 request_trigger: PredictEditsRequestTrigger,
1968 ) -> &mut Option<(EntityId, Instant)> {
1969 match request_trigger {
1970 PredictEditsRequestTrigger::Diagnostics => {
1971 &mut project_state.last_jump_prediction_refresh
1972 }
1973 _ => &mut project_state.last_edit_prediction_refresh,
1974 }
1975 }
1976
1977 let (needs_acceptance_tracking, max_pending_predictions) =
1978 match all_language_settings(None, cx).edit_predictions.provider {
1979 EditPredictionProvider::Zed
1980 | EditPredictionProvider::Sweep
1981 | EditPredictionProvider::Mercury
1982 | EditPredictionProvider::Experimental(_) => (true, 2),
1983 EditPredictionProvider::Ollama => (false, 1),
1984 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1985 EditPredictionProvider::None
1986 | EditPredictionProvider::Copilot
1987 | EditPredictionProvider::Codestral => {
1988 log::error!("queue_prediction_refresh called with non-store provider");
1989 return;
1990 }
1991 };
1992
1993 let drop_on_cancel = !needs_acceptance_tracking;
1994 let throttle_timeout = Self::THROTTLE_TIMEOUT;
1995 let project_state = self.get_or_init_project(&project, cx);
1996 let pending_prediction_id = project_state.next_pending_prediction_id;
1997 project_state.next_pending_prediction_id += 1;
1998 let last_request = *select_throttle(project_state, request_trigger);
1999
2000 let task = cx.spawn(async move |this, cx| {
2001 if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
2002 if throttle_entity != last_entity {
2003 return None;
2004 }
2005 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2006 }) {
2007 cx.background_executor().timer(timeout).await;
2008 }
2009
2010 // If this task was cancelled before the throttle timeout expired,
2011 // do not perform a request.
2012 let mut is_cancelled = true;
2013 this.update(cx, |this, cx| {
2014 let project_state = this.get_or_init_project(&project, cx);
2015 let was_cancelled = project_state
2016 .cancelled_predictions
2017 .remove(&pending_prediction_id);
2018 if !was_cancelled {
2019 let new_refresh = (throttle_entity, Instant::now());
2020 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2021 is_cancelled = false;
2022 }
2023 })
2024 .ok();
2025 if is_cancelled {
2026 return None;
2027 }
2028
2029 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2030 let new_prediction_id = new_prediction_result
2031 .as_ref()
2032 .map(|(prediction, _)| prediction.id.clone());
2033
2034 // When a prediction completes, remove it from the pending list, and cancel
2035 // any pending predictions that were enqueued before it.
2036 this.update(cx, |this, cx| {
2037 let project_state = this.get_or_init_project(&project, cx);
2038
2039 let is_cancelled = project_state
2040 .cancelled_predictions
2041 .remove(&pending_prediction_id);
2042
2043 let new_current_prediction = if !is_cancelled
2044 && let Some((prediction_result, requested_by)) = new_prediction_result
2045 {
2046 match prediction_result.prediction {
2047 Ok(prediction) => {
2048 let new_prediction = CurrentEditPrediction {
2049 requested_by,
2050 prediction,
2051 was_shown: false,
2052 shown_with: None,
2053 };
2054
2055 if let Some(current_prediction) =
2056 project_state.current_prediction.as_ref()
2057 {
2058 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2059 {
2060 this.reject_current_prediction(
2061 EditPredictionRejectReason::Replaced,
2062 &project,
2063 cx,
2064 );
2065
2066 Some(new_prediction)
2067 } else {
2068 this.reject_prediction(
2069 new_prediction.prediction.id,
2070 EditPredictionRejectReason::CurrentPreferred,
2071 false,
2072 new_prediction.prediction.model_version,
2073 cx,
2074 );
2075 None
2076 }
2077 } else {
2078 Some(new_prediction)
2079 }
2080 }
2081 Err(reject_reason) => {
2082 this.reject_prediction(
2083 prediction_result.id,
2084 reject_reason,
2085 false,
2086 None,
2087 cx,
2088 );
2089 None
2090 }
2091 }
2092 } else {
2093 None
2094 };
2095
2096 let project_state = this.get_or_init_project(&project, cx);
2097
2098 if let Some(new_prediction) = new_current_prediction {
2099 project_state.current_prediction = Some(new_prediction);
2100 }
2101
2102 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2103 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2104 if pending_prediction.id == pending_prediction_id {
2105 pending_predictions.remove(ix);
2106 for pending_prediction in pending_predictions.drain(0..ix) {
2107 project_state.cancel_pending_prediction(pending_prediction, cx)
2108 }
2109 break;
2110 }
2111 }
2112 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2113 cx.notify();
2114 })
2115 .ok();
2116
2117 new_prediction_id
2118 });
2119
2120 if project_state.pending_predictions.len() < max_pending_predictions {
2121 project_state.pending_predictions.push(PendingPrediction {
2122 id: pending_prediction_id,
2123 task,
2124 drop_on_cancel,
2125 });
2126 } else {
2127 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2128 project_state.pending_predictions.push(PendingPrediction {
2129 id: pending_prediction_id,
2130 task,
2131 drop_on_cancel,
2132 });
2133 project_state.cancel_pending_prediction(pending_prediction, cx);
2134 }
2135 }
2136
2137 pub fn request_prediction(
2138 &mut self,
2139 project: &Entity<Project>,
2140 active_buffer: &Entity<Buffer>,
2141 position: language::Anchor,
2142 trigger: PredictEditsRequestTrigger,
2143 cx: &mut Context<Self>,
2144 ) -> Task<Result<Option<EditPredictionResult>>> {
2145 self.request_prediction_internal(
2146 project.clone(),
2147 active_buffer.clone(),
2148 position,
2149 trigger,
2150 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2151 cx,
2152 )
2153 }
2154
2155 fn request_prediction_internal(
2156 &mut self,
2157 project: Entity<Project>,
2158 active_buffer: Entity<Buffer>,
2159 position: language::Anchor,
2160 trigger: PredictEditsRequestTrigger,
2161 allow_jump: bool,
2162 cx: &mut Context<Self>,
2163 ) -> Task<Result<Option<EditPredictionResult>>> {
2164 self.get_or_init_project(&project, cx);
2165 let project_state = self.projects.get(&project.entity_id()).unwrap();
2166 let stored_events = project_state.events(cx);
2167 let has_events = !stored_events.is_empty();
2168 let events: Vec<Arc<zeta_prompt::Event>> =
2169 stored_events.iter().map(|e| e.event.clone()).collect();
2170 let debug_tx = project_state.debug_tx.clone();
2171
2172 let snapshot = active_buffer.read(cx).snapshot();
2173 let cursor_point = position.to_point(&snapshot);
2174 let current_offset = position.to_offset(&snapshot);
2175
2176 let mut user_actions: Vec<UserActionRecord> =
2177 project_state.user_actions.iter().cloned().collect();
2178
2179 if let Some(last_action) = user_actions.last() {
2180 if last_action.buffer_id == active_buffer.entity_id()
2181 && current_offset != last_action.offset
2182 {
2183 let timestamp_epoch_ms = SystemTime::now()
2184 .duration_since(UNIX_EPOCH)
2185 .map(|d| d.as_millis() as u64)
2186 .unwrap_or(0);
2187 user_actions.push(UserActionRecord {
2188 action_type: UserActionType::CursorMovement,
2189 buffer_id: active_buffer.entity_id(),
2190 line_number: cursor_point.row,
2191 offset: current_offset,
2192 timestamp_epoch_ms,
2193 });
2194 }
2195 }
2196 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2197 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2198 let diagnostic_search_range =
2199 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2200
2201 let related_files = self.context_for_project(&project, cx);
2202
2203 let is_open_source = snapshot
2204 .file()
2205 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2206 && events.iter().all(|event| event.in_open_source_repo())
2207 && related_files.iter().all(|file| file.in_open_source_repo);
2208
2209 let can_collect_data = !cfg!(test)
2210 && is_open_source
2211 && self.is_data_collection_enabled(cx)
2212 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2213
2214 let inputs = EditPredictionModelInput {
2215 project: project.clone(),
2216 buffer: active_buffer.clone(),
2217 snapshot: snapshot,
2218 position,
2219 events,
2220 related_files,
2221 recent_paths: project_state.recent_paths.clone(),
2222 trigger,
2223 diagnostic_search_range: diagnostic_search_range,
2224 debug_tx,
2225 user_actions,
2226 can_collect_data,
2227 is_open_source,
2228 };
2229
2230 if can_collect_data && rand::random_ratio(1, 1000) {
2231 if let Some(task) = capture_example(
2232 project.clone(),
2233 active_buffer,
2234 position,
2235 stored_events,
2236 false,
2237 cx,
2238 ) {
2239 task.detach();
2240 }
2241 }
2242
2243 let task = match self.edit_prediction_model {
2244 EditPredictionModel::Zeta => zeta::request_prediction_with_zeta(self, inputs, cx),
2245 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2246 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2247 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2248 };
2249
2250 cx.spawn(async move |this, cx| {
2251 let prediction = task.await?;
2252
2253 if prediction.is_none() && allow_jump && has_events {
2254 this.update(cx, |this, cx| {
2255 this.refresh_prediction_from_diagnostics(
2256 project,
2257 DiagnosticSearchScope::Local,
2258 cx,
2259 );
2260 })?;
2261 return anyhow::Ok(None);
2262 }
2263
2264 Ok(prediction)
2265 })
2266 }
2267
2268 pub(crate) async fn next_diagnostic_location(
2269 active_buffer: Entity<Buffer>,
2270 active_buffer_snapshot: &BufferSnapshot,
2271 active_buffer_diagnostic_search_range: Range<Point>,
2272 active_buffer_cursor_point: Point,
2273 project: &Entity<Project>,
2274 cx: &mut AsyncApp,
2275 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2276 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2277 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2278 .flat_map(|(_, _, _, selections)| {
2279 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2280 })
2281 .collect();
2282
2283 let mut jump_location = active_buffer_snapshot
2284 .diagnostic_groups(None)
2285 .into_iter()
2286 .filter_map(|(_, group)| {
2287 let range = &group.entries[group.primary_ix]
2288 .range
2289 .to_point(&active_buffer_snapshot);
2290 if range.overlaps(&active_buffer_diagnostic_search_range) {
2291 return None;
2292 }
2293 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2294 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2295 });
2296 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2297 <= DIAGNOSTIC_LINES_RANGE;
2298 if near_collaborator && !near_local {
2299 return None;
2300 }
2301 Some(range.start)
2302 })
2303 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2304 .map(|position| {
2305 (
2306 active_buffer.clone(),
2307 active_buffer_snapshot.anchor_before(position),
2308 )
2309 });
2310
2311 if jump_location.is_none() {
2312 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2313 let file = buffer.file()?;
2314
2315 Some(ProjectPath {
2316 worktree_id: file.worktree_id(cx),
2317 path: file.path().clone(),
2318 })
2319 });
2320
2321 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2322 project
2323 .diagnostic_summaries(false, cx)
2324 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2325 .map(|(path, _, _)| {
2326 let shared_prefix = path
2327 .path
2328 .components()
2329 .zip(
2330 active_buffer_path
2331 .as_ref()
2332 .map(|p| p.path.components())
2333 .unwrap_or_default(),
2334 )
2335 .take_while(|(a, b)| a == b)
2336 .count();
2337 (path, shared_prefix)
2338 })
2339 .collect()
2340 });
2341
2342 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2343
2344 for (path, _) in candidates {
2345 let candidate_buffer = project
2346 .update(cx, |project, cx| project.open_buffer(path, cx))
2347 .await?;
2348
2349 let (has_collaborators, diagnostic_position) =
2350 candidate_buffer.read_with(cx, |buffer, _cx| {
2351 let snapshot = buffer.snapshot();
2352 let has_collaborators = snapshot
2353 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2354 .next()
2355 .is_some();
2356 let position = buffer
2357 .buffer_diagnostics(None)
2358 .into_iter()
2359 .min_by_key(|entry| entry.diagnostic.severity)
2360 .map(|entry| entry.range.start);
2361 (has_collaborators, position)
2362 });
2363
2364 if has_collaborators {
2365 continue;
2366 }
2367
2368 if let Some(position) = diagnostic_position {
2369 jump_location = Some((candidate_buffer, position));
2370 break;
2371 }
2372 }
2373 }
2374
2375 anyhow::Ok(jump_location)
2376 }
2377
2378 async fn send_raw_llm_request(
2379 request: RawCompletionRequest,
2380 client: Arc<Client>,
2381 custom_url: Option<Arc<Url>>,
2382 llm_token: LlmApiToken,
2383 organization_id: Option<OrganizationId>,
2384 app_version: Version,
2385 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2386 let url = if let Some(custom_url) = custom_url {
2387 custom_url.as_ref().clone()
2388 } else {
2389 client
2390 .http_client()
2391 .build_zed_llm_url("/predict_edits/raw", &[])?
2392 };
2393
2394 Self::send_api_request(
2395 |builder| {
2396 let req = builder
2397 .uri(url.as_ref())
2398 .body(serde_json::to_string(&request)?.into());
2399 Ok(req?)
2400 },
2401 client,
2402 llm_token,
2403 organization_id,
2404 app_version,
2405 true,
2406 )
2407 .await
2408 }
2409
2410 pub(crate) async fn send_v3_request(
2411 input: ZetaPromptInput,
2412 client: Arc<Client>,
2413 llm_token: LlmApiToken,
2414 organization_id: Option<OrganizationId>,
2415 app_version: Version,
2416 trigger: PredictEditsRequestTrigger,
2417 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2418 let url = client
2419 .http_client()
2420 .build_zed_llm_url("/predict_edits/v3", &[])?;
2421
2422 let request = PredictEditsV3Request { input, trigger };
2423
2424 let json_bytes = serde_json::to_vec(&request)?;
2425 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2426
2427 Self::send_api_request(
2428 |builder| {
2429 let req = builder
2430 .uri(url.as_ref())
2431 .header("Content-Encoding", "zstd")
2432 .body(compressed.clone().into());
2433 Ok(req?)
2434 },
2435 client,
2436 llm_token,
2437 organization_id,
2438 app_version,
2439 true,
2440 )
2441 .await
2442 }
2443
2444 fn handle_api_response<T>(
2445 this: &WeakEntity<Self>,
2446 response: Result<(T, Option<EditPredictionUsage>)>,
2447 cx: &mut gpui::AsyncApp,
2448 ) -> Result<T> {
2449 match response {
2450 Ok((data, usage)) => {
2451 if let Some(usage) = usage {
2452 this.update(cx, |this, cx| {
2453 this.user_store.update(cx, |user_store, cx| {
2454 user_store.update_edit_prediction_usage(usage, cx);
2455 });
2456 })
2457 .ok();
2458 }
2459 Ok(data)
2460 }
2461 Err(err) => {
2462 if err.is::<ZedUpdateRequiredError>() {
2463 cx.update(|cx| {
2464 this.update(cx, |this, _cx| {
2465 this.update_required = true;
2466 })
2467 .ok();
2468
2469 let error_message: SharedString = err.to_string().into();
2470 show_app_notification(
2471 NotificationId::unique::<ZedUpdateRequiredError>(),
2472 cx,
2473 move |cx| {
2474 cx.new(|cx| {
2475 ErrorMessagePrompt::new(error_message.clone(), cx)
2476 .with_link_button("Update Zed", "https://zed.dev/releases")
2477 })
2478 },
2479 );
2480 });
2481 }
2482 Err(err)
2483 }
2484 }
2485 }
2486
2487 async fn send_api_request<Res>(
2488 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2489 client: Arc<Client>,
2490 llm_token: LlmApiToken,
2491 organization_id: Option<OrganizationId>,
2492 app_version: Version,
2493 require_auth: bool,
2494 ) -> Result<(Res, Option<EditPredictionUsage>)>
2495 where
2496 Res: DeserializeOwned,
2497 {
2498 let http_client = client.http_client();
2499
2500 let mut token = if require_auth {
2501 Some(llm_token.acquire(&client, organization_id.clone()).await?)
2502 } else {
2503 llm_token
2504 .acquire(&client, organization_id.clone())
2505 .await
2506 .ok()
2507 };
2508 let mut did_retry = false;
2509
2510 loop {
2511 let request_builder = http_client::Request::builder().method(Method::POST);
2512
2513 let mut request_builder = request_builder
2514 .header("Content-Type", "application/json")
2515 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2516
2517 // Only add Authorization header if we have a token
2518 if let Some(ref token_value) = token {
2519 request_builder =
2520 request_builder.header("Authorization", format!("Bearer {}", token_value));
2521 }
2522
2523 let request = build(request_builder)?;
2524
2525 let mut response = http_client.send(request).await?;
2526
2527 if let Some(minimum_required_version) = response
2528 .headers()
2529 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2530 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2531 {
2532 anyhow::ensure!(
2533 app_version >= minimum_required_version,
2534 ZedUpdateRequiredError {
2535 minimum_version: minimum_required_version
2536 }
2537 );
2538 }
2539
2540 if response.status().is_success() {
2541 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2542
2543 let mut body = Vec::new();
2544 response.body_mut().read_to_end(&mut body).await?;
2545 return Ok((serde_json::from_slice(&body)?, usage));
2546 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2547 did_retry = true;
2548 token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2549 } else {
2550 let mut body = String::new();
2551 response.body_mut().read_to_string(&mut body).await?;
2552 anyhow::bail!(
2553 "Request failed with status: {:?}\nBody: {}",
2554 response.status(),
2555 body
2556 );
2557 }
2558 }
2559 }
2560
2561 pub fn refresh_context(
2562 &mut self,
2563 project: &Entity<Project>,
2564 buffer: &Entity<language::Buffer>,
2565 cursor_position: language::Anchor,
2566 cx: &mut Context<Self>,
2567 ) {
2568 self.get_or_init_project(project, cx)
2569 .context
2570 .update(cx, |store, cx| {
2571 store.refresh(buffer.clone(), cursor_position, cx);
2572 });
2573 }
2574
2575 #[cfg(feature = "cli-support")]
2576 pub fn set_context_for_buffer(
2577 &mut self,
2578 project: &Entity<Project>,
2579 related_files: Vec<RelatedFile>,
2580 cx: &mut Context<Self>,
2581 ) {
2582 self.get_or_init_project(project, cx)
2583 .context
2584 .update(cx, |store, cx| {
2585 store.set_related_files(related_files, cx);
2586 });
2587 }
2588
2589 #[cfg(feature = "cli-support")]
2590 pub fn set_recent_paths_for_project(
2591 &mut self,
2592 project: &Entity<Project>,
2593 paths: impl IntoIterator<Item = project::ProjectPath>,
2594 cx: &mut Context<Self>,
2595 ) {
2596 let project_state = self.get_or_init_project(project, cx);
2597 project_state.recent_paths = paths.into_iter().collect();
2598 }
2599
2600 fn is_file_open_source(
2601 &self,
2602 project: &Entity<Project>,
2603 file: &Arc<dyn File>,
2604 cx: &App,
2605 ) -> bool {
2606 if !file.is_local() || file.is_private() {
2607 return false;
2608 }
2609 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2610 return false;
2611 };
2612 project_state
2613 .license_detection_watchers
2614 .get(&file.worktree_id(cx))
2615 .as_ref()
2616 .is_some_and(|watcher| watcher.is_project_open_source())
2617 }
2618
2619 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2620 self.data_collection_choice.is_enabled(cx)
2621 }
2622
2623 fn load_data_collection_choice() -> DataCollectionChoice {
2624 let choice = KEY_VALUE_STORE
2625 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2626 .log_err()
2627 .flatten();
2628
2629 match choice.as_deref() {
2630 Some("true") => DataCollectionChoice::Enabled,
2631 Some("false") => DataCollectionChoice::Disabled,
2632 Some(_) => {
2633 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2634 DataCollectionChoice::NotAnswered
2635 }
2636 None => DataCollectionChoice::NotAnswered,
2637 }
2638 }
2639
2640 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2641 self.data_collection_choice = self.data_collection_choice.toggle();
2642 let new_choice = self.data_collection_choice;
2643 let is_enabled = new_choice.is_enabled(cx);
2644 db::write_and_log(cx, move || {
2645 KEY_VALUE_STORE.write_kvp(
2646 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2647 is_enabled.to_string(),
2648 )
2649 });
2650 }
2651
2652 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2653 self.shown_predictions.iter()
2654 }
2655
2656 pub fn shown_completions_len(&self) -> usize {
2657 self.shown_predictions.len()
2658 }
2659
2660 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2661 self.rated_predictions.contains(id)
2662 }
2663
2664 pub fn rate_prediction(
2665 &mut self,
2666 prediction: &EditPrediction,
2667 rating: EditPredictionRating,
2668 feedback: String,
2669 cx: &mut Context<Self>,
2670 ) {
2671 let organization = self.user_store.read(cx).current_organization();
2672
2673 self.rated_predictions.insert(prediction.id.clone());
2674
2675 cx.background_spawn({
2676 let client = self.client.clone();
2677 let prediction_id = prediction.id.to_string();
2678 let inputs = serde_json::to_value(&prediction.inputs);
2679 let output = prediction
2680 .edit_preview
2681 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2682 async move {
2683 client
2684 .cloud_client()
2685 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2686 organization_id: organization.map(|organization| organization.id.clone()),
2687 request_id: prediction_id,
2688 rating: match rating {
2689 EditPredictionRating::Positive => "positive".to_string(),
2690 EditPredictionRating::Negative => "negative".to_string(),
2691 },
2692 inputs: inputs?,
2693 output,
2694 feedback,
2695 })
2696 .await?;
2697
2698 anyhow::Ok(())
2699 }
2700 })
2701 .detach_and_log_err(cx);
2702
2703 cx.notify();
2704 }
2705}
2706
2707fn merge_trailing_events_if_needed(
2708 events: &mut VecDeque<StoredEvent>,
2709 end_snapshot: &TextBufferSnapshot,
2710 latest_snapshot: &TextBufferSnapshot,
2711 latest_edit_range: &Range<Anchor>,
2712) {
2713 if let Some(last_event) = events.back() {
2714 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2715 return;
2716 }
2717 }
2718
2719 let mut next_old_event = None;
2720 let mut mergeable_count = 0;
2721 for old_event in events.iter().rev() {
2722 if let Some(next_old_event) = &next_old_event
2723 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2724 {
2725 break;
2726 }
2727 mergeable_count += 1;
2728 next_old_event = Some(old_event);
2729 }
2730
2731 if mergeable_count <= 1 {
2732 return;
2733 }
2734
2735 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2736 let oldest_event = events_to_merge.peek().unwrap();
2737 let oldest_snapshot = oldest_event.old_snapshot.clone();
2738
2739 if let Some((diff, edited_range)) =
2740 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2741 {
2742 let merged_event = match oldest_event.event.as_ref() {
2743 zeta_prompt::Event::BufferChange {
2744 old_path,
2745 path,
2746 in_open_source_repo,
2747 ..
2748 } => StoredEvent {
2749 event: Arc::new(zeta_prompt::Event::BufferChange {
2750 old_path: old_path.clone(),
2751 path: path.clone(),
2752 diff,
2753 in_open_source_repo: *in_open_source_repo,
2754 predicted: events_to_merge.all(|e| {
2755 matches!(
2756 e.event.as_ref(),
2757 zeta_prompt::Event::BufferChange {
2758 predicted: true,
2759 ..
2760 }
2761 )
2762 }),
2763 }),
2764 old_snapshot: oldest_snapshot.clone(),
2765 edit_range: end_snapshot.anchor_before(edited_range.start)
2766 ..end_snapshot.anchor_before(edited_range.end),
2767 },
2768 };
2769 events.truncate(events.len() - mergeable_count);
2770 events.push_back(merged_event);
2771 }
2772}
2773
2774pub(crate) fn filter_redundant_excerpts(
2775 mut related_files: Vec<RelatedFile>,
2776 cursor_path: &Path,
2777 cursor_row_range: Range<u32>,
2778) -> Vec<RelatedFile> {
2779 for file in &mut related_files {
2780 if file.path.as_ref() == cursor_path {
2781 file.excerpts.retain(|excerpt| {
2782 excerpt.row_range.start < cursor_row_range.start
2783 || excerpt.row_range.end > cursor_row_range.end
2784 });
2785 }
2786 }
2787 related_files.retain(|file| !file.excerpts.is_empty());
2788 related_files
2789}
2790
2791#[derive(Error, Debug)]
2792#[error(
2793 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2794)]
2795pub struct ZedUpdateRequiredError {
2796 minimum_version: Version,
2797}
2798
2799#[derive(Debug, Clone, Copy)]
2800pub enum DataCollectionChoice {
2801 NotAnswered,
2802 Enabled,
2803 Disabled,
2804}
2805
2806impl DataCollectionChoice {
2807 pub fn is_enabled(self, cx: &App) -> bool {
2808 if cx.is_staff() {
2809 return true;
2810 }
2811 match self {
2812 Self::Enabled => true,
2813 Self::NotAnswered | Self::Disabled => false,
2814 }
2815 }
2816
2817 #[must_use]
2818 pub fn toggle(&self) -> DataCollectionChoice {
2819 match self {
2820 Self::Enabled => Self::Disabled,
2821 Self::Disabled => Self::Enabled,
2822 Self::NotAnswered => Self::Enabled,
2823 }
2824 }
2825}
2826
2827impl From<bool> for DataCollectionChoice {
2828 fn from(value: bool) -> Self {
2829 match value {
2830 true => DataCollectionChoice::Enabled,
2831 false => DataCollectionChoice::Disabled,
2832 }
2833 }
2834}
2835
2836struct ZedPredictUpsell;
2837
2838impl Dismissable for ZedPredictUpsell {
2839 const KEY: &'static str = "dismissed-edit-predict-upsell";
2840
2841 fn dismissed() -> bool {
2842 // To make this backwards compatible with older versions of Zed, we
2843 // check if the user has seen the previous Edit Prediction Onboarding
2844 // before, by checking the data collection choice which was written to
2845 // the database once the user clicked on "Accept and Enable"
2846 if KEY_VALUE_STORE
2847 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2848 .log_err()
2849 .is_some_and(|s| s.is_some())
2850 {
2851 return true;
2852 }
2853
2854 KEY_VALUE_STORE
2855 .read_kvp(Self::KEY)
2856 .log_err()
2857 .is_some_and(|s| s.is_some())
2858 }
2859}
2860
2861pub fn should_show_upsell_modal() -> bool {
2862 !ZedPredictUpsell::dismissed()
2863}
2864
2865pub fn init(cx: &mut App) {
2866 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2867 workspace.register_action(
2868 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2869 ZedPredictModal::toggle(
2870 workspace,
2871 workspace.user_store().clone(),
2872 workspace.client().clone(),
2873 window,
2874 cx,
2875 )
2876 },
2877 );
2878
2879 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2880 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2881 settings
2882 .project
2883 .all_languages
2884 .edit_predictions
2885 .get_or_insert_default()
2886 .provider = Some(EditPredictionProvider::None)
2887 });
2888 });
2889 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2890 EditPredictionStore::try_global(cx).and_then(|store| {
2891 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2892 })
2893 }
2894
2895 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2896 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2897 copilot_ui::initiate_sign_in(copilot, window, cx);
2898 }
2899 });
2900 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2901 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2902 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2903 }
2904 });
2905 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2906 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2907 copilot_ui::initiate_sign_out(copilot, window, cx);
2908 }
2909 });
2910 })
2911 .detach();
2912}