1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::SubmitEditPredictionFeedbackBody;
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod cursor_excerpt;
59pub mod example_spec;
60pub mod fim;
61mod license_detection;
62pub mod mercury;
63pub mod ollama;
64mod onboarding_modal;
65pub mod open_ai_response;
66mod prediction;
67pub mod sweep_ai;
68
69pub mod udiff;
70
71mod capture_example;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::license_detection::LicenseDetectionWatcher;
79use crate::mercury::Mercury;
80use crate::onboarding_modal::ZedPredictModal;
81pub use crate::prediction::EditPrediction;
82pub use crate::prediction::EditPredictionId;
83use crate::prediction::EditPredictionResult;
84pub use crate::sweep_ai::SweepAi;
85pub use capture_example::capture_example;
86pub use language_model::ApiKeyState;
87pub use telemetry_events::EditPredictionRating;
88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
89
90actions!(
91 edit_prediction,
92 [
93 /// Resets the edit prediction onboarding state.
94 ResetOnboarding,
95 /// Clears the edit prediction history.
96 ClearHistory,
97 ]
98);
99
100/// Maximum number of events to track.
101const EVENT_COUNT_MAX: usize = 6;
102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
106const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
107const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
108const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
109
110pub struct Zeta2FeatureFlag;
111pub struct EditPredictionJumpsFeatureFlag;
112
113impl FeatureFlag for Zeta2FeatureFlag {
114 const NAME: &'static str = "zeta2";
115}
116
117impl FeatureFlag for EditPredictionJumpsFeatureFlag {
118 const NAME: &'static str = "edit_prediction_jumps";
119}
120
121#[derive(Clone)]
122struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
123
124impl Global for EditPredictionStoreGlobal {}
125
126/// Configuration for using the raw Zeta2 endpoint.
127/// When set, the client uses the raw endpoint and constructs the prompt itself.
128/// The version is also used as the Baseten environment name (lowercased).
129#[derive(Clone)]
130pub struct Zeta2RawConfig {
131 pub model_id: Option<String>,
132 pub format: ZetaFormat,
133}
134
135pub struct EditPredictionStore {
136 client: Arc<Client>,
137 user_store: Entity<UserStore>,
138 llm_token: LlmApiToken,
139 _llm_token_subscription: Subscription,
140 projects: HashMap<EntityId, ProjectState>,
141 update_required: bool,
142 edit_prediction_model: EditPredictionModel,
143 zeta2_raw_config: Option<Zeta2RawConfig>,
144 pub sweep_ai: SweepAi,
145 pub mercury: Mercury,
146 data_collection_choice: DataCollectionChoice,
147 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
148 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
149 shown_predictions: VecDeque<EditPrediction>,
150 rated_predictions: HashSet<EditPredictionId>,
151 #[cfg(test)]
152 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
153}
154
155#[derive(Copy, Clone, PartialEq, Eq)]
156pub enum EditPredictionModel {
157 Zeta1,
158 Zeta2,
159 Fim { format: EditPredictionPromptFormat },
160 Sweep,
161 Mercury,
162}
163
164#[derive(Clone)]
165pub struct EditPredictionModelInput {
166 project: Entity<Project>,
167 buffer: Entity<Buffer>,
168 snapshot: BufferSnapshot,
169 position: Anchor,
170 events: Vec<Arc<zeta_prompt::Event>>,
171 related_files: Vec<RelatedFile>,
172 recent_paths: VecDeque<ProjectPath>,
173 trigger: PredictEditsRequestTrigger,
174 diagnostic_search_range: Range<Point>,
175 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
176 can_collect_data: bool,
177 is_open_source: bool,
178 pub user_actions: Vec<UserActionRecord>,
179}
180
181#[derive(Debug)]
182pub enum DebugEvent {
183 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
184 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
185 EditPredictionStarted(EditPredictionStartedDebugEvent),
186 EditPredictionFinished(EditPredictionFinishedDebugEvent),
187}
188
189#[derive(Debug)]
190pub struct ContextRetrievalStartedDebugEvent {
191 pub project_entity_id: EntityId,
192 pub timestamp: Instant,
193 pub search_prompt: String,
194}
195
196#[derive(Debug)]
197pub struct ContextRetrievalFinishedDebugEvent {
198 pub project_entity_id: EntityId,
199 pub timestamp: Instant,
200 pub metadata: Vec<(&'static str, SharedString)>,
201}
202
203#[derive(Debug)]
204pub struct EditPredictionStartedDebugEvent {
205 pub buffer: WeakEntity<Buffer>,
206 pub position: Anchor,
207 pub prompt: Option<String>,
208}
209
210#[derive(Debug)]
211pub struct EditPredictionFinishedDebugEvent {
212 pub buffer: WeakEntity<Buffer>,
213 pub position: Anchor,
214 pub model_output: Option<String>,
215}
216
217const USER_ACTION_HISTORY_SIZE: usize = 16;
218
219#[derive(Clone, Debug)]
220pub struct UserActionRecord {
221 pub action_type: UserActionType,
222 pub buffer_id: EntityId,
223 pub line_number: u32,
224 pub offset: usize,
225 pub timestamp_epoch_ms: u64,
226}
227
228#[derive(Clone, Copy, Debug, PartialEq, Eq)]
229pub enum UserActionType {
230 InsertChar,
231 InsertSelection,
232 DeleteChar,
233 DeleteSelection,
234 CursorMovement,
235}
236
237/// An event with associated metadata for reconstructing buffer state.
238#[derive(Clone)]
239pub struct StoredEvent {
240 pub event: Arc<zeta_prompt::Event>,
241 pub old_snapshot: TextBufferSnapshot,
242 pub edit_range: Range<Anchor>,
243}
244
245impl StoredEvent {
246 fn can_merge(
247 &self,
248 next_old_event: &&&StoredEvent,
249 new_snapshot: &TextBufferSnapshot,
250 last_edit_range: &Range<Anchor>,
251 ) -> bool {
252 // Events must be for the same buffer
253 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
254 return false;
255 }
256 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
257 return false;
258 }
259
260 let a_is_predicted = matches!(
261 self.event.as_ref(),
262 zeta_prompt::Event::BufferChange {
263 predicted: true,
264 ..
265 }
266 );
267 let b_is_predicted = matches!(
268 next_old_event.event.as_ref(),
269 zeta_prompt::Event::BufferChange {
270 predicted: true,
271 ..
272 }
273 );
274
275 // If events come from the same source (both predicted or both manual) then
276 // we would have coalesced them already.
277 if a_is_predicted == b_is_predicted {
278 return false;
279 }
280
281 let left_range = self.edit_range.to_point(new_snapshot);
282 let right_range = next_old_event.edit_range.to_point(new_snapshot);
283 let latest_range = last_edit_range.to_point(&new_snapshot);
284
285 // Events near to the latest edit are not merged if their sources differ.
286 if lines_between_ranges(&left_range, &latest_range)
287 .min(lines_between_ranges(&right_range, &latest_range))
288 <= CHANGE_GROUPING_LINE_SPAN
289 {
290 return false;
291 }
292
293 // Events that are distant from each other are not merged.
294 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
295 return false;
296 }
297
298 true
299 }
300}
301
302fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
303 if left.start > right.end {
304 return left.start.row - right.end.row;
305 }
306 if right.start > left.end {
307 return right.start.row - left.end.row;
308 }
309 0
310}
311
312struct ProjectState {
313 events: VecDeque<StoredEvent>,
314 last_event: Option<LastEvent>,
315 recent_paths: VecDeque<ProjectPath>,
316 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
317 current_prediction: Option<CurrentEditPrediction>,
318 next_pending_prediction_id: usize,
319 pending_predictions: ArrayVec<PendingPrediction, 2>,
320 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
321 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
322 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
323 cancelled_predictions: HashSet<usize>,
324 context: Entity<RelatedExcerptStore>,
325 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
326 user_actions: VecDeque<UserActionRecord>,
327 _subscriptions: [gpui::Subscription; 2],
328 copilot: Option<Entity<Copilot>>,
329}
330
331impl ProjectState {
332 fn record_user_action(&mut self, action: UserActionRecord) {
333 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
334 self.user_actions.pop_front();
335 }
336 self.user_actions.push_back(action);
337 }
338
339 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
340 self.events
341 .iter()
342 .cloned()
343 .chain(self.last_event.as_ref().iter().flat_map(|event| {
344 let (one, two) = event.split_by_pause();
345 let one = one.finalize(&self.license_detection_watchers, cx);
346 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
347 one.into_iter().chain(two)
348 }))
349 .collect()
350 }
351
352 fn cancel_pending_prediction(
353 &mut self,
354 pending_prediction: PendingPrediction,
355 cx: &mut Context<EditPredictionStore>,
356 ) {
357 self.cancelled_predictions.insert(pending_prediction.id);
358
359 if pending_prediction.drop_on_cancel {
360 drop(pending_prediction.task);
361 } else {
362 cx.spawn(async move |this, cx| {
363 let Some(prediction_id) = pending_prediction.task.await else {
364 return;
365 };
366
367 this.update(cx, |this, cx| {
368 this.reject_prediction(
369 prediction_id,
370 EditPredictionRejectReason::Canceled,
371 false,
372 None,
373 cx,
374 );
375 })
376 .ok();
377 })
378 .detach()
379 }
380 }
381
382 fn active_buffer(
383 &self,
384 project: &Entity<Project>,
385 cx: &App,
386 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
387 let project = project.read(cx);
388 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
389 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
390 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
391 Some((active_buffer, registered_buffer.last_position))
392 }
393}
394
395#[derive(Debug, Clone)]
396struct CurrentEditPrediction {
397 pub requested_by: PredictionRequestedBy,
398 pub prediction: EditPrediction,
399 pub was_shown: bool,
400 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
401}
402
403impl CurrentEditPrediction {
404 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
405 let Some(new_edits) = self
406 .prediction
407 .interpolate(&self.prediction.buffer.read(cx))
408 else {
409 return false;
410 };
411
412 if self.prediction.buffer != old_prediction.prediction.buffer {
413 return true;
414 }
415
416 let Some(old_edits) = old_prediction
417 .prediction
418 .interpolate(&old_prediction.prediction.buffer.read(cx))
419 else {
420 return true;
421 };
422
423 let requested_by_buffer_id = self.requested_by.buffer_id();
424
425 // This reduces the occurrence of UI thrash from replacing edits
426 //
427 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
428 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
429 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
430 && old_edits.len() == 1
431 && new_edits.len() == 1
432 {
433 let (old_range, old_text) = &old_edits[0];
434 let (new_range, new_text) = &new_edits[0];
435 new_range == old_range && new_text.starts_with(old_text.as_ref())
436 } else {
437 true
438 }
439 }
440}
441
442#[derive(Debug, Clone)]
443enum PredictionRequestedBy {
444 DiagnosticsUpdate,
445 Buffer(EntityId),
446}
447
448impl PredictionRequestedBy {
449 pub fn buffer_id(&self) -> Option<EntityId> {
450 match self {
451 PredictionRequestedBy::DiagnosticsUpdate => None,
452 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
453 }
454 }
455}
456
457const DIAGNOSTIC_LINES_RANGE: u32 = 20;
458
459#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
460pub enum DiagnosticSearchScope {
461 Local,
462 Global,
463}
464
465#[derive(Debug)]
466struct PendingPrediction {
467 id: usize,
468 task: Task<Option<EditPredictionId>>,
469 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
470 /// If false, the task is awaited to completion so rejection can be reported.
471 drop_on_cancel: bool,
472}
473
474/// A prediction from the perspective of a buffer.
475#[derive(Debug)]
476enum BufferEditPrediction<'a> {
477 Local { prediction: &'a EditPrediction },
478 Jump { prediction: &'a EditPrediction },
479}
480
481#[cfg(test)]
482impl std::ops::Deref for BufferEditPrediction<'_> {
483 type Target = EditPrediction;
484
485 fn deref(&self) -> &Self::Target {
486 match self {
487 BufferEditPrediction::Local { prediction } => prediction,
488 BufferEditPrediction::Jump { prediction } => prediction,
489 }
490 }
491}
492
493#[derive(Clone)]
494struct PendingSettledPrediction {
495 request_id: EditPredictionId,
496 editable_anchor_range: Range<Anchor>,
497 enqueued_at: Instant,
498 last_edit_at: Instant,
499}
500
501struct RegisteredBuffer {
502 file: Option<Arc<dyn File>>,
503 snapshot: TextBufferSnapshot,
504 pending_predictions: Vec<PendingSettledPrediction>,
505 last_position: Option<Anchor>,
506 _subscriptions: [gpui::Subscription; 2],
507}
508
509#[derive(Clone)]
510struct LastEvent {
511 old_snapshot: TextBufferSnapshot,
512 new_snapshot: TextBufferSnapshot,
513 old_file: Option<Arc<dyn File>>,
514 new_file: Option<Arc<dyn File>>,
515 edit_range: Option<Range<Anchor>>,
516 predicted: bool,
517 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
518 last_edit_time: Option<Instant>,
519}
520
521impl LastEvent {
522 pub fn finalize(
523 &self,
524 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
525 cx: &App,
526 ) -> Option<StoredEvent> {
527 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
528 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
529
530 let in_open_source_repo =
531 [self.new_file.as_ref(), self.old_file.as_ref()]
532 .iter()
533 .all(|file| {
534 file.is_some_and(|file| {
535 license_detection_watchers
536 .get(&file.worktree_id(cx))
537 .is_some_and(|watcher| watcher.is_project_open_source())
538 })
539 });
540
541 let (diff, edit_range) =
542 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
543
544 if path == old_path && diff.is_empty() {
545 None
546 } else {
547 Some(StoredEvent {
548 event: Arc::new(zeta_prompt::Event::BufferChange {
549 old_path,
550 path,
551 diff,
552 in_open_source_repo,
553 predicted: self.predicted,
554 }),
555 edit_range: self.new_snapshot.anchor_before(edit_range.start)
556 ..self.new_snapshot.anchor_before(edit_range.end),
557 old_snapshot: self.old_snapshot.clone(),
558 })
559 }
560 }
561
562 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
563 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
564 return (self.clone(), None);
565 };
566
567 let before = LastEvent {
568 old_snapshot: self.old_snapshot.clone(),
569 new_snapshot: boundary_snapshot.clone(),
570 old_file: self.old_file.clone(),
571 new_file: self.new_file.clone(),
572 edit_range: None,
573 predicted: self.predicted,
574 snapshot_after_last_editing_pause: None,
575 last_edit_time: self.last_edit_time,
576 };
577
578 let after = LastEvent {
579 old_snapshot: boundary_snapshot.clone(),
580 new_snapshot: self.new_snapshot.clone(),
581 old_file: self.old_file.clone(),
582 new_file: self.new_file.clone(),
583 edit_range: None,
584 predicted: self.predicted,
585 snapshot_after_last_editing_pause: None,
586 last_edit_time: self.last_edit_time,
587 };
588
589 (before, Some(after))
590 }
591}
592
593pub(crate) fn compute_diff_between_snapshots(
594 old_snapshot: &TextBufferSnapshot,
595 new_snapshot: &TextBufferSnapshot,
596) -> Option<(String, Range<Point>)> {
597 let edits: Vec<Edit<usize>> = new_snapshot
598 .edits_since::<usize>(&old_snapshot.version)
599 .collect();
600
601 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
602
603 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
604 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
605 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
606 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
607
608 const CONTEXT_LINES: u32 = 3;
609
610 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
611 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
612 let old_context_end_row =
613 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
614 let new_context_end_row =
615 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
616
617 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
618 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
619 let old_end_line_offset = old_snapshot
620 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
621 let new_end_line_offset = new_snapshot
622 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
623 let old_edit_range = old_start_line_offset..old_end_line_offset;
624 let new_edit_range = new_start_line_offset..new_end_line_offset;
625
626 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
627 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
628
629 let diff = language::unified_diff_with_offsets(
630 &old_region_text,
631 &new_region_text,
632 old_context_start_row,
633 new_context_start_row,
634 );
635
636 Some((diff, new_start_point..new_end_point))
637}
638
639fn buffer_path_with_id_fallback(
640 file: Option<&Arc<dyn File>>,
641 snapshot: &TextBufferSnapshot,
642 cx: &App,
643) -> Arc<Path> {
644 if let Some(file) = file {
645 file.full_path(cx).into()
646 } else {
647 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
648 }
649}
650
651impl EditPredictionStore {
652 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
653 cx.try_global::<EditPredictionStoreGlobal>()
654 .map(|global| global.0.clone())
655 }
656
657 pub fn global(
658 client: &Arc<Client>,
659 user_store: &Entity<UserStore>,
660 cx: &mut App,
661 ) -> Entity<Self> {
662 cx.try_global::<EditPredictionStoreGlobal>()
663 .map(|global| global.0.clone())
664 .unwrap_or_else(|| {
665 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
666 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
667 ep_store
668 })
669 }
670
671 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
672 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
673 let data_collection_choice = Self::load_data_collection_choice();
674
675 let llm_token = LlmApiToken::default();
676
677 let (reject_tx, reject_rx) = mpsc::unbounded();
678 cx.background_spawn({
679 let client = client.clone();
680 let llm_token = llm_token.clone();
681 let app_version = AppVersion::global(cx);
682 let background_executor = cx.background_executor().clone();
683 async move {
684 Self::handle_rejected_predictions(
685 reject_rx,
686 client,
687 llm_token,
688 app_version,
689 background_executor,
690 )
691 .await
692 }
693 })
694 .detach();
695
696 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
697 cx.spawn(async move |this, cx| {
698 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
699 })
700 .detach();
701
702 let this = Self {
703 projects: HashMap::default(),
704 client,
705 user_store,
706 llm_token,
707 _llm_token_subscription: cx.subscribe(
708 &refresh_llm_token_listener,
709 |this, _listener, _event, cx| {
710 let client = this.client.clone();
711 let llm_token = this.llm_token.clone();
712 cx.spawn(async move |_this, _cx| {
713 llm_token.refresh(&client).await?;
714 anyhow::Ok(())
715 })
716 .detach_and_log_err(cx);
717 },
718 ),
719 update_required: false,
720 edit_prediction_model: EditPredictionModel::Zeta2,
721 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
722 sweep_ai: SweepAi::new(cx),
723 mercury: Mercury::new(cx),
724
725 data_collection_choice,
726 reject_predictions_tx: reject_tx,
727 settled_predictions_tx,
728 rated_predictions: Default::default(),
729 shown_predictions: Default::default(),
730 #[cfg(test)]
731 settled_event_callback: None,
732 };
733
734 this
735 }
736
737 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
738 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
739 let format = ZetaFormat::parse(&version_str).ok()?;
740 let model_id = env::var("ZED_ZETA_MODEL").ok();
741 Some(Zeta2RawConfig { model_id, format })
742 }
743
744 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
745 self.edit_prediction_model = model;
746 }
747
748 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
749 self.zeta2_raw_config = Some(config);
750 }
751
752 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
753 self.zeta2_raw_config.as_ref()
754 }
755
756 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
757 use ui::IconName;
758 match self.edit_prediction_model {
759 EditPredictionModel::Sweep => {
760 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
761 .with_disabled(IconName::SweepAiDisabled)
762 .with_up(IconName::SweepAiUp)
763 .with_down(IconName::SweepAiDown)
764 .with_error(IconName::SweepAiError)
765 }
766 EditPredictionModel::Mercury => {
767 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
768 }
769 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
770 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
771 .with_disabled(IconName::ZedPredictDisabled)
772 .with_up(IconName::ZedPredictUp)
773 .with_down(IconName::ZedPredictDown)
774 .with_error(IconName::ZedPredictError)
775 }
776 EditPredictionModel::Fim { .. } => {
777 let settings = &all_language_settings(None, cx).edit_predictions;
778 match settings.provider {
779 EditPredictionProvider::Ollama => {
780 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
781 }
782 _ => {
783 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
784 }
785 }
786 }
787 }
788 }
789
790 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
791 self.sweep_ai.api_token.read(cx).has_key()
792 }
793
794 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
795 self.mercury.api_token.read(cx).has_key()
796 }
797
798 pub fn clear_history(&mut self) {
799 for project_state in self.projects.values_mut() {
800 project_state.events.clear();
801 project_state.last_event.take();
802 }
803 }
804
805 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
806 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
807 project_state.events.clear();
808 project_state.last_event.take();
809 }
810 }
811
812 pub fn edit_history_for_project(
813 &self,
814 project: &Entity<Project>,
815 cx: &App,
816 ) -> Vec<StoredEvent> {
817 self.projects
818 .get(&project.entity_id())
819 .map(|project_state| project_state.events(cx))
820 .unwrap_or_default()
821 }
822
823 pub fn context_for_project<'a>(
824 &'a self,
825 project: &Entity<Project>,
826 cx: &'a mut App,
827 ) -> Vec<RelatedFile> {
828 self.projects
829 .get(&project.entity_id())
830 .map(|project_state| {
831 project_state.context.update(cx, |context, cx| {
832 context
833 .related_files_with_buffers(cx)
834 .map(|(mut related_file, buffer)| {
835 related_file.in_open_source_repo = buffer
836 .read(cx)
837 .file()
838 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
839 related_file
840 })
841 .collect()
842 })
843 })
844 .unwrap_or_default()
845 }
846
847 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
848 self.projects
849 .get(&project.entity_id())
850 .and_then(|project| project.copilot.clone())
851 }
852
853 pub fn start_copilot_for_project(
854 &mut self,
855 project: &Entity<Project>,
856 cx: &mut Context<Self>,
857 ) -> Option<Entity<Copilot>> {
858 if DisableAiSettings::get(None, cx).disable_ai {
859 return None;
860 }
861 let state = self.get_or_init_project(project, cx);
862
863 if state.copilot.is_some() {
864 return state.copilot.clone();
865 }
866 let _project = project.clone();
867 let project = project.read(cx);
868
869 let node = project.node_runtime().cloned();
870 if let Some(node) = node {
871 let next_id = project.languages().next_language_server_id();
872 let fs = project.fs().clone();
873
874 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
875 state.copilot = Some(copilot.clone());
876 Some(copilot)
877 } else {
878 None
879 }
880 }
881
882 pub fn context_for_project_with_buffers<'a>(
883 &'a self,
884 project: &Entity<Project>,
885 cx: &'a mut App,
886 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
887 self.projects
888 .get(&project.entity_id())
889 .map(|project| {
890 project.context.update(cx, |context, cx| {
891 context.related_files_with_buffers(cx).collect()
892 })
893 })
894 .unwrap_or_default()
895 }
896
897 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
898 if matches!(
899 self.edit_prediction_model,
900 EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
901 ) {
902 self.user_store.read(cx).edit_prediction_usage()
903 } else {
904 None
905 }
906 }
907
908 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
909 self.get_or_init_project(project, cx);
910 }
911
912 pub fn register_buffer(
913 &mut self,
914 buffer: &Entity<Buffer>,
915 project: &Entity<Project>,
916 cx: &mut Context<Self>,
917 ) {
918 let project_state = self.get_or_init_project(project, cx);
919 Self::register_buffer_impl(project_state, buffer, project, cx);
920 }
921
922 fn get_or_init_project(
923 &mut self,
924 project: &Entity<Project>,
925 cx: &mut Context<Self>,
926 ) -> &mut ProjectState {
927 let entity_id = project.entity_id();
928 self.projects
929 .entry(entity_id)
930 .or_insert_with(|| ProjectState {
931 context: {
932 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
933 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
934 this.handle_excerpt_store_event(entity_id, event);
935 })
936 .detach();
937 related_excerpt_store
938 },
939 events: VecDeque::new(),
940 last_event: None,
941 recent_paths: VecDeque::new(),
942 debug_tx: None,
943 registered_buffers: HashMap::default(),
944 current_prediction: None,
945 cancelled_predictions: HashSet::default(),
946 pending_predictions: ArrayVec::new(),
947 next_pending_prediction_id: 0,
948 last_edit_prediction_refresh: None,
949 last_jump_prediction_refresh: None,
950 license_detection_watchers: HashMap::default(),
951 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
952 _subscriptions: [
953 cx.subscribe(&project, Self::handle_project_event),
954 cx.observe_release(&project, move |this, _, cx| {
955 this.projects.remove(&entity_id);
956 cx.notify();
957 }),
958 ],
959 copilot: None,
960 })
961 }
962
963 pub fn remove_project(&mut self, project: &Entity<Project>) {
964 self.projects.remove(&project.entity_id());
965 }
966
967 fn handle_excerpt_store_event(
968 &mut self,
969 project_entity_id: EntityId,
970 event: &RelatedExcerptStoreEvent,
971 ) {
972 if let Some(project_state) = self.projects.get(&project_entity_id) {
973 if let Some(debug_tx) = project_state.debug_tx.clone() {
974 match event {
975 RelatedExcerptStoreEvent::StartedRefresh => {
976 debug_tx
977 .unbounded_send(DebugEvent::ContextRetrievalStarted(
978 ContextRetrievalStartedDebugEvent {
979 project_entity_id: project_entity_id,
980 timestamp: Instant::now(),
981 search_prompt: String::new(),
982 },
983 ))
984 .ok();
985 }
986 RelatedExcerptStoreEvent::FinishedRefresh {
987 cache_hit_count,
988 cache_miss_count,
989 mean_definition_latency,
990 max_definition_latency,
991 } => {
992 debug_tx
993 .unbounded_send(DebugEvent::ContextRetrievalFinished(
994 ContextRetrievalFinishedDebugEvent {
995 project_entity_id: project_entity_id,
996 timestamp: Instant::now(),
997 metadata: vec![
998 (
999 "Cache Hits",
1000 format!(
1001 "{}/{}",
1002 cache_hit_count,
1003 cache_hit_count + cache_miss_count
1004 )
1005 .into(),
1006 ),
1007 (
1008 "Max LSP Time",
1009 format!("{} ms", max_definition_latency.as_millis())
1010 .into(),
1011 ),
1012 (
1013 "Mean LSP Time",
1014 format!("{} ms", mean_definition_latency.as_millis())
1015 .into(),
1016 ),
1017 ],
1018 },
1019 ))
1020 .ok();
1021 }
1022 }
1023 }
1024 }
1025 }
1026
1027 pub fn debug_info(
1028 &mut self,
1029 project: &Entity<Project>,
1030 cx: &mut Context<Self>,
1031 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1032 let project_state = self.get_or_init_project(project, cx);
1033 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1034 project_state.debug_tx = Some(debug_watch_tx);
1035 debug_watch_rx
1036 }
1037
1038 fn handle_project_event(
1039 &mut self,
1040 project: Entity<Project>,
1041 event: &project::Event,
1042 cx: &mut Context<Self>,
1043 ) {
1044 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1045 return;
1046 }
1047 // TODO [zeta2] init with recent paths
1048 match event {
1049 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1050 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1051 return;
1052 };
1053 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1054 if let Some(path) = path {
1055 if let Some(ix) = project_state
1056 .recent_paths
1057 .iter()
1058 .position(|probe| probe == &path)
1059 {
1060 project_state.recent_paths.remove(ix);
1061 }
1062 project_state.recent_paths.push_front(path);
1063 }
1064 }
1065 project::Event::DiagnosticsUpdated { .. } => {
1066 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1067 self.refresh_prediction_from_diagnostics(
1068 project,
1069 DiagnosticSearchScope::Global,
1070 cx,
1071 );
1072 }
1073 }
1074 _ => (),
1075 }
1076 }
1077
1078 fn register_buffer_impl<'a>(
1079 project_state: &'a mut ProjectState,
1080 buffer: &Entity<Buffer>,
1081 project: &Entity<Project>,
1082 cx: &mut Context<Self>,
1083 ) -> &'a mut RegisteredBuffer {
1084 let buffer_id = buffer.entity_id();
1085
1086 if let Some(file) = buffer.read(cx).file() {
1087 let worktree_id = file.worktree_id(cx);
1088 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1089 project_state
1090 .license_detection_watchers
1091 .entry(worktree_id)
1092 .or_insert_with(|| {
1093 let project_entity_id = project.entity_id();
1094 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1095 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1096 else {
1097 return;
1098 };
1099 project_state
1100 .license_detection_watchers
1101 .remove(&worktree_id);
1102 })
1103 .detach();
1104 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1105 });
1106 }
1107 }
1108
1109 match project_state.registered_buffers.entry(buffer_id) {
1110 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1111 hash_map::Entry::Vacant(entry) => {
1112 let buf = buffer.read(cx);
1113 let snapshot = buf.text_snapshot();
1114 let file = buf.file().cloned();
1115 let project_entity_id = project.entity_id();
1116 entry.insert(RegisteredBuffer {
1117 snapshot,
1118 file,
1119 last_position: None,
1120 pending_predictions: Vec::new(),
1121 _subscriptions: [
1122 cx.subscribe(buffer, {
1123 let project = project.downgrade();
1124 move |this, buffer, event, cx| {
1125 if let language::BufferEvent::Edited = event
1126 && let Some(project) = project.upgrade()
1127 {
1128 this.report_changes_for_buffer(&buffer, &project, false, cx);
1129 }
1130 }
1131 }),
1132 cx.observe_release(buffer, move |this, _buffer, _cx| {
1133 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1134 else {
1135 return;
1136 };
1137 project_state.registered_buffers.remove(&buffer_id);
1138 }),
1139 ],
1140 })
1141 }
1142 }
1143 }
1144
1145 fn report_changes_for_buffer(
1146 &mut self,
1147 buffer: &Entity<Buffer>,
1148 project: &Entity<Project>,
1149 is_predicted: bool,
1150 cx: &mut Context<Self>,
1151 ) {
1152 let project_state = self.get_or_init_project(project, cx);
1153 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1154
1155 let buf = buffer.read(cx);
1156 let new_file = buf.file().cloned();
1157 let new_snapshot = buf.text_snapshot();
1158 if new_snapshot.version == registered_buffer.snapshot.version {
1159 return;
1160 }
1161
1162 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1163 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1164 let mut num_edits = 0usize;
1165 let mut total_deleted = 0usize;
1166 let mut total_inserted = 0usize;
1167 let mut edit_range: Option<Range<Anchor>> = None;
1168 let mut last_offset: Option<usize> = None;
1169 let now = cx.background_executor().now();
1170
1171 for (edit, anchor_range) in
1172 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1173 {
1174 num_edits += 1;
1175 total_deleted += edit.old.len();
1176 total_inserted += edit.new.len();
1177 edit_range = Some(match edit_range {
1178 None => anchor_range,
1179 Some(acc) => acc.start..anchor_range.end,
1180 });
1181 last_offset = Some(edit.new.end);
1182 }
1183
1184 let Some(edit_range) = edit_range else {
1185 return;
1186 };
1187
1188 for pending_prediction in &mut registered_buffer.pending_predictions {
1189 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1190 pending_prediction.last_edit_at = now;
1191 }
1192 }
1193
1194 let action_type = match (total_deleted, total_inserted, num_edits) {
1195 (0, ins, n) if ins == n => UserActionType::InsertChar,
1196 (0, _, _) => UserActionType::InsertSelection,
1197 (del, 0, n) if del == n => UserActionType::DeleteChar,
1198 (_, 0, _) => UserActionType::DeleteSelection,
1199 (_, ins, n) if ins == n => UserActionType::InsertChar,
1200 (_, _, _) => UserActionType::InsertSelection,
1201 };
1202
1203 if let Some(offset) = last_offset {
1204 let point = new_snapshot.offset_to_point(offset);
1205 let timestamp_epoch_ms = SystemTime::now()
1206 .duration_since(UNIX_EPOCH)
1207 .map(|d| d.as_millis() as u64)
1208 .unwrap_or(0);
1209 project_state.record_user_action(UserActionRecord {
1210 action_type,
1211 buffer_id: buffer.entity_id(),
1212 line_number: point.row,
1213 offset,
1214 timestamp_epoch_ms,
1215 });
1216 }
1217
1218 let events = &mut project_state.events;
1219
1220 if let Some(last_event) = project_state.last_event.as_mut() {
1221 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1222 == last_event.new_snapshot.remote_id()
1223 && old_snapshot.version == last_event.new_snapshot.version;
1224
1225 let prediction_source_changed = is_predicted != last_event.predicted;
1226
1227 let should_coalesce = is_next_snapshot_of_same_buffer
1228 && !prediction_source_changed
1229 && last_event
1230 .edit_range
1231 .as_ref()
1232 .is_some_and(|last_edit_range| {
1233 lines_between_ranges(
1234 &edit_range.to_point(&new_snapshot),
1235 &last_edit_range.to_point(&new_snapshot),
1236 ) <= CHANGE_GROUPING_LINE_SPAN
1237 });
1238
1239 if should_coalesce {
1240 let pause_elapsed = last_event
1241 .last_edit_time
1242 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1243 .unwrap_or(false);
1244 if pause_elapsed {
1245 last_event.snapshot_after_last_editing_pause =
1246 Some(last_event.new_snapshot.clone());
1247 }
1248
1249 last_event.edit_range = Some(edit_range);
1250 last_event.new_snapshot = new_snapshot;
1251 last_event.last_edit_time = Some(now);
1252 return;
1253 }
1254 }
1255
1256 if let Some(event) = project_state.last_event.take() {
1257 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1258 if events.len() + 1 >= EVENT_COUNT_MAX {
1259 events.pop_front();
1260 }
1261 events.push_back(event);
1262 }
1263 }
1264
1265 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1266
1267 project_state.last_event = Some(LastEvent {
1268 old_file,
1269 new_file,
1270 old_snapshot,
1271 new_snapshot,
1272 edit_range: Some(edit_range),
1273 predicted: is_predicted,
1274 snapshot_after_last_editing_pause: None,
1275 last_edit_time: Some(now),
1276 });
1277 }
1278
1279 fn prediction_at(
1280 &mut self,
1281 buffer: &Entity<Buffer>,
1282 position: Option<language::Anchor>,
1283 project: &Entity<Project>,
1284 cx: &App,
1285 ) -> Option<BufferEditPrediction<'_>> {
1286 let project_state = self.projects.get_mut(&project.entity_id())?;
1287 if let Some(position) = position
1288 && let Some(buffer) = project_state
1289 .registered_buffers
1290 .get_mut(&buffer.entity_id())
1291 {
1292 buffer.last_position = Some(position);
1293 }
1294
1295 let CurrentEditPrediction {
1296 requested_by,
1297 prediction,
1298 ..
1299 } = project_state.current_prediction.as_ref()?;
1300
1301 if prediction.targets_buffer(buffer.read(cx)) {
1302 Some(BufferEditPrediction::Local { prediction })
1303 } else {
1304 let show_jump = match requested_by {
1305 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1306 requested_by_buffer_id == &buffer.entity_id()
1307 }
1308 PredictionRequestedBy::DiagnosticsUpdate => true,
1309 };
1310
1311 if show_jump {
1312 Some(BufferEditPrediction::Jump { prediction })
1313 } else {
1314 None
1315 }
1316 }
1317 }
1318
1319 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1320 let Some(current_prediction) = self
1321 .projects
1322 .get_mut(&project.entity_id())
1323 .and_then(|project_state| project_state.current_prediction.take())
1324 else {
1325 return;
1326 };
1327
1328 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1329
1330 // can't hold &mut project_state ref across report_changes_for_buffer_call
1331 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1332 return;
1333 };
1334
1335 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1336 project_state.cancel_pending_prediction(pending_prediction, cx);
1337 }
1338
1339 match self.edit_prediction_model {
1340 EditPredictionModel::Sweep => {
1341 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1342 }
1343 EditPredictionModel::Mercury => {
1344 mercury::edit_prediction_accepted(
1345 current_prediction.prediction.id,
1346 self.client.http_client(),
1347 cx,
1348 );
1349 }
1350 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1351 let is_cloud = !matches!(
1352 all_language_settings(None, cx).edit_predictions.provider,
1353 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1354 );
1355 if is_cloud {
1356 zeta::edit_prediction_accepted(self, current_prediction, cx)
1357 }
1358 }
1359 EditPredictionModel::Fim { .. } => {}
1360 }
1361 }
1362
1363 async fn handle_rejected_predictions(
1364 rx: UnboundedReceiver<EditPredictionRejection>,
1365 client: Arc<Client>,
1366 llm_token: LlmApiToken,
1367 app_version: Version,
1368 background_executor: BackgroundExecutor,
1369 ) {
1370 let mut rx = std::pin::pin!(rx.peekable());
1371 let mut batched = Vec::new();
1372
1373 while let Some(rejection) = rx.next().await {
1374 batched.push(rejection);
1375
1376 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1377 select_biased! {
1378 next = rx.as_mut().peek().fuse() => {
1379 if next.is_some() {
1380 continue;
1381 }
1382 }
1383 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1384 }
1385 }
1386
1387 let url = client
1388 .http_client()
1389 .build_zed_llm_url("/predict_edits/reject", &[])
1390 .unwrap();
1391
1392 let flush_count = batched
1393 .len()
1394 // in case items have accumulated after failure
1395 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1396 let start = batched.len() - flush_count;
1397
1398 let body = RejectEditPredictionsBodyRef {
1399 rejections: &batched[start..],
1400 };
1401
1402 let result = Self::send_api_request::<()>(
1403 |builder| {
1404 let req = builder
1405 .uri(url.as_ref())
1406 .body(serde_json::to_string(&body)?.into());
1407 anyhow::Ok(req?)
1408 },
1409 client.clone(),
1410 llm_token.clone(),
1411 app_version.clone(),
1412 true,
1413 )
1414 .await;
1415
1416 if result.log_err().is_some() {
1417 batched.drain(start..);
1418 }
1419 }
1420 }
1421
1422 async fn run_settled_predictions_worker(
1423 this: WeakEntity<Self>,
1424 mut rx: UnboundedReceiver<Instant>,
1425 cx: &mut AsyncApp,
1426 ) {
1427 let mut next_wake_time: Option<Instant> = None;
1428 loop {
1429 let now = cx.background_executor().now();
1430 if let Some(wake_time) = next_wake_time.take() {
1431 cx.background_executor()
1432 .timer(wake_time.duration_since(now))
1433 .await;
1434 } else {
1435 let Some(new_enqueue_time) = rx.next().await else {
1436 break;
1437 };
1438 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1439 while rx.next().now_or_never().flatten().is_some() {}
1440 continue;
1441 }
1442
1443 let Some(this) = this.upgrade() else {
1444 break;
1445 };
1446
1447 let now = cx.background_executor().now();
1448
1449 let mut oldest_edited_at = None;
1450
1451 this.update(cx, |this, _| {
1452 for (_, project_state) in this.projects.iter_mut() {
1453 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1454 registered_buffer
1455 .pending_predictions
1456 .retain_mut(|pending_prediction| {
1457 let age =
1458 now.saturating_duration_since(pending_prediction.enqueued_at);
1459 if age >= EDIT_PREDICTION_SETTLED_TTL {
1460 return false;
1461 }
1462
1463 let quiet_for =
1464 now.saturating_duration_since(pending_prediction.last_edit_at);
1465 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1466 let settled_editable_region = registered_buffer
1467 .snapshot
1468 .text_for_range(
1469 pending_prediction.editable_anchor_range.clone(),
1470 )
1471 .collect::<String>();
1472
1473 #[cfg(test)]
1474 if let Some(callback) = &this.settled_event_callback {
1475 callback(
1476 pending_prediction.request_id.clone(),
1477 settled_editable_region.clone(),
1478 );
1479 }
1480
1481 telemetry::event!(
1482 EDIT_PREDICTION_SETTLED_EVENT,
1483 request_id = pending_prediction.request_id.0.clone(),
1484 settled_editable_region,
1485 );
1486
1487 return false;
1488 }
1489
1490 if oldest_edited_at
1491 .is_none_or(|t| pending_prediction.last_edit_at < t)
1492 {
1493 oldest_edited_at = Some(pending_prediction.last_edit_at);
1494 }
1495
1496 true
1497 });
1498 }
1499 }
1500 });
1501
1502 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1503 }
1504 }
1505
1506 pub(crate) fn enqueue_settled_prediction(
1507 &mut self,
1508 request_id: EditPredictionId,
1509 project: &Entity<Project>,
1510 edited_buffer: &Entity<Buffer>,
1511 edited_buffer_snapshot: &BufferSnapshot,
1512 editable_offset_range: Range<usize>,
1513 cx: &mut Context<Self>,
1514 ) {
1515 let project_state = self.get_or_init_project(project, cx);
1516 if let Some(buffer) = project_state
1517 .registered_buffers
1518 .get_mut(&edited_buffer.entity_id())
1519 {
1520 let now = cx.background_executor().now();
1521 buffer.pending_predictions.push(PendingSettledPrediction {
1522 request_id,
1523 editable_anchor_range: edited_buffer_snapshot
1524 .anchor_range_around(editable_offset_range),
1525 enqueued_at: now,
1526 last_edit_at: now,
1527 });
1528 self.settled_predictions_tx.unbounded_send(now).ok();
1529 }
1530 }
1531
1532 fn reject_current_prediction(
1533 &mut self,
1534 reason: EditPredictionRejectReason,
1535 project: &Entity<Project>,
1536 cx: &App,
1537 ) {
1538 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1539 project_state.pending_predictions.clear();
1540 if let Some(prediction) = project_state.current_prediction.take() {
1541 let model_version = prediction.prediction.model_version.clone();
1542 self.reject_prediction(
1543 prediction.prediction.id,
1544 reason,
1545 prediction.was_shown,
1546 model_version,
1547 cx,
1548 );
1549 }
1550 };
1551 }
1552
1553 fn did_show_current_prediction(
1554 &mut self,
1555 project: &Entity<Project>,
1556 display_type: edit_prediction_types::SuggestionDisplayType,
1557 cx: &mut Context<Self>,
1558 ) {
1559 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1560 return;
1561 };
1562
1563 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1564 return;
1565 };
1566
1567 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1568 let previous_shown_with = current_prediction.shown_with;
1569
1570 if previous_shown_with.is_none() || !is_jump {
1571 current_prediction.shown_with = Some(display_type);
1572 }
1573
1574 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1575
1576 if is_first_non_jump_show {
1577 current_prediction.was_shown = true;
1578 }
1579
1580 let display_type_changed = previous_shown_with != Some(display_type);
1581
1582 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1583 sweep_ai::edit_prediction_shown(
1584 &self.sweep_ai,
1585 self.client.clone(),
1586 ¤t_prediction.prediction,
1587 display_type,
1588 cx,
1589 );
1590 }
1591
1592 if is_first_non_jump_show {
1593 self.shown_predictions
1594 .push_front(current_prediction.prediction.clone());
1595 if self.shown_predictions.len() > 50 {
1596 let completion = self.shown_predictions.pop_back().unwrap();
1597 self.rated_predictions.remove(&completion.id);
1598 }
1599 }
1600 }
1601
1602 fn reject_prediction(
1603 &mut self,
1604 prediction_id: EditPredictionId,
1605 reason: EditPredictionRejectReason,
1606 was_shown: bool,
1607 model_version: Option<String>,
1608 cx: &App,
1609 ) {
1610 match self.edit_prediction_model {
1611 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1612 let is_cloud = !matches!(
1613 all_language_settings(None, cx).edit_predictions.provider,
1614 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1615 );
1616 if is_cloud {
1617 self.reject_predictions_tx
1618 .unbounded_send(EditPredictionRejection {
1619 request_id: prediction_id.to_string(),
1620 reason,
1621 was_shown,
1622 model_version,
1623 })
1624 .log_err();
1625 }
1626 }
1627 EditPredictionModel::Mercury => {
1628 mercury::edit_prediction_rejected(
1629 prediction_id,
1630 was_shown,
1631 reason,
1632 self.client.http_client(),
1633 cx,
1634 );
1635 }
1636 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1637 }
1638 }
1639
1640 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1641 self.projects
1642 .get(&project.entity_id())
1643 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1644 }
1645
1646 pub fn refresh_prediction_from_buffer(
1647 &mut self,
1648 project: Entity<Project>,
1649 buffer: Entity<Buffer>,
1650 position: language::Anchor,
1651 cx: &mut Context<Self>,
1652 ) {
1653 self.queue_prediction_refresh(
1654 project.clone(),
1655 PredictEditsRequestTrigger::Other,
1656 buffer.entity_id(),
1657 cx,
1658 move |this, cx| {
1659 let Some(request_task) = this
1660 .update(cx, |this, cx| {
1661 this.request_prediction(
1662 &project,
1663 &buffer,
1664 position,
1665 PredictEditsRequestTrigger::Other,
1666 cx,
1667 )
1668 })
1669 .log_err()
1670 else {
1671 return Task::ready(anyhow::Ok(None));
1672 };
1673
1674 cx.spawn(async move |_cx| {
1675 request_task.await.map(|prediction_result| {
1676 prediction_result.map(|prediction_result| {
1677 (
1678 prediction_result,
1679 PredictionRequestedBy::Buffer(buffer.entity_id()),
1680 )
1681 })
1682 })
1683 })
1684 },
1685 )
1686 }
1687
1688 pub fn refresh_prediction_from_diagnostics(
1689 &mut self,
1690 project: Entity<Project>,
1691 scope: DiagnosticSearchScope,
1692 cx: &mut Context<Self>,
1693 ) {
1694 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1695 return;
1696 }
1697
1698 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1699 return;
1700 };
1701
1702 // Prefer predictions from buffer
1703 if project_state.current_prediction.is_some() {
1704 return;
1705 }
1706
1707 self.queue_prediction_refresh(
1708 project.clone(),
1709 PredictEditsRequestTrigger::Diagnostics,
1710 project.entity_id(),
1711 cx,
1712 move |this, cx| {
1713 let Some((active_buffer, snapshot, cursor_point)) = this
1714 .read_with(cx, |this, cx| {
1715 let project_state = this.projects.get(&project.entity_id())?;
1716 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1717 let snapshot = buffer.read(cx).snapshot();
1718
1719 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1720 return None;
1721 }
1722
1723 let cursor_point = position
1724 .map(|pos| pos.to_point(&snapshot))
1725 .unwrap_or_default();
1726
1727 Some((buffer, snapshot, cursor_point))
1728 })
1729 .log_err()
1730 .flatten()
1731 else {
1732 return Task::ready(anyhow::Ok(None));
1733 };
1734
1735 cx.spawn(async move |cx| {
1736 let diagnostic_search_range = match scope {
1737 DiagnosticSearchScope::Local => {
1738 let diagnostic_search_start =
1739 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1740 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1741 Point::new(diagnostic_search_start, 0)
1742 ..Point::new(diagnostic_search_end, 0)
1743 }
1744 DiagnosticSearchScope::Global => Default::default(),
1745 };
1746
1747 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1748 active_buffer,
1749 &snapshot,
1750 diagnostic_search_range,
1751 cursor_point,
1752 &project,
1753 cx,
1754 )
1755 .await?
1756 else {
1757 return anyhow::Ok(None);
1758 };
1759
1760 let Some(prediction_result) = this
1761 .update(cx, |this, cx| {
1762 this.request_prediction(
1763 &project,
1764 &jump_buffer,
1765 jump_position,
1766 PredictEditsRequestTrigger::Diagnostics,
1767 cx,
1768 )
1769 })?
1770 .await?
1771 else {
1772 return anyhow::Ok(None);
1773 };
1774
1775 this.update(cx, |this, cx| {
1776 Some((
1777 if this
1778 .get_or_init_project(&project, cx)
1779 .current_prediction
1780 .is_none()
1781 {
1782 prediction_result
1783 } else {
1784 EditPredictionResult {
1785 id: prediction_result.id,
1786 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1787 }
1788 },
1789 PredictionRequestedBy::DiagnosticsUpdate,
1790 ))
1791 })
1792 })
1793 },
1794 );
1795 }
1796
1797 fn predictions_enabled_at(
1798 snapshot: &BufferSnapshot,
1799 position: Option<language::Anchor>,
1800 cx: &App,
1801 ) -> bool {
1802 let file = snapshot.file();
1803 let all_settings = all_language_settings(file, cx);
1804 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1805 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1806 {
1807 return false;
1808 }
1809
1810 if let Some(last_position) = position {
1811 let settings = snapshot.settings_at(last_position, cx);
1812
1813 if !settings.edit_predictions_disabled_in.is_empty()
1814 && let Some(scope) = snapshot.language_scope_at(last_position)
1815 && let Some(scope_name) = scope.override_name()
1816 && settings
1817 .edit_predictions_disabled_in
1818 .iter()
1819 .any(|s| s == scope_name)
1820 {
1821 return false;
1822 }
1823 }
1824
1825 true
1826 }
1827
1828 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1829}
1830
1831fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1832 match provider {
1833 EditPredictionProvider::Zed
1834 | EditPredictionProvider::Sweep
1835 | EditPredictionProvider::Mercury
1836 | EditPredictionProvider::Ollama
1837 | EditPredictionProvider::OpenAiCompatibleApi
1838 | EditPredictionProvider::Experimental(_) => true,
1839 EditPredictionProvider::None
1840 | EditPredictionProvider::Copilot
1841 | EditPredictionProvider::Supermaven
1842 | EditPredictionProvider::Codestral => false,
1843 }
1844}
1845
1846impl EditPredictionStore {
1847 fn queue_prediction_refresh(
1848 &mut self,
1849 project: Entity<Project>,
1850 request_trigger: PredictEditsRequestTrigger,
1851 throttle_entity: EntityId,
1852 cx: &mut Context<Self>,
1853 do_refresh: impl FnOnce(
1854 WeakEntity<Self>,
1855 &mut AsyncApp,
1856 )
1857 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1858 + 'static,
1859 ) {
1860 fn select_throttle(
1861 project_state: &mut ProjectState,
1862 request_trigger: PredictEditsRequestTrigger,
1863 ) -> &mut Option<(EntityId, Instant)> {
1864 match request_trigger {
1865 PredictEditsRequestTrigger::Diagnostics => {
1866 &mut project_state.last_jump_prediction_refresh
1867 }
1868 _ => &mut project_state.last_edit_prediction_refresh,
1869 }
1870 }
1871
1872 let (needs_acceptance_tracking, max_pending_predictions) =
1873 match all_language_settings(None, cx).edit_predictions.provider {
1874 EditPredictionProvider::Zed
1875 | EditPredictionProvider::Sweep
1876 | EditPredictionProvider::Mercury
1877 | EditPredictionProvider::Experimental(_) => (true, 2),
1878 EditPredictionProvider::Ollama => (false, 1),
1879 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1880 EditPredictionProvider::None
1881 | EditPredictionProvider::Copilot
1882 | EditPredictionProvider::Supermaven
1883 | EditPredictionProvider::Codestral => {
1884 log::error!("queue_prediction_refresh called with non-store provider");
1885 return;
1886 }
1887 };
1888
1889 let drop_on_cancel = !needs_acceptance_tracking;
1890 let throttle_timeout = Self::THROTTLE_TIMEOUT;
1891 let project_state = self.get_or_init_project(&project, cx);
1892 let pending_prediction_id = project_state.next_pending_prediction_id;
1893 project_state.next_pending_prediction_id += 1;
1894 let last_request = *select_throttle(project_state, request_trigger);
1895
1896 let task = cx.spawn(async move |this, cx| {
1897 if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1898 if throttle_entity != last_entity {
1899 return None;
1900 }
1901 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1902 }) {
1903 cx.background_executor().timer(timeout).await;
1904 }
1905
1906 // If this task was cancelled before the throttle timeout expired,
1907 // do not perform a request.
1908 let mut is_cancelled = true;
1909 this.update(cx, |this, cx| {
1910 let project_state = this.get_or_init_project(&project, cx);
1911 let was_cancelled = project_state
1912 .cancelled_predictions
1913 .remove(&pending_prediction_id);
1914 if !was_cancelled {
1915 let new_refresh = (throttle_entity, Instant::now());
1916 *select_throttle(project_state, request_trigger) = Some(new_refresh);
1917 is_cancelled = false;
1918 }
1919 })
1920 .ok();
1921 if is_cancelled {
1922 return None;
1923 }
1924
1925 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1926 let new_prediction_id = new_prediction_result
1927 .as_ref()
1928 .map(|(prediction, _)| prediction.id.clone());
1929
1930 // When a prediction completes, remove it from the pending list, and cancel
1931 // any pending predictions that were enqueued before it.
1932 this.update(cx, |this, cx| {
1933 let project_state = this.get_or_init_project(&project, cx);
1934
1935 let is_cancelled = project_state
1936 .cancelled_predictions
1937 .remove(&pending_prediction_id);
1938
1939 let new_current_prediction = if !is_cancelled
1940 && let Some((prediction_result, requested_by)) = new_prediction_result
1941 {
1942 match prediction_result.prediction {
1943 Ok(prediction) => {
1944 let new_prediction = CurrentEditPrediction {
1945 requested_by,
1946 prediction,
1947 was_shown: false,
1948 shown_with: None,
1949 };
1950
1951 if let Some(current_prediction) =
1952 project_state.current_prediction.as_ref()
1953 {
1954 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1955 {
1956 this.reject_current_prediction(
1957 EditPredictionRejectReason::Replaced,
1958 &project,
1959 cx,
1960 );
1961
1962 Some(new_prediction)
1963 } else {
1964 this.reject_prediction(
1965 new_prediction.prediction.id,
1966 EditPredictionRejectReason::CurrentPreferred,
1967 false,
1968 new_prediction.prediction.model_version,
1969 cx,
1970 );
1971 None
1972 }
1973 } else {
1974 Some(new_prediction)
1975 }
1976 }
1977 Err(reject_reason) => {
1978 this.reject_prediction(
1979 prediction_result.id,
1980 reject_reason,
1981 false,
1982 None,
1983 cx,
1984 );
1985 None
1986 }
1987 }
1988 } else {
1989 None
1990 };
1991
1992 let project_state = this.get_or_init_project(&project, cx);
1993
1994 if let Some(new_prediction) = new_current_prediction {
1995 project_state.current_prediction = Some(new_prediction);
1996 }
1997
1998 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1999 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2000 if pending_prediction.id == pending_prediction_id {
2001 pending_predictions.remove(ix);
2002 for pending_prediction in pending_predictions.drain(0..ix) {
2003 project_state.cancel_pending_prediction(pending_prediction, cx)
2004 }
2005 break;
2006 }
2007 }
2008 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2009 cx.notify();
2010 })
2011 .ok();
2012
2013 new_prediction_id
2014 });
2015
2016 if project_state.pending_predictions.len() < max_pending_predictions {
2017 project_state.pending_predictions.push(PendingPrediction {
2018 id: pending_prediction_id,
2019 task,
2020 drop_on_cancel,
2021 });
2022 } else {
2023 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2024 project_state.pending_predictions.push(PendingPrediction {
2025 id: pending_prediction_id,
2026 task,
2027 drop_on_cancel,
2028 });
2029 project_state.cancel_pending_prediction(pending_prediction, cx);
2030 }
2031 }
2032
2033 pub fn request_prediction(
2034 &mut self,
2035 project: &Entity<Project>,
2036 active_buffer: &Entity<Buffer>,
2037 position: language::Anchor,
2038 trigger: PredictEditsRequestTrigger,
2039 cx: &mut Context<Self>,
2040 ) -> Task<Result<Option<EditPredictionResult>>> {
2041 self.request_prediction_internal(
2042 project.clone(),
2043 active_buffer.clone(),
2044 position,
2045 trigger,
2046 cx.has_flag::<Zeta2FeatureFlag>(),
2047 cx,
2048 )
2049 }
2050
2051 fn request_prediction_internal(
2052 &mut self,
2053 project: Entity<Project>,
2054 active_buffer: Entity<Buffer>,
2055 position: language::Anchor,
2056 trigger: PredictEditsRequestTrigger,
2057 allow_jump: bool,
2058 cx: &mut Context<Self>,
2059 ) -> Task<Result<Option<EditPredictionResult>>> {
2060 self.get_or_init_project(&project, cx);
2061 let project_state = self.projects.get(&project.entity_id()).unwrap();
2062 let stored_events = project_state.events(cx);
2063 let has_events = !stored_events.is_empty();
2064 let events: Vec<Arc<zeta_prompt::Event>> =
2065 stored_events.iter().map(|e| e.event.clone()).collect();
2066 let debug_tx = project_state.debug_tx.clone();
2067
2068 let snapshot = active_buffer.read(cx).snapshot();
2069 let cursor_point = position.to_point(&snapshot);
2070 let current_offset = position.to_offset(&snapshot);
2071
2072 let mut user_actions: Vec<UserActionRecord> =
2073 project_state.user_actions.iter().cloned().collect();
2074
2075 if let Some(last_action) = user_actions.last() {
2076 if last_action.buffer_id == active_buffer.entity_id()
2077 && current_offset != last_action.offset
2078 {
2079 let timestamp_epoch_ms = SystemTime::now()
2080 .duration_since(UNIX_EPOCH)
2081 .map(|d| d.as_millis() as u64)
2082 .unwrap_or(0);
2083 user_actions.push(UserActionRecord {
2084 action_type: UserActionType::CursorMovement,
2085 buffer_id: active_buffer.entity_id(),
2086 line_number: cursor_point.row,
2087 offset: current_offset,
2088 timestamp_epoch_ms,
2089 });
2090 }
2091 }
2092 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2093 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2094 let diagnostic_search_range =
2095 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2096
2097 let related_files = self.context_for_project(&project, cx);
2098
2099 let is_open_source = snapshot
2100 .file()
2101 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2102 && events.iter().all(|event| event.in_open_source_repo())
2103 && related_files.iter().all(|file| file.in_open_source_repo);
2104
2105 let can_collect_data = !cfg!(test)
2106 && is_open_source
2107 && self.is_data_collection_enabled(cx)
2108 && matches!(
2109 self.edit_prediction_model,
2110 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2
2111 );
2112
2113 let inputs = EditPredictionModelInput {
2114 project: project.clone(),
2115 buffer: active_buffer.clone(),
2116 snapshot: snapshot,
2117 position,
2118 events,
2119 related_files,
2120 recent_paths: project_state.recent_paths.clone(),
2121 trigger,
2122 diagnostic_search_range: diagnostic_search_range,
2123 debug_tx,
2124 user_actions,
2125 can_collect_data,
2126 is_open_source,
2127 };
2128
2129 if can_collect_data && rand::random_ratio(1, 1000) {
2130 if let Some(task) = capture_example(
2131 project.clone(),
2132 active_buffer,
2133 position,
2134 stored_events,
2135 false,
2136 cx,
2137 ) {
2138 task.detach();
2139 }
2140 }
2141
2142 let task = match self.edit_prediction_model {
2143 EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
2144 self,
2145 inputs,
2146 Some(zeta_prompt::EditPredictionModelKind::Zeta1),
2147 cx,
2148 ),
2149 EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
2150 self,
2151 inputs,
2152 Some(zeta_prompt::EditPredictionModelKind::Zeta2),
2153 cx,
2154 ),
2155 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2156 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2157 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2158 };
2159
2160 cx.spawn(async move |this, cx| {
2161 let prediction = task.await?;
2162
2163 if prediction.is_none() && allow_jump && has_events {
2164 this.update(cx, |this, cx| {
2165 this.refresh_prediction_from_diagnostics(
2166 project,
2167 DiagnosticSearchScope::Local,
2168 cx,
2169 );
2170 })?;
2171 return anyhow::Ok(None);
2172 }
2173
2174 Ok(prediction)
2175 })
2176 }
2177
2178 pub(crate) async fn next_diagnostic_location(
2179 active_buffer: Entity<Buffer>,
2180 active_buffer_snapshot: &BufferSnapshot,
2181 active_buffer_diagnostic_search_range: Range<Point>,
2182 active_buffer_cursor_point: Point,
2183 project: &Entity<Project>,
2184 cx: &mut AsyncApp,
2185 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2186 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2187 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2188 .flat_map(|(_, _, _, selections)| {
2189 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2190 })
2191 .collect();
2192
2193 let mut jump_location = active_buffer_snapshot
2194 .diagnostic_groups(None)
2195 .into_iter()
2196 .filter_map(|(_, group)| {
2197 let range = &group.entries[group.primary_ix]
2198 .range
2199 .to_point(&active_buffer_snapshot);
2200 if range.overlaps(&active_buffer_diagnostic_search_range) {
2201 return None;
2202 }
2203 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2204 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2205 });
2206 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2207 <= DIAGNOSTIC_LINES_RANGE;
2208 if near_collaborator && !near_local {
2209 return None;
2210 }
2211 Some(range.start)
2212 })
2213 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2214 .map(|position| {
2215 (
2216 active_buffer.clone(),
2217 active_buffer_snapshot.anchor_before(position),
2218 )
2219 });
2220
2221 if jump_location.is_none() {
2222 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2223 let file = buffer.file()?;
2224
2225 Some(ProjectPath {
2226 worktree_id: file.worktree_id(cx),
2227 path: file.path().clone(),
2228 })
2229 });
2230
2231 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2232 project
2233 .diagnostic_summaries(false, cx)
2234 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2235 .map(|(path, _, _)| {
2236 let shared_prefix = path
2237 .path
2238 .components()
2239 .zip(
2240 active_buffer_path
2241 .as_ref()
2242 .map(|p| p.path.components())
2243 .unwrap_or_default(),
2244 )
2245 .take_while(|(a, b)| a == b)
2246 .count();
2247 (path, shared_prefix)
2248 })
2249 .collect()
2250 });
2251
2252 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2253
2254 for (path, _) in candidates {
2255 let candidate_buffer = project
2256 .update(cx, |project, cx| project.open_buffer(path, cx))
2257 .await?;
2258
2259 let (has_collaborators, diagnostic_position) =
2260 candidate_buffer.read_with(cx, |buffer, _cx| {
2261 let snapshot = buffer.snapshot();
2262 let has_collaborators = snapshot
2263 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2264 .next()
2265 .is_some();
2266 let position = buffer
2267 .buffer_diagnostics(None)
2268 .into_iter()
2269 .min_by_key(|entry| entry.diagnostic.severity)
2270 .map(|entry| entry.range.start);
2271 (has_collaborators, position)
2272 });
2273
2274 if has_collaborators {
2275 continue;
2276 }
2277
2278 if let Some(position) = diagnostic_position {
2279 jump_location = Some((candidate_buffer, position));
2280 break;
2281 }
2282 }
2283 }
2284
2285 anyhow::Ok(jump_location)
2286 }
2287
2288 async fn send_raw_llm_request(
2289 request: RawCompletionRequest,
2290 client: Arc<Client>,
2291 custom_url: Option<Arc<Url>>,
2292 llm_token: LlmApiToken,
2293 app_version: Version,
2294 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2295 let url = if let Some(custom_url) = custom_url {
2296 custom_url.as_ref().clone()
2297 } else {
2298 client
2299 .http_client()
2300 .build_zed_llm_url("/predict_edits/raw", &[])?
2301 };
2302
2303 Self::send_api_request(
2304 |builder| {
2305 let req = builder
2306 .uri(url.as_ref())
2307 .body(serde_json::to_string(&request)?.into());
2308 Ok(req?)
2309 },
2310 client,
2311 llm_token,
2312 app_version,
2313 true,
2314 )
2315 .await
2316 }
2317
2318 pub(crate) async fn send_v3_request(
2319 input: ZetaPromptInput,
2320 client: Arc<Client>,
2321 llm_token: LlmApiToken,
2322 app_version: Version,
2323 trigger: PredictEditsRequestTrigger,
2324 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2325 let url = client
2326 .http_client()
2327 .build_zed_llm_url("/predict_edits/v3", &[])?;
2328
2329 let request = PredictEditsV3Request { input, trigger };
2330
2331 let json_bytes = serde_json::to_vec(&request)?;
2332 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2333
2334 Self::send_api_request(
2335 |builder| {
2336 let req = builder
2337 .uri(url.as_ref())
2338 .header("Content-Encoding", "zstd")
2339 .body(compressed.clone().into());
2340 Ok(req?)
2341 },
2342 client,
2343 llm_token,
2344 app_version,
2345 true,
2346 )
2347 .await
2348 }
2349
2350 fn handle_api_response<T>(
2351 this: &WeakEntity<Self>,
2352 response: Result<(T, Option<EditPredictionUsage>)>,
2353 cx: &mut gpui::AsyncApp,
2354 ) -> Result<T> {
2355 match response {
2356 Ok((data, usage)) => {
2357 if let Some(usage) = usage {
2358 this.update(cx, |this, cx| {
2359 this.user_store.update(cx, |user_store, cx| {
2360 user_store.update_edit_prediction_usage(usage, cx);
2361 });
2362 })
2363 .ok();
2364 }
2365 Ok(data)
2366 }
2367 Err(err) => {
2368 if err.is::<ZedUpdateRequiredError>() {
2369 cx.update(|cx| {
2370 this.update(cx, |this, _cx| {
2371 this.update_required = true;
2372 })
2373 .ok();
2374
2375 let error_message: SharedString = err.to_string().into();
2376 show_app_notification(
2377 NotificationId::unique::<ZedUpdateRequiredError>(),
2378 cx,
2379 move |cx| {
2380 cx.new(|cx| {
2381 ErrorMessagePrompt::new(error_message.clone(), cx)
2382 .with_link_button("Update Zed", "https://zed.dev/releases")
2383 })
2384 },
2385 );
2386 });
2387 }
2388 Err(err)
2389 }
2390 }
2391 }
2392
2393 async fn send_api_request<Res>(
2394 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2395 client: Arc<Client>,
2396 llm_token: LlmApiToken,
2397 app_version: Version,
2398 require_auth: bool,
2399 ) -> Result<(Res, Option<EditPredictionUsage>)>
2400 where
2401 Res: DeserializeOwned,
2402 {
2403 let http_client = client.http_client();
2404
2405 let mut token = if require_auth {
2406 Some(llm_token.acquire(&client).await?)
2407 } else {
2408 llm_token.acquire(&client).await.ok()
2409 };
2410 let mut did_retry = false;
2411
2412 loop {
2413 let request_builder = http_client::Request::builder().method(Method::POST);
2414
2415 let mut request_builder = request_builder
2416 .header("Content-Type", "application/json")
2417 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2418
2419 // Only add Authorization header if we have a token
2420 if let Some(ref token_value) = token {
2421 request_builder =
2422 request_builder.header("Authorization", format!("Bearer {}", token_value));
2423 }
2424
2425 let request = build(request_builder)?;
2426
2427 let mut response = http_client.send(request).await?;
2428
2429 if let Some(minimum_required_version) = response
2430 .headers()
2431 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2432 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2433 {
2434 anyhow::ensure!(
2435 app_version >= minimum_required_version,
2436 ZedUpdateRequiredError {
2437 minimum_version: minimum_required_version
2438 }
2439 );
2440 }
2441
2442 if response.status().is_success() {
2443 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2444
2445 let mut body = Vec::new();
2446 response.body_mut().read_to_end(&mut body).await?;
2447 return Ok((serde_json::from_slice(&body)?, usage));
2448 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2449 did_retry = true;
2450 token = Some(llm_token.refresh(&client).await?);
2451 } else {
2452 let mut body = String::new();
2453 response.body_mut().read_to_string(&mut body).await?;
2454 anyhow::bail!(
2455 "Request failed with status: {:?}\nBody: {}",
2456 response.status(),
2457 body
2458 );
2459 }
2460 }
2461 }
2462
2463 pub fn refresh_context(
2464 &mut self,
2465 project: &Entity<Project>,
2466 buffer: &Entity<language::Buffer>,
2467 cursor_position: language::Anchor,
2468 cx: &mut Context<Self>,
2469 ) {
2470 self.get_or_init_project(project, cx)
2471 .context
2472 .update(cx, |store, cx| {
2473 store.refresh(buffer.clone(), cursor_position, cx);
2474 });
2475 }
2476
2477 #[cfg(feature = "cli-support")]
2478 pub fn set_context_for_buffer(
2479 &mut self,
2480 project: &Entity<Project>,
2481 related_files: Vec<RelatedFile>,
2482 cx: &mut Context<Self>,
2483 ) {
2484 self.get_or_init_project(project, cx)
2485 .context
2486 .update(cx, |store, cx| {
2487 store.set_related_files(related_files, cx);
2488 });
2489 }
2490
2491 #[cfg(feature = "cli-support")]
2492 pub fn set_recent_paths_for_project(
2493 &mut self,
2494 project: &Entity<Project>,
2495 paths: impl IntoIterator<Item = project::ProjectPath>,
2496 cx: &mut Context<Self>,
2497 ) {
2498 let project_state = self.get_or_init_project(project, cx);
2499 project_state.recent_paths = paths.into_iter().collect();
2500 }
2501
2502 fn is_file_open_source(
2503 &self,
2504 project: &Entity<Project>,
2505 file: &Arc<dyn File>,
2506 cx: &App,
2507 ) -> bool {
2508 if !file.is_local() || file.is_private() {
2509 return false;
2510 }
2511 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2512 return false;
2513 };
2514 project_state
2515 .license_detection_watchers
2516 .get(&file.worktree_id(cx))
2517 .as_ref()
2518 .is_some_and(|watcher| watcher.is_project_open_source())
2519 }
2520
2521 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2522 self.data_collection_choice.is_enabled(cx)
2523 }
2524
2525 fn load_data_collection_choice() -> DataCollectionChoice {
2526 let choice = KEY_VALUE_STORE
2527 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2528 .log_err()
2529 .flatten();
2530
2531 match choice.as_deref() {
2532 Some("true") => DataCollectionChoice::Enabled,
2533 Some("false") => DataCollectionChoice::Disabled,
2534 Some(_) => {
2535 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2536 DataCollectionChoice::NotAnswered
2537 }
2538 None => DataCollectionChoice::NotAnswered,
2539 }
2540 }
2541
2542 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2543 self.data_collection_choice = self.data_collection_choice.toggle();
2544 let new_choice = self.data_collection_choice;
2545 let is_enabled = new_choice.is_enabled(cx);
2546 db::write_and_log(cx, move || {
2547 KEY_VALUE_STORE.write_kvp(
2548 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2549 is_enabled.to_string(),
2550 )
2551 });
2552 }
2553
2554 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2555 self.shown_predictions.iter()
2556 }
2557
2558 pub fn shown_completions_len(&self) -> usize {
2559 self.shown_predictions.len()
2560 }
2561
2562 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2563 self.rated_predictions.contains(id)
2564 }
2565
2566 pub fn rate_prediction(
2567 &mut self,
2568 prediction: &EditPrediction,
2569 rating: EditPredictionRating,
2570 feedback: String,
2571 cx: &mut Context<Self>,
2572 ) {
2573 let organization = self.user_store.read(cx).current_organization();
2574
2575 self.rated_predictions.insert(prediction.id.clone());
2576
2577 cx.background_spawn({
2578 let client = self.client.clone();
2579 let prediction_id = prediction.id.to_string();
2580 let inputs = serde_json::to_value(&prediction.inputs);
2581 let output = prediction
2582 .edit_preview
2583 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2584 async move {
2585 client
2586 .cloud_client()
2587 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2588 organization_id: organization.map(|organization| organization.id.clone()),
2589 request_id: prediction_id,
2590 rating: match rating {
2591 EditPredictionRating::Positive => "positive".to_string(),
2592 EditPredictionRating::Negative => "negative".to_string(),
2593 },
2594 inputs: inputs?,
2595 output,
2596 feedback,
2597 })
2598 .await?;
2599
2600 anyhow::Ok(())
2601 }
2602 })
2603 .detach_and_log_err(cx);
2604
2605 cx.notify();
2606 }
2607}
2608
2609fn merge_trailing_events_if_needed(
2610 events: &mut VecDeque<StoredEvent>,
2611 end_snapshot: &TextBufferSnapshot,
2612 latest_snapshot: &TextBufferSnapshot,
2613 latest_edit_range: &Range<Anchor>,
2614) {
2615 if let Some(last_event) = events.back() {
2616 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2617 return;
2618 }
2619 }
2620
2621 let mut next_old_event = None;
2622 let mut mergeable_count = 0;
2623 for old_event in events.iter().rev() {
2624 if let Some(next_old_event) = &next_old_event
2625 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2626 {
2627 break;
2628 }
2629 mergeable_count += 1;
2630 next_old_event = Some(old_event);
2631 }
2632
2633 if mergeable_count <= 1 {
2634 return;
2635 }
2636
2637 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2638 let oldest_event = events_to_merge.peek().unwrap();
2639 let oldest_snapshot = oldest_event.old_snapshot.clone();
2640
2641 if let Some((diff, edited_range)) =
2642 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2643 {
2644 let merged_event = match oldest_event.event.as_ref() {
2645 zeta_prompt::Event::BufferChange {
2646 old_path,
2647 path,
2648 in_open_source_repo,
2649 ..
2650 } => StoredEvent {
2651 event: Arc::new(zeta_prompt::Event::BufferChange {
2652 old_path: old_path.clone(),
2653 path: path.clone(),
2654 diff,
2655 in_open_source_repo: *in_open_source_repo,
2656 predicted: events_to_merge.all(|e| {
2657 matches!(
2658 e.event.as_ref(),
2659 zeta_prompt::Event::BufferChange {
2660 predicted: true,
2661 ..
2662 }
2663 )
2664 }),
2665 }),
2666 old_snapshot: oldest_snapshot.clone(),
2667 edit_range: end_snapshot.anchor_before(edited_range.start)
2668 ..end_snapshot.anchor_before(edited_range.end),
2669 },
2670 };
2671 events.truncate(events.len() - mergeable_count);
2672 events.push_back(merged_event);
2673 }
2674}
2675
2676pub(crate) fn filter_redundant_excerpts(
2677 mut related_files: Vec<RelatedFile>,
2678 cursor_path: &Path,
2679 cursor_row_range: Range<u32>,
2680) -> Vec<RelatedFile> {
2681 for file in &mut related_files {
2682 if file.path.as_ref() == cursor_path {
2683 file.excerpts.retain(|excerpt| {
2684 excerpt.row_range.start < cursor_row_range.start
2685 || excerpt.row_range.end > cursor_row_range.end
2686 });
2687 }
2688 }
2689 related_files.retain(|file| !file.excerpts.is_empty());
2690 related_files
2691}
2692
2693#[derive(Error, Debug)]
2694#[error(
2695 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2696)]
2697pub struct ZedUpdateRequiredError {
2698 minimum_version: Version,
2699}
2700
2701#[derive(Debug, Clone, Copy)]
2702pub enum DataCollectionChoice {
2703 NotAnswered,
2704 Enabled,
2705 Disabled,
2706}
2707
2708impl DataCollectionChoice {
2709 pub fn is_enabled(self, cx: &App) -> bool {
2710 if cx.is_staff() {
2711 return true;
2712 }
2713 match self {
2714 Self::Enabled => true,
2715 Self::NotAnswered | Self::Disabled => false,
2716 }
2717 }
2718
2719 #[must_use]
2720 pub fn toggle(&self) -> DataCollectionChoice {
2721 match self {
2722 Self::Enabled => Self::Disabled,
2723 Self::Disabled => Self::Enabled,
2724 Self::NotAnswered => Self::Enabled,
2725 }
2726 }
2727}
2728
2729impl From<bool> for DataCollectionChoice {
2730 fn from(value: bool) -> Self {
2731 match value {
2732 true => DataCollectionChoice::Enabled,
2733 false => DataCollectionChoice::Disabled,
2734 }
2735 }
2736}
2737
2738struct ZedPredictUpsell;
2739
2740impl Dismissable for ZedPredictUpsell {
2741 const KEY: &'static str = "dismissed-edit-predict-upsell";
2742
2743 fn dismissed() -> bool {
2744 // To make this backwards compatible with older versions of Zed, we
2745 // check if the user has seen the previous Edit Prediction Onboarding
2746 // before, by checking the data collection choice which was written to
2747 // the database once the user clicked on "Accept and Enable"
2748 if KEY_VALUE_STORE
2749 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2750 .log_err()
2751 .is_some_and(|s| s.is_some())
2752 {
2753 return true;
2754 }
2755
2756 KEY_VALUE_STORE
2757 .read_kvp(Self::KEY)
2758 .log_err()
2759 .is_some_and(|s| s.is_some())
2760 }
2761}
2762
2763pub fn should_show_upsell_modal() -> bool {
2764 !ZedPredictUpsell::dismissed()
2765}
2766
2767pub fn init(cx: &mut App) {
2768 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2769 workspace.register_action(
2770 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2771 ZedPredictModal::toggle(
2772 workspace,
2773 workspace.user_store().clone(),
2774 workspace.client().clone(),
2775 window,
2776 cx,
2777 )
2778 },
2779 );
2780
2781 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2782 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2783 settings
2784 .project
2785 .all_languages
2786 .edit_predictions
2787 .get_or_insert_default()
2788 .provider = Some(EditPredictionProvider::None)
2789 });
2790 });
2791 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2792 EditPredictionStore::try_global(cx).and_then(|store| {
2793 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2794 })
2795 }
2796
2797 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2798 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2799 copilot_ui::initiate_sign_in(copilot, window, cx);
2800 }
2801 });
2802 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2803 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2804 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2805 }
2806 });
2807 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2808 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2809 copilot_ui::initiate_sign_out(copilot, window, cx);
2810 }
2811 });
2812 })
2813 .detach();
2814}