1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::SubmitEditPredictionFeedbackBody;
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod cursor_excerpt;
59pub mod example_spec;
60pub mod fim;
61mod license_detection;
62pub mod mercury;
63pub mod ollama;
64mod onboarding_modal;
65pub mod open_ai_response;
66mod prediction;
67pub mod sweep_ai;
68
69pub mod udiff;
70
71mod capture_example;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::license_detection::LicenseDetectionWatcher;
79use crate::mercury::Mercury;
80use crate::onboarding_modal::ZedPredictModal;
81pub use crate::prediction::EditPrediction;
82pub use crate::prediction::EditPredictionId;
83use crate::prediction::EditPredictionResult;
84pub use crate::sweep_ai::SweepAi;
85pub use capture_example::capture_example;
86pub use language_model::ApiKeyState;
87pub use telemetry_events::EditPredictionRating;
88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
89
90actions!(
91 edit_prediction,
92 [
93 /// Resets the edit prediction onboarding state.
94 ResetOnboarding,
95 /// Clears the edit prediction history.
96 ClearHistory,
97 ]
98);
99
100/// Maximum number of events to track.
101const EVENT_COUNT_MAX: usize = 6;
102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
106const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
107const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
108const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
109
110pub struct Zeta2FeatureFlag;
111pub struct EditPredictionJumpsFeatureFlag;
112
113impl FeatureFlag for Zeta2FeatureFlag {
114 const NAME: &'static str = "zeta2";
115}
116
117impl FeatureFlag for EditPredictionJumpsFeatureFlag {
118 const NAME: &'static str = "edit_prediction_jumps";
119}
120
121#[derive(Clone)]
122struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
123
124impl Global for EditPredictionStoreGlobal {}
125
126/// Configuration for using the raw Zeta2 endpoint.
127/// When set, the client uses the raw endpoint and constructs the prompt itself.
128/// The version is also used as the Baseten environment name (lowercased).
129#[derive(Clone)]
130pub struct Zeta2RawConfig {
131 pub model_id: Option<String>,
132 pub format: ZetaFormat,
133}
134
135pub struct EditPredictionStore {
136 client: Arc<Client>,
137 user_store: Entity<UserStore>,
138 llm_token: LlmApiToken,
139 _llm_token_subscription: Subscription,
140 projects: HashMap<EntityId, ProjectState>,
141 update_required: bool,
142 edit_prediction_model: EditPredictionModel,
143 zeta2_raw_config: Option<Zeta2RawConfig>,
144 pub sweep_ai: SweepAi,
145 pub mercury: Mercury,
146 data_collection_choice: DataCollectionChoice,
147 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
148 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
149 shown_predictions: VecDeque<EditPrediction>,
150 rated_predictions: HashSet<EditPredictionId>,
151 #[cfg(test)]
152 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
153}
154
155#[derive(Copy, Clone, PartialEq, Eq)]
156pub enum EditPredictionModel {
157 Zeta1,
158 Zeta2,
159 Fim { format: EditPredictionPromptFormat },
160 Sweep,
161 Mercury,
162}
163
164#[derive(Clone)]
165pub struct EditPredictionModelInput {
166 project: Entity<Project>,
167 buffer: Entity<Buffer>,
168 snapshot: BufferSnapshot,
169 position: Anchor,
170 events: Vec<Arc<zeta_prompt::Event>>,
171 related_files: Vec<RelatedFile>,
172 recent_paths: VecDeque<ProjectPath>,
173 trigger: PredictEditsRequestTrigger,
174 diagnostic_search_range: Range<Point>,
175 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
176 pub user_actions: Vec<UserActionRecord>,
177}
178
179#[derive(Debug)]
180pub enum DebugEvent {
181 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
182 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
183 EditPredictionStarted(EditPredictionStartedDebugEvent),
184 EditPredictionFinished(EditPredictionFinishedDebugEvent),
185}
186
187#[derive(Debug)]
188pub struct ContextRetrievalStartedDebugEvent {
189 pub project_entity_id: EntityId,
190 pub timestamp: Instant,
191 pub search_prompt: String,
192}
193
194#[derive(Debug)]
195pub struct ContextRetrievalFinishedDebugEvent {
196 pub project_entity_id: EntityId,
197 pub timestamp: Instant,
198 pub metadata: Vec<(&'static str, SharedString)>,
199}
200
201#[derive(Debug)]
202pub struct EditPredictionStartedDebugEvent {
203 pub buffer: WeakEntity<Buffer>,
204 pub position: Anchor,
205 pub prompt: Option<String>,
206}
207
208#[derive(Debug)]
209pub struct EditPredictionFinishedDebugEvent {
210 pub buffer: WeakEntity<Buffer>,
211 pub position: Anchor,
212 pub model_output: Option<String>,
213}
214
215const USER_ACTION_HISTORY_SIZE: usize = 16;
216
217#[derive(Clone, Debug)]
218pub struct UserActionRecord {
219 pub action_type: UserActionType,
220 pub buffer_id: EntityId,
221 pub line_number: u32,
222 pub offset: usize,
223 pub timestamp_epoch_ms: u64,
224}
225
226#[derive(Clone, Copy, Debug, PartialEq, Eq)]
227pub enum UserActionType {
228 InsertChar,
229 InsertSelection,
230 DeleteChar,
231 DeleteSelection,
232 CursorMovement,
233}
234
235/// An event with associated metadata for reconstructing buffer state.
236#[derive(Clone)]
237pub struct StoredEvent {
238 pub event: Arc<zeta_prompt::Event>,
239 pub old_snapshot: TextBufferSnapshot,
240 pub edit_range: Range<Anchor>,
241}
242
243impl StoredEvent {
244 fn can_merge(
245 &self,
246 next_old_event: &&&StoredEvent,
247 new_snapshot: &TextBufferSnapshot,
248 last_edit_range: &Range<Anchor>,
249 ) -> bool {
250 // Events must be for the same buffer
251 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
252 return false;
253 }
254 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
255 return false;
256 }
257
258 let a_is_predicted = matches!(
259 self.event.as_ref(),
260 zeta_prompt::Event::BufferChange {
261 predicted: true,
262 ..
263 }
264 );
265 let b_is_predicted = matches!(
266 next_old_event.event.as_ref(),
267 zeta_prompt::Event::BufferChange {
268 predicted: true,
269 ..
270 }
271 );
272
273 // If events come from the same source (both predicted or both manual) then
274 // we would have coalesced them already.
275 if a_is_predicted == b_is_predicted {
276 return false;
277 }
278
279 let left_range = self.edit_range.to_point(new_snapshot);
280 let right_range = next_old_event.edit_range.to_point(new_snapshot);
281 let latest_range = last_edit_range.to_point(&new_snapshot);
282
283 // Events near to the latest edit are not merged if their sources differ.
284 if lines_between_ranges(&left_range, &latest_range)
285 .min(lines_between_ranges(&right_range, &latest_range))
286 <= CHANGE_GROUPING_LINE_SPAN
287 {
288 return false;
289 }
290
291 // Events that are distant from each other are not merged.
292 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
293 return false;
294 }
295
296 true
297 }
298}
299
300fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
301 if left.start > right.end {
302 return left.start.row - right.end.row;
303 }
304 if right.start > left.end {
305 return right.start.row - left.end.row;
306 }
307 0
308}
309
310struct ProjectState {
311 events: VecDeque<StoredEvent>,
312 last_event: Option<LastEvent>,
313 recent_paths: VecDeque<ProjectPath>,
314 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
315 current_prediction: Option<CurrentEditPrediction>,
316 next_pending_prediction_id: usize,
317 pending_predictions: ArrayVec<PendingPrediction, 2>,
318 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
319 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
320 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
321 cancelled_predictions: HashSet<usize>,
322 context: Entity<RelatedExcerptStore>,
323 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
324 user_actions: VecDeque<UserActionRecord>,
325 _subscriptions: [gpui::Subscription; 2],
326 copilot: Option<Entity<Copilot>>,
327}
328
329impl ProjectState {
330 fn record_user_action(&mut self, action: UserActionRecord) {
331 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
332 self.user_actions.pop_front();
333 }
334 self.user_actions.push_back(action);
335 }
336
337 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
338 self.events
339 .iter()
340 .cloned()
341 .chain(self.last_event.as_ref().iter().flat_map(|event| {
342 let (one, two) = event.split_by_pause();
343 let one = one.finalize(&self.license_detection_watchers, cx);
344 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
345 one.into_iter().chain(two)
346 }))
347 .collect()
348 }
349
350 fn cancel_pending_prediction(
351 &mut self,
352 pending_prediction: PendingPrediction,
353 cx: &mut Context<EditPredictionStore>,
354 ) {
355 self.cancelled_predictions.insert(pending_prediction.id);
356
357 if pending_prediction.drop_on_cancel {
358 drop(pending_prediction.task);
359 } else {
360 cx.spawn(async move |this, cx| {
361 let Some(prediction_id) = pending_prediction.task.await else {
362 return;
363 };
364
365 this.update(cx, |this, cx| {
366 this.reject_prediction(
367 prediction_id,
368 EditPredictionRejectReason::Canceled,
369 false,
370 None,
371 cx,
372 );
373 })
374 .ok();
375 })
376 .detach()
377 }
378 }
379
380 fn active_buffer(
381 &self,
382 project: &Entity<Project>,
383 cx: &App,
384 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
385 let project = project.read(cx);
386 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
387 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
388 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
389 Some((active_buffer, registered_buffer.last_position))
390 }
391}
392
393#[derive(Debug, Clone)]
394struct CurrentEditPrediction {
395 pub requested_by: PredictionRequestedBy,
396 pub prediction: EditPrediction,
397 pub was_shown: bool,
398 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
399}
400
401impl CurrentEditPrediction {
402 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
403 let Some(new_edits) = self
404 .prediction
405 .interpolate(&self.prediction.buffer.read(cx))
406 else {
407 return false;
408 };
409
410 if self.prediction.buffer != old_prediction.prediction.buffer {
411 return true;
412 }
413
414 let Some(old_edits) = old_prediction
415 .prediction
416 .interpolate(&old_prediction.prediction.buffer.read(cx))
417 else {
418 return true;
419 };
420
421 let requested_by_buffer_id = self.requested_by.buffer_id();
422
423 // This reduces the occurrence of UI thrash from replacing edits
424 //
425 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
426 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
427 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
428 && old_edits.len() == 1
429 && new_edits.len() == 1
430 {
431 let (old_range, old_text) = &old_edits[0];
432 let (new_range, new_text) = &new_edits[0];
433 new_range == old_range && new_text.starts_with(old_text.as_ref())
434 } else {
435 true
436 }
437 }
438}
439
440#[derive(Debug, Clone)]
441enum PredictionRequestedBy {
442 DiagnosticsUpdate,
443 Buffer(EntityId),
444}
445
446impl PredictionRequestedBy {
447 pub fn buffer_id(&self) -> Option<EntityId> {
448 match self {
449 PredictionRequestedBy::DiagnosticsUpdate => None,
450 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
451 }
452 }
453}
454
455const DIAGNOSTIC_LINES_RANGE: u32 = 20;
456
457#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
458pub enum DiagnosticSearchScope {
459 Local,
460 Global,
461}
462
463#[derive(Debug)]
464struct PendingPrediction {
465 id: usize,
466 task: Task<Option<EditPredictionId>>,
467 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
468 /// If false, the task is awaited to completion so rejection can be reported.
469 drop_on_cancel: bool,
470}
471
472/// A prediction from the perspective of a buffer.
473#[derive(Debug)]
474enum BufferEditPrediction<'a> {
475 Local { prediction: &'a EditPrediction },
476 Jump { prediction: &'a EditPrediction },
477}
478
479#[cfg(test)]
480impl std::ops::Deref for BufferEditPrediction<'_> {
481 type Target = EditPrediction;
482
483 fn deref(&self) -> &Self::Target {
484 match self {
485 BufferEditPrediction::Local { prediction } => prediction,
486 BufferEditPrediction::Jump { prediction } => prediction,
487 }
488 }
489}
490
491#[derive(Clone)]
492struct PendingSettledPrediction {
493 request_id: EditPredictionId,
494 editable_anchor_range: Range<Anchor>,
495 enqueued_at: Instant,
496 last_edit_at: Instant,
497}
498
499struct RegisteredBuffer {
500 file: Option<Arc<dyn File>>,
501 snapshot: TextBufferSnapshot,
502 pending_predictions: Vec<PendingSettledPrediction>,
503 last_position: Option<Anchor>,
504 _subscriptions: [gpui::Subscription; 2],
505}
506
507#[derive(Clone)]
508struct LastEvent {
509 old_snapshot: TextBufferSnapshot,
510 new_snapshot: TextBufferSnapshot,
511 old_file: Option<Arc<dyn File>>,
512 new_file: Option<Arc<dyn File>>,
513 edit_range: Option<Range<Anchor>>,
514 predicted: bool,
515 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
516 last_edit_time: Option<Instant>,
517}
518
519impl LastEvent {
520 pub fn finalize(
521 &self,
522 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
523 cx: &App,
524 ) -> Option<StoredEvent> {
525 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
526 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
527
528 let in_open_source_repo =
529 [self.new_file.as_ref(), self.old_file.as_ref()]
530 .iter()
531 .all(|file| {
532 file.is_some_and(|file| {
533 license_detection_watchers
534 .get(&file.worktree_id(cx))
535 .is_some_and(|watcher| watcher.is_project_open_source())
536 })
537 });
538
539 let (diff, edit_range) =
540 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
541
542 if path == old_path && diff.is_empty() {
543 None
544 } else {
545 Some(StoredEvent {
546 event: Arc::new(zeta_prompt::Event::BufferChange {
547 old_path,
548 path,
549 diff,
550 in_open_source_repo,
551 predicted: self.predicted,
552 }),
553 edit_range: self.new_snapshot.anchor_before(edit_range.start)
554 ..self.new_snapshot.anchor_before(edit_range.end),
555 old_snapshot: self.old_snapshot.clone(),
556 })
557 }
558 }
559
560 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
561 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
562 return (self.clone(), None);
563 };
564
565 let before = LastEvent {
566 old_snapshot: self.old_snapshot.clone(),
567 new_snapshot: boundary_snapshot.clone(),
568 old_file: self.old_file.clone(),
569 new_file: self.new_file.clone(),
570 edit_range: None,
571 predicted: self.predicted,
572 snapshot_after_last_editing_pause: None,
573 last_edit_time: self.last_edit_time,
574 };
575
576 let after = LastEvent {
577 old_snapshot: boundary_snapshot.clone(),
578 new_snapshot: self.new_snapshot.clone(),
579 old_file: self.old_file.clone(),
580 new_file: self.new_file.clone(),
581 edit_range: None,
582 predicted: self.predicted,
583 snapshot_after_last_editing_pause: None,
584 last_edit_time: self.last_edit_time,
585 };
586
587 (before, Some(after))
588 }
589}
590
591pub(crate) fn compute_diff_between_snapshots(
592 old_snapshot: &TextBufferSnapshot,
593 new_snapshot: &TextBufferSnapshot,
594) -> Option<(String, Range<Point>)> {
595 let edits: Vec<Edit<usize>> = new_snapshot
596 .edits_since::<usize>(&old_snapshot.version)
597 .collect();
598
599 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
600
601 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
602 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
603 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
604 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
605
606 const CONTEXT_LINES: u32 = 3;
607
608 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
609 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
610 let old_context_end_row =
611 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
612 let new_context_end_row =
613 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
614
615 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
616 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
617 let old_end_line_offset = old_snapshot
618 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
619 let new_end_line_offset = new_snapshot
620 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
621 let old_edit_range = old_start_line_offset..old_end_line_offset;
622 let new_edit_range = new_start_line_offset..new_end_line_offset;
623
624 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
625 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
626
627 let diff = language::unified_diff_with_offsets(
628 &old_region_text,
629 &new_region_text,
630 old_context_start_row,
631 new_context_start_row,
632 );
633
634 Some((diff, new_start_point..new_end_point))
635}
636
637fn buffer_path_with_id_fallback(
638 file: Option<&Arc<dyn File>>,
639 snapshot: &TextBufferSnapshot,
640 cx: &App,
641) -> Arc<Path> {
642 if let Some(file) = file {
643 file.full_path(cx).into()
644 } else {
645 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
646 }
647}
648
649impl EditPredictionStore {
650 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
651 cx.try_global::<EditPredictionStoreGlobal>()
652 .map(|global| global.0.clone())
653 }
654
655 pub fn global(
656 client: &Arc<Client>,
657 user_store: &Entity<UserStore>,
658 cx: &mut App,
659 ) -> Entity<Self> {
660 cx.try_global::<EditPredictionStoreGlobal>()
661 .map(|global| global.0.clone())
662 .unwrap_or_else(|| {
663 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
664 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
665 ep_store
666 })
667 }
668
669 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
670 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
671 let data_collection_choice = Self::load_data_collection_choice();
672
673 let llm_token = LlmApiToken::default();
674
675 let (reject_tx, reject_rx) = mpsc::unbounded();
676 cx.background_spawn({
677 let client = client.clone();
678 let llm_token = llm_token.clone();
679 let app_version = AppVersion::global(cx);
680 let background_executor = cx.background_executor().clone();
681 async move {
682 Self::handle_rejected_predictions(
683 reject_rx,
684 client,
685 llm_token,
686 app_version,
687 background_executor,
688 )
689 .await
690 }
691 })
692 .detach();
693
694 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
695 cx.spawn(async move |this, cx| {
696 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
697 })
698 .detach();
699
700 let this = Self {
701 projects: HashMap::default(),
702 client,
703 user_store,
704 llm_token,
705 _llm_token_subscription: cx.subscribe(
706 &refresh_llm_token_listener,
707 |this, _listener, _event, cx| {
708 let client = this.client.clone();
709 let llm_token = this.llm_token.clone();
710 cx.spawn(async move |_this, _cx| {
711 llm_token.refresh(&client).await?;
712 anyhow::Ok(())
713 })
714 .detach_and_log_err(cx);
715 },
716 ),
717 update_required: false,
718 edit_prediction_model: EditPredictionModel::Zeta2,
719 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
720 sweep_ai: SweepAi::new(cx),
721 mercury: Mercury::new(cx),
722
723 data_collection_choice,
724 reject_predictions_tx: reject_tx,
725 settled_predictions_tx,
726 rated_predictions: Default::default(),
727 shown_predictions: Default::default(),
728 #[cfg(test)]
729 settled_event_callback: None,
730 };
731
732 this
733 }
734
735 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
736 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
737 let format = ZetaFormat::parse(&version_str).ok()?;
738 let model_id = env::var("ZED_ZETA_MODEL").ok();
739 Some(Zeta2RawConfig { model_id, format })
740 }
741
742 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
743 self.edit_prediction_model = model;
744 }
745
746 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
747 self.zeta2_raw_config = Some(config);
748 }
749
750 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
751 self.zeta2_raw_config.as_ref()
752 }
753
754 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
755 use ui::IconName;
756 match self.edit_prediction_model {
757 EditPredictionModel::Sweep => {
758 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
759 .with_disabled(IconName::SweepAiDisabled)
760 .with_up(IconName::SweepAiUp)
761 .with_down(IconName::SweepAiDown)
762 .with_error(IconName::SweepAiError)
763 }
764 EditPredictionModel::Mercury => {
765 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
766 }
767 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
768 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
769 .with_disabled(IconName::ZedPredictDisabled)
770 .with_up(IconName::ZedPredictUp)
771 .with_down(IconName::ZedPredictDown)
772 .with_error(IconName::ZedPredictError)
773 }
774 EditPredictionModel::Fim { .. } => {
775 let settings = &all_language_settings(None, cx).edit_predictions;
776 match settings.provider {
777 EditPredictionProvider::Ollama => {
778 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
779 }
780 _ => {
781 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
782 }
783 }
784 }
785 }
786 }
787
788 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
789 self.sweep_ai.api_token.read(cx).has_key()
790 }
791
792 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
793 self.mercury.api_token.read(cx).has_key()
794 }
795
796 pub fn clear_history(&mut self) {
797 for project_state in self.projects.values_mut() {
798 project_state.events.clear();
799 project_state.last_event.take();
800 }
801 }
802
803 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
804 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
805 project_state.events.clear();
806 project_state.last_event.take();
807 }
808 }
809
810 pub fn edit_history_for_project(
811 &self,
812 project: &Entity<Project>,
813 cx: &App,
814 ) -> Vec<StoredEvent> {
815 self.projects
816 .get(&project.entity_id())
817 .map(|project_state| project_state.events(cx))
818 .unwrap_or_default()
819 }
820
821 pub fn context_for_project<'a>(
822 &'a self,
823 project: &Entity<Project>,
824 cx: &'a mut App,
825 ) -> Vec<RelatedFile> {
826 self.projects
827 .get(&project.entity_id())
828 .map(|project_state| {
829 project_state.context.update(cx, |context, cx| {
830 context
831 .related_files_with_buffers(cx)
832 .map(|(mut related_file, buffer)| {
833 related_file.in_open_source_repo = buffer
834 .read(cx)
835 .file()
836 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
837 related_file
838 })
839 .collect()
840 })
841 })
842 .unwrap_or_default()
843 }
844
845 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
846 self.projects
847 .get(&project.entity_id())
848 .and_then(|project| project.copilot.clone())
849 }
850
851 pub fn start_copilot_for_project(
852 &mut self,
853 project: &Entity<Project>,
854 cx: &mut Context<Self>,
855 ) -> Option<Entity<Copilot>> {
856 if DisableAiSettings::get(None, cx).disable_ai {
857 return None;
858 }
859 let state = self.get_or_init_project(project, cx);
860
861 if state.copilot.is_some() {
862 return state.copilot.clone();
863 }
864 let _project = project.clone();
865 let project = project.read(cx);
866
867 let node = project.node_runtime().cloned();
868 if let Some(node) = node {
869 let next_id = project.languages().next_language_server_id();
870 let fs = project.fs().clone();
871
872 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
873 state.copilot = Some(copilot.clone());
874 Some(copilot)
875 } else {
876 None
877 }
878 }
879
880 pub fn context_for_project_with_buffers<'a>(
881 &'a self,
882 project: &Entity<Project>,
883 cx: &'a mut App,
884 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
885 self.projects
886 .get(&project.entity_id())
887 .map(|project| {
888 project.context.update(cx, |context, cx| {
889 context.related_files_with_buffers(cx).collect()
890 })
891 })
892 .unwrap_or_default()
893 }
894
895 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
896 if matches!(
897 self.edit_prediction_model,
898 EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
899 ) {
900 self.user_store.read(cx).edit_prediction_usage()
901 } else {
902 None
903 }
904 }
905
906 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
907 self.get_or_init_project(project, cx);
908 }
909
910 pub fn register_buffer(
911 &mut self,
912 buffer: &Entity<Buffer>,
913 project: &Entity<Project>,
914 cx: &mut Context<Self>,
915 ) {
916 let project_state = self.get_or_init_project(project, cx);
917 Self::register_buffer_impl(project_state, buffer, project, cx);
918 }
919
920 fn get_or_init_project(
921 &mut self,
922 project: &Entity<Project>,
923 cx: &mut Context<Self>,
924 ) -> &mut ProjectState {
925 let entity_id = project.entity_id();
926 self.projects
927 .entry(entity_id)
928 .or_insert_with(|| ProjectState {
929 context: {
930 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
931 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
932 this.handle_excerpt_store_event(entity_id, event);
933 })
934 .detach();
935 related_excerpt_store
936 },
937 events: VecDeque::new(),
938 last_event: None,
939 recent_paths: VecDeque::new(),
940 debug_tx: None,
941 registered_buffers: HashMap::default(),
942 current_prediction: None,
943 cancelled_predictions: HashSet::default(),
944 pending_predictions: ArrayVec::new(),
945 next_pending_prediction_id: 0,
946 last_edit_prediction_refresh: None,
947 last_jump_prediction_refresh: None,
948 license_detection_watchers: HashMap::default(),
949 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
950 _subscriptions: [
951 cx.subscribe(&project, Self::handle_project_event),
952 cx.observe_release(&project, move |this, _, cx| {
953 this.projects.remove(&entity_id);
954 cx.notify();
955 }),
956 ],
957 copilot: None,
958 })
959 }
960
961 pub fn remove_project(&mut self, project: &Entity<Project>) {
962 self.projects.remove(&project.entity_id());
963 }
964
965 fn handle_excerpt_store_event(
966 &mut self,
967 project_entity_id: EntityId,
968 event: &RelatedExcerptStoreEvent,
969 ) {
970 if let Some(project_state) = self.projects.get(&project_entity_id) {
971 if let Some(debug_tx) = project_state.debug_tx.clone() {
972 match event {
973 RelatedExcerptStoreEvent::StartedRefresh => {
974 debug_tx
975 .unbounded_send(DebugEvent::ContextRetrievalStarted(
976 ContextRetrievalStartedDebugEvent {
977 project_entity_id: project_entity_id,
978 timestamp: Instant::now(),
979 search_prompt: String::new(),
980 },
981 ))
982 .ok();
983 }
984 RelatedExcerptStoreEvent::FinishedRefresh {
985 cache_hit_count,
986 cache_miss_count,
987 mean_definition_latency,
988 max_definition_latency,
989 } => {
990 debug_tx
991 .unbounded_send(DebugEvent::ContextRetrievalFinished(
992 ContextRetrievalFinishedDebugEvent {
993 project_entity_id: project_entity_id,
994 timestamp: Instant::now(),
995 metadata: vec![
996 (
997 "Cache Hits",
998 format!(
999 "{}/{}",
1000 cache_hit_count,
1001 cache_hit_count + cache_miss_count
1002 )
1003 .into(),
1004 ),
1005 (
1006 "Max LSP Time",
1007 format!("{} ms", max_definition_latency.as_millis())
1008 .into(),
1009 ),
1010 (
1011 "Mean LSP Time",
1012 format!("{} ms", mean_definition_latency.as_millis())
1013 .into(),
1014 ),
1015 ],
1016 },
1017 ))
1018 .ok();
1019 }
1020 }
1021 }
1022 }
1023 }
1024
1025 pub fn debug_info(
1026 &mut self,
1027 project: &Entity<Project>,
1028 cx: &mut Context<Self>,
1029 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1030 let project_state = self.get_or_init_project(project, cx);
1031 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1032 project_state.debug_tx = Some(debug_watch_tx);
1033 debug_watch_rx
1034 }
1035
1036 fn handle_project_event(
1037 &mut self,
1038 project: Entity<Project>,
1039 event: &project::Event,
1040 cx: &mut Context<Self>,
1041 ) {
1042 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1043 return;
1044 }
1045 // TODO [zeta2] init with recent paths
1046 match event {
1047 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1048 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1049 return;
1050 };
1051 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1052 if let Some(path) = path {
1053 if let Some(ix) = project_state
1054 .recent_paths
1055 .iter()
1056 .position(|probe| probe == &path)
1057 {
1058 project_state.recent_paths.remove(ix);
1059 }
1060 project_state.recent_paths.push_front(path);
1061 }
1062 }
1063 project::Event::DiagnosticsUpdated { .. } => {
1064 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1065 self.refresh_prediction_from_diagnostics(
1066 project,
1067 DiagnosticSearchScope::Global,
1068 cx,
1069 );
1070 }
1071 }
1072 _ => (),
1073 }
1074 }
1075
1076 fn register_buffer_impl<'a>(
1077 project_state: &'a mut ProjectState,
1078 buffer: &Entity<Buffer>,
1079 project: &Entity<Project>,
1080 cx: &mut Context<Self>,
1081 ) -> &'a mut RegisteredBuffer {
1082 let buffer_id = buffer.entity_id();
1083
1084 if let Some(file) = buffer.read(cx).file() {
1085 let worktree_id = file.worktree_id(cx);
1086 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1087 project_state
1088 .license_detection_watchers
1089 .entry(worktree_id)
1090 .or_insert_with(|| {
1091 let project_entity_id = project.entity_id();
1092 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1093 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1094 else {
1095 return;
1096 };
1097 project_state
1098 .license_detection_watchers
1099 .remove(&worktree_id);
1100 })
1101 .detach();
1102 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1103 });
1104 }
1105 }
1106
1107 match project_state.registered_buffers.entry(buffer_id) {
1108 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1109 hash_map::Entry::Vacant(entry) => {
1110 let buf = buffer.read(cx);
1111 let snapshot = buf.text_snapshot();
1112 let file = buf.file().cloned();
1113 let project_entity_id = project.entity_id();
1114 entry.insert(RegisteredBuffer {
1115 snapshot,
1116 file,
1117 last_position: None,
1118 pending_predictions: Vec::new(),
1119 _subscriptions: [
1120 cx.subscribe(buffer, {
1121 let project = project.downgrade();
1122 move |this, buffer, event, cx| {
1123 if let language::BufferEvent::Edited = event
1124 && let Some(project) = project.upgrade()
1125 {
1126 this.report_changes_for_buffer(&buffer, &project, false, cx);
1127 }
1128 }
1129 }),
1130 cx.observe_release(buffer, move |this, _buffer, _cx| {
1131 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1132 else {
1133 return;
1134 };
1135 project_state.registered_buffers.remove(&buffer_id);
1136 }),
1137 ],
1138 })
1139 }
1140 }
1141 }
1142
1143 fn report_changes_for_buffer(
1144 &mut self,
1145 buffer: &Entity<Buffer>,
1146 project: &Entity<Project>,
1147 is_predicted: bool,
1148 cx: &mut Context<Self>,
1149 ) {
1150 let project_state = self.get_or_init_project(project, cx);
1151 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1152
1153 let buf = buffer.read(cx);
1154 let new_file = buf.file().cloned();
1155 let new_snapshot = buf.text_snapshot();
1156 if new_snapshot.version == registered_buffer.snapshot.version {
1157 return;
1158 }
1159
1160 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1161 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1162 let mut num_edits = 0usize;
1163 let mut total_deleted = 0usize;
1164 let mut total_inserted = 0usize;
1165 let mut edit_range: Option<Range<Anchor>> = None;
1166 let mut last_offset: Option<usize> = None;
1167 let now = cx.background_executor().now();
1168
1169 for (edit, anchor_range) in
1170 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1171 {
1172 num_edits += 1;
1173 total_deleted += edit.old.len();
1174 total_inserted += edit.new.len();
1175 edit_range = Some(match edit_range {
1176 None => anchor_range,
1177 Some(acc) => acc.start..anchor_range.end,
1178 });
1179 last_offset = Some(edit.new.end);
1180 }
1181
1182 let Some(edit_range) = edit_range else {
1183 return;
1184 };
1185
1186 for pending_prediction in &mut registered_buffer.pending_predictions {
1187 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1188 pending_prediction.last_edit_at = now;
1189 }
1190 }
1191
1192 let action_type = match (total_deleted, total_inserted, num_edits) {
1193 (0, ins, n) if ins == n => UserActionType::InsertChar,
1194 (0, _, _) => UserActionType::InsertSelection,
1195 (del, 0, n) if del == n => UserActionType::DeleteChar,
1196 (_, 0, _) => UserActionType::DeleteSelection,
1197 (_, ins, n) if ins == n => UserActionType::InsertChar,
1198 (_, _, _) => UserActionType::InsertSelection,
1199 };
1200
1201 if let Some(offset) = last_offset {
1202 let point = new_snapshot.offset_to_point(offset);
1203 let timestamp_epoch_ms = SystemTime::now()
1204 .duration_since(UNIX_EPOCH)
1205 .map(|d| d.as_millis() as u64)
1206 .unwrap_or(0);
1207 project_state.record_user_action(UserActionRecord {
1208 action_type,
1209 buffer_id: buffer.entity_id(),
1210 line_number: point.row,
1211 offset,
1212 timestamp_epoch_ms,
1213 });
1214 }
1215
1216 let events = &mut project_state.events;
1217
1218 if let Some(last_event) = project_state.last_event.as_mut() {
1219 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1220 == last_event.new_snapshot.remote_id()
1221 && old_snapshot.version == last_event.new_snapshot.version;
1222
1223 let prediction_source_changed = is_predicted != last_event.predicted;
1224
1225 let should_coalesce = is_next_snapshot_of_same_buffer
1226 && !prediction_source_changed
1227 && last_event
1228 .edit_range
1229 .as_ref()
1230 .is_some_and(|last_edit_range| {
1231 lines_between_ranges(
1232 &edit_range.to_point(&new_snapshot),
1233 &last_edit_range.to_point(&new_snapshot),
1234 ) <= CHANGE_GROUPING_LINE_SPAN
1235 });
1236
1237 if should_coalesce {
1238 let pause_elapsed = last_event
1239 .last_edit_time
1240 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1241 .unwrap_or(false);
1242 if pause_elapsed {
1243 last_event.snapshot_after_last_editing_pause =
1244 Some(last_event.new_snapshot.clone());
1245 }
1246
1247 last_event.edit_range = Some(edit_range);
1248 last_event.new_snapshot = new_snapshot;
1249 last_event.last_edit_time = Some(now);
1250 return;
1251 }
1252 }
1253
1254 if let Some(event) = project_state.last_event.take() {
1255 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1256 if events.len() + 1 >= EVENT_COUNT_MAX {
1257 events.pop_front();
1258 }
1259 events.push_back(event);
1260 }
1261 }
1262
1263 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1264
1265 project_state.last_event = Some(LastEvent {
1266 old_file,
1267 new_file,
1268 old_snapshot,
1269 new_snapshot,
1270 edit_range: Some(edit_range),
1271 predicted: is_predicted,
1272 snapshot_after_last_editing_pause: None,
1273 last_edit_time: Some(now),
1274 });
1275 }
1276
1277 fn prediction_at(
1278 &mut self,
1279 buffer: &Entity<Buffer>,
1280 position: Option<language::Anchor>,
1281 project: &Entity<Project>,
1282 cx: &App,
1283 ) -> Option<BufferEditPrediction<'_>> {
1284 let project_state = self.projects.get_mut(&project.entity_id())?;
1285 if let Some(position) = position
1286 && let Some(buffer) = project_state
1287 .registered_buffers
1288 .get_mut(&buffer.entity_id())
1289 {
1290 buffer.last_position = Some(position);
1291 }
1292
1293 let CurrentEditPrediction {
1294 requested_by,
1295 prediction,
1296 ..
1297 } = project_state.current_prediction.as_ref()?;
1298
1299 if prediction.targets_buffer(buffer.read(cx)) {
1300 Some(BufferEditPrediction::Local { prediction })
1301 } else {
1302 let show_jump = match requested_by {
1303 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1304 requested_by_buffer_id == &buffer.entity_id()
1305 }
1306 PredictionRequestedBy::DiagnosticsUpdate => true,
1307 };
1308
1309 if show_jump {
1310 Some(BufferEditPrediction::Jump { prediction })
1311 } else {
1312 None
1313 }
1314 }
1315 }
1316
1317 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1318 let Some(current_prediction) = self
1319 .projects
1320 .get_mut(&project.entity_id())
1321 .and_then(|project_state| project_state.current_prediction.take())
1322 else {
1323 return;
1324 };
1325
1326 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1327
1328 // can't hold &mut project_state ref across report_changes_for_buffer_call
1329 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1330 return;
1331 };
1332
1333 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1334 project_state.cancel_pending_prediction(pending_prediction, cx);
1335 }
1336
1337 match self.edit_prediction_model {
1338 EditPredictionModel::Sweep => {
1339 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1340 }
1341 EditPredictionModel::Mercury => {
1342 mercury::edit_prediction_accepted(
1343 current_prediction.prediction.id,
1344 self.client.http_client(),
1345 cx,
1346 );
1347 }
1348 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1349 let is_cloud = !matches!(
1350 all_language_settings(None, cx).edit_predictions.provider,
1351 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1352 );
1353 if is_cloud {
1354 zeta::edit_prediction_accepted(self, current_prediction, cx)
1355 }
1356 }
1357 EditPredictionModel::Fim { .. } => {}
1358 }
1359 }
1360
1361 async fn handle_rejected_predictions(
1362 rx: UnboundedReceiver<EditPredictionRejection>,
1363 client: Arc<Client>,
1364 llm_token: LlmApiToken,
1365 app_version: Version,
1366 background_executor: BackgroundExecutor,
1367 ) {
1368 let mut rx = std::pin::pin!(rx.peekable());
1369 let mut batched = Vec::new();
1370
1371 while let Some(rejection) = rx.next().await {
1372 batched.push(rejection);
1373
1374 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1375 select_biased! {
1376 next = rx.as_mut().peek().fuse() => {
1377 if next.is_some() {
1378 continue;
1379 }
1380 }
1381 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1382 }
1383 }
1384
1385 let url = client
1386 .http_client()
1387 .build_zed_llm_url("/predict_edits/reject", &[])
1388 .unwrap();
1389
1390 let flush_count = batched
1391 .len()
1392 // in case items have accumulated after failure
1393 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1394 let start = batched.len() - flush_count;
1395
1396 let body = RejectEditPredictionsBodyRef {
1397 rejections: &batched[start..],
1398 };
1399
1400 let result = Self::send_api_request::<()>(
1401 |builder| {
1402 let req = builder
1403 .uri(url.as_ref())
1404 .body(serde_json::to_string(&body)?.into());
1405 anyhow::Ok(req?)
1406 },
1407 client.clone(),
1408 llm_token.clone(),
1409 app_version.clone(),
1410 true,
1411 )
1412 .await;
1413
1414 if result.log_err().is_some() {
1415 batched.drain(start..);
1416 }
1417 }
1418 }
1419
1420 async fn run_settled_predictions_worker(
1421 this: WeakEntity<Self>,
1422 mut rx: UnboundedReceiver<Instant>,
1423 cx: &mut AsyncApp,
1424 ) {
1425 let mut next_wake_time: Option<Instant> = None;
1426 loop {
1427 let now = cx.background_executor().now();
1428 if let Some(wake_time) = next_wake_time.take() {
1429 cx.background_executor()
1430 .timer(wake_time.duration_since(now))
1431 .await;
1432 } else {
1433 let Some(new_enqueue_time) = rx.next().await else {
1434 break;
1435 };
1436 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1437 while rx.next().now_or_never().flatten().is_some() {}
1438 continue;
1439 }
1440
1441 let Some(this) = this.upgrade() else {
1442 break;
1443 };
1444
1445 let now = cx.background_executor().now();
1446
1447 let mut oldest_edited_at = None;
1448
1449 this.update(cx, |this, _| {
1450 for (_, project_state) in this.projects.iter_mut() {
1451 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1452 registered_buffer
1453 .pending_predictions
1454 .retain_mut(|pending_prediction| {
1455 let age =
1456 now.saturating_duration_since(pending_prediction.enqueued_at);
1457 if age >= EDIT_PREDICTION_SETTLED_TTL {
1458 return false;
1459 }
1460
1461 let quiet_for =
1462 now.saturating_duration_since(pending_prediction.last_edit_at);
1463 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1464 let settled_editable_region = registered_buffer
1465 .snapshot
1466 .text_for_range(
1467 pending_prediction.editable_anchor_range.clone(),
1468 )
1469 .collect::<String>();
1470
1471 #[cfg(test)]
1472 if let Some(callback) = &this.settled_event_callback {
1473 callback(
1474 pending_prediction.request_id.clone(),
1475 settled_editable_region.clone(),
1476 );
1477 }
1478
1479 telemetry::event!(
1480 EDIT_PREDICTION_SETTLED_EVENT,
1481 request_id = pending_prediction.request_id.0.clone(),
1482 settled_editable_region,
1483 );
1484
1485 return false;
1486 }
1487
1488 if oldest_edited_at
1489 .is_none_or(|t| pending_prediction.last_edit_at < t)
1490 {
1491 oldest_edited_at = Some(pending_prediction.last_edit_at);
1492 }
1493
1494 true
1495 });
1496 }
1497 }
1498 });
1499
1500 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1501 }
1502 }
1503
1504 pub(crate) fn enqueue_settled_prediction(
1505 &mut self,
1506 request_id: EditPredictionId,
1507 project: &Entity<Project>,
1508 edited_buffer: &Entity<Buffer>,
1509 edited_buffer_snapshot: &BufferSnapshot,
1510 editable_offset_range: Range<usize>,
1511 cx: &mut Context<Self>,
1512 ) {
1513 let project_state = self.get_or_init_project(project, cx);
1514 if let Some(buffer) = project_state
1515 .registered_buffers
1516 .get_mut(&edited_buffer.entity_id())
1517 {
1518 let now = cx.background_executor().now();
1519 buffer.pending_predictions.push(PendingSettledPrediction {
1520 request_id,
1521 editable_anchor_range: edited_buffer_snapshot
1522 .anchor_range_around(editable_offset_range),
1523 enqueued_at: now,
1524 last_edit_at: now,
1525 });
1526 self.settled_predictions_tx.unbounded_send(now).ok();
1527 }
1528 }
1529
1530 fn reject_current_prediction(
1531 &mut self,
1532 reason: EditPredictionRejectReason,
1533 project: &Entity<Project>,
1534 cx: &App,
1535 ) {
1536 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1537 project_state.pending_predictions.clear();
1538 if let Some(prediction) = project_state.current_prediction.take() {
1539 let model_version = prediction.prediction.model_version.clone();
1540 self.reject_prediction(
1541 prediction.prediction.id,
1542 reason,
1543 prediction.was_shown,
1544 model_version,
1545 cx,
1546 );
1547 }
1548 };
1549 }
1550
1551 fn did_show_current_prediction(
1552 &mut self,
1553 project: &Entity<Project>,
1554 display_type: edit_prediction_types::SuggestionDisplayType,
1555 cx: &mut Context<Self>,
1556 ) {
1557 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1558 return;
1559 };
1560
1561 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1562 return;
1563 };
1564
1565 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1566 let previous_shown_with = current_prediction.shown_with;
1567
1568 if previous_shown_with.is_none() || !is_jump {
1569 current_prediction.shown_with = Some(display_type);
1570 }
1571
1572 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1573
1574 if is_first_non_jump_show {
1575 current_prediction.was_shown = true;
1576 }
1577
1578 let display_type_changed = previous_shown_with != Some(display_type);
1579
1580 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1581 sweep_ai::edit_prediction_shown(
1582 &self.sweep_ai,
1583 self.client.clone(),
1584 ¤t_prediction.prediction,
1585 display_type,
1586 cx,
1587 );
1588 }
1589
1590 if is_first_non_jump_show {
1591 self.shown_predictions
1592 .push_front(current_prediction.prediction.clone());
1593 if self.shown_predictions.len() > 50 {
1594 let completion = self.shown_predictions.pop_back().unwrap();
1595 self.rated_predictions.remove(&completion.id);
1596 }
1597 }
1598 }
1599
1600 fn reject_prediction(
1601 &mut self,
1602 prediction_id: EditPredictionId,
1603 reason: EditPredictionRejectReason,
1604 was_shown: bool,
1605 model_version: Option<String>,
1606 cx: &App,
1607 ) {
1608 match self.edit_prediction_model {
1609 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1610 let is_cloud = !matches!(
1611 all_language_settings(None, cx).edit_predictions.provider,
1612 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1613 );
1614 if is_cloud {
1615 self.reject_predictions_tx
1616 .unbounded_send(EditPredictionRejection {
1617 request_id: prediction_id.to_string(),
1618 reason,
1619 was_shown,
1620 model_version,
1621 })
1622 .log_err();
1623 }
1624 }
1625 EditPredictionModel::Mercury => {
1626 mercury::edit_prediction_rejected(
1627 prediction_id,
1628 was_shown,
1629 reason,
1630 self.client.http_client(),
1631 cx,
1632 );
1633 }
1634 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1635 }
1636 }
1637
1638 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1639 self.projects
1640 .get(&project.entity_id())
1641 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1642 }
1643
1644 pub fn refresh_prediction_from_buffer(
1645 &mut self,
1646 project: Entity<Project>,
1647 buffer: Entity<Buffer>,
1648 position: language::Anchor,
1649 cx: &mut Context<Self>,
1650 ) {
1651 self.queue_prediction_refresh(
1652 project.clone(),
1653 PredictEditsRequestTrigger::Other,
1654 buffer.entity_id(),
1655 cx,
1656 move |this, cx| {
1657 let Some(request_task) = this
1658 .update(cx, |this, cx| {
1659 this.request_prediction(
1660 &project,
1661 &buffer,
1662 position,
1663 PredictEditsRequestTrigger::Other,
1664 cx,
1665 )
1666 })
1667 .log_err()
1668 else {
1669 return Task::ready(anyhow::Ok(None));
1670 };
1671
1672 cx.spawn(async move |_cx| {
1673 request_task.await.map(|prediction_result| {
1674 prediction_result.map(|prediction_result| {
1675 (
1676 prediction_result,
1677 PredictionRequestedBy::Buffer(buffer.entity_id()),
1678 )
1679 })
1680 })
1681 })
1682 },
1683 )
1684 }
1685
1686 pub fn refresh_prediction_from_diagnostics(
1687 &mut self,
1688 project: Entity<Project>,
1689 scope: DiagnosticSearchScope,
1690 cx: &mut Context<Self>,
1691 ) {
1692 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1693 return;
1694 }
1695
1696 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1697 return;
1698 };
1699
1700 // Prefer predictions from buffer
1701 if project_state.current_prediction.is_some() {
1702 return;
1703 }
1704
1705 self.queue_prediction_refresh(
1706 project.clone(),
1707 PredictEditsRequestTrigger::Diagnostics,
1708 project.entity_id(),
1709 cx,
1710 move |this, cx| {
1711 let Some((active_buffer, snapshot, cursor_point)) = this
1712 .read_with(cx, |this, cx| {
1713 let project_state = this.projects.get(&project.entity_id())?;
1714 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1715 let snapshot = buffer.read(cx).snapshot();
1716
1717 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1718 return None;
1719 }
1720
1721 let cursor_point = position
1722 .map(|pos| pos.to_point(&snapshot))
1723 .unwrap_or_default();
1724
1725 Some((buffer, snapshot, cursor_point))
1726 })
1727 .log_err()
1728 .flatten()
1729 else {
1730 return Task::ready(anyhow::Ok(None));
1731 };
1732
1733 cx.spawn(async move |cx| {
1734 let diagnostic_search_range = match scope {
1735 DiagnosticSearchScope::Local => {
1736 let diagnostic_search_start =
1737 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1738 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1739 Point::new(diagnostic_search_start, 0)
1740 ..Point::new(diagnostic_search_end, 0)
1741 }
1742 DiagnosticSearchScope::Global => Default::default(),
1743 };
1744
1745 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1746 active_buffer,
1747 &snapshot,
1748 diagnostic_search_range,
1749 cursor_point,
1750 &project,
1751 cx,
1752 )
1753 .await?
1754 else {
1755 return anyhow::Ok(None);
1756 };
1757
1758 let Some(prediction_result) = this
1759 .update(cx, |this, cx| {
1760 this.request_prediction(
1761 &project,
1762 &jump_buffer,
1763 jump_position,
1764 PredictEditsRequestTrigger::Diagnostics,
1765 cx,
1766 )
1767 })?
1768 .await?
1769 else {
1770 return anyhow::Ok(None);
1771 };
1772
1773 this.update(cx, |this, cx| {
1774 Some((
1775 if this
1776 .get_or_init_project(&project, cx)
1777 .current_prediction
1778 .is_none()
1779 {
1780 prediction_result
1781 } else {
1782 EditPredictionResult {
1783 id: prediction_result.id,
1784 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1785 }
1786 },
1787 PredictionRequestedBy::DiagnosticsUpdate,
1788 ))
1789 })
1790 })
1791 },
1792 );
1793 }
1794
1795 fn predictions_enabled_at(
1796 snapshot: &BufferSnapshot,
1797 position: Option<language::Anchor>,
1798 cx: &App,
1799 ) -> bool {
1800 let file = snapshot.file();
1801 let all_settings = all_language_settings(file, cx);
1802 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1803 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1804 {
1805 return false;
1806 }
1807
1808 if let Some(last_position) = position {
1809 let settings = snapshot.settings_at(last_position, cx);
1810
1811 if !settings.edit_predictions_disabled_in.is_empty()
1812 && let Some(scope) = snapshot.language_scope_at(last_position)
1813 && let Some(scope_name) = scope.override_name()
1814 && settings
1815 .edit_predictions_disabled_in
1816 .iter()
1817 .any(|s| s == scope_name)
1818 {
1819 return false;
1820 }
1821 }
1822
1823 true
1824 }
1825
1826 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1827}
1828
1829fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1830 match provider {
1831 EditPredictionProvider::Zed
1832 | EditPredictionProvider::Sweep
1833 | EditPredictionProvider::Mercury
1834 | EditPredictionProvider::Ollama
1835 | EditPredictionProvider::OpenAiCompatibleApi
1836 | EditPredictionProvider::Experimental(_) => true,
1837 EditPredictionProvider::None
1838 | EditPredictionProvider::Copilot
1839 | EditPredictionProvider::Supermaven
1840 | EditPredictionProvider::Codestral => false,
1841 }
1842}
1843
1844impl EditPredictionStore {
1845 fn queue_prediction_refresh(
1846 &mut self,
1847 project: Entity<Project>,
1848 request_trigger: PredictEditsRequestTrigger,
1849 throttle_entity: EntityId,
1850 cx: &mut Context<Self>,
1851 do_refresh: impl FnOnce(
1852 WeakEntity<Self>,
1853 &mut AsyncApp,
1854 )
1855 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1856 + 'static,
1857 ) {
1858 fn select_throttle(
1859 project_state: &mut ProjectState,
1860 request_trigger: PredictEditsRequestTrigger,
1861 ) -> &mut Option<(EntityId, Instant)> {
1862 match request_trigger {
1863 PredictEditsRequestTrigger::Diagnostics => {
1864 &mut project_state.last_jump_prediction_refresh
1865 }
1866 _ => &mut project_state.last_edit_prediction_refresh,
1867 }
1868 }
1869
1870 let (needs_acceptance_tracking, max_pending_predictions) =
1871 match all_language_settings(None, cx).edit_predictions.provider {
1872 EditPredictionProvider::Zed
1873 | EditPredictionProvider::Sweep
1874 | EditPredictionProvider::Mercury
1875 | EditPredictionProvider::Experimental(_) => (true, 2),
1876 EditPredictionProvider::Ollama => (false, 1),
1877 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1878 EditPredictionProvider::None
1879 | EditPredictionProvider::Copilot
1880 | EditPredictionProvider::Supermaven
1881 | EditPredictionProvider::Codestral => {
1882 log::error!("queue_prediction_refresh called with non-store provider");
1883 return;
1884 }
1885 };
1886
1887 let drop_on_cancel = !needs_acceptance_tracking;
1888 let throttle_timeout = Self::THROTTLE_TIMEOUT;
1889 let project_state = self.get_or_init_project(&project, cx);
1890 let pending_prediction_id = project_state.next_pending_prediction_id;
1891 project_state.next_pending_prediction_id += 1;
1892 let last_request = *select_throttle(project_state, request_trigger);
1893
1894 let task = cx.spawn(async move |this, cx| {
1895 if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1896 if throttle_entity != last_entity {
1897 return None;
1898 }
1899 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1900 }) {
1901 cx.background_executor().timer(timeout).await;
1902 }
1903
1904 // If this task was cancelled before the throttle timeout expired,
1905 // do not perform a request.
1906 let mut is_cancelled = true;
1907 this.update(cx, |this, cx| {
1908 let project_state = this.get_or_init_project(&project, cx);
1909 let was_cancelled = project_state
1910 .cancelled_predictions
1911 .remove(&pending_prediction_id);
1912 if !was_cancelled {
1913 let new_refresh = (throttle_entity, Instant::now());
1914 *select_throttle(project_state, request_trigger) = Some(new_refresh);
1915 is_cancelled = false;
1916 }
1917 })
1918 .ok();
1919 if is_cancelled {
1920 return None;
1921 }
1922
1923 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1924 let new_prediction_id = new_prediction_result
1925 .as_ref()
1926 .map(|(prediction, _)| prediction.id.clone());
1927
1928 // When a prediction completes, remove it from the pending list, and cancel
1929 // any pending predictions that were enqueued before it.
1930 this.update(cx, |this, cx| {
1931 let project_state = this.get_or_init_project(&project, cx);
1932
1933 let is_cancelled = project_state
1934 .cancelled_predictions
1935 .remove(&pending_prediction_id);
1936
1937 let new_current_prediction = if !is_cancelled
1938 && let Some((prediction_result, requested_by)) = new_prediction_result
1939 {
1940 match prediction_result.prediction {
1941 Ok(prediction) => {
1942 let new_prediction = CurrentEditPrediction {
1943 requested_by,
1944 prediction,
1945 was_shown: false,
1946 shown_with: None,
1947 };
1948
1949 if let Some(current_prediction) =
1950 project_state.current_prediction.as_ref()
1951 {
1952 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1953 {
1954 this.reject_current_prediction(
1955 EditPredictionRejectReason::Replaced,
1956 &project,
1957 cx,
1958 );
1959
1960 Some(new_prediction)
1961 } else {
1962 this.reject_prediction(
1963 new_prediction.prediction.id,
1964 EditPredictionRejectReason::CurrentPreferred,
1965 false,
1966 new_prediction.prediction.model_version,
1967 cx,
1968 );
1969 None
1970 }
1971 } else {
1972 Some(new_prediction)
1973 }
1974 }
1975 Err(reject_reason) => {
1976 this.reject_prediction(
1977 prediction_result.id,
1978 reject_reason,
1979 false,
1980 None,
1981 cx,
1982 );
1983 None
1984 }
1985 }
1986 } else {
1987 None
1988 };
1989
1990 let project_state = this.get_or_init_project(&project, cx);
1991
1992 if let Some(new_prediction) = new_current_prediction {
1993 project_state.current_prediction = Some(new_prediction);
1994 }
1995
1996 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1997 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1998 if pending_prediction.id == pending_prediction_id {
1999 pending_predictions.remove(ix);
2000 for pending_prediction in pending_predictions.drain(0..ix) {
2001 project_state.cancel_pending_prediction(pending_prediction, cx)
2002 }
2003 break;
2004 }
2005 }
2006 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2007 cx.notify();
2008 })
2009 .ok();
2010
2011 new_prediction_id
2012 });
2013
2014 if project_state.pending_predictions.len() < max_pending_predictions {
2015 project_state.pending_predictions.push(PendingPrediction {
2016 id: pending_prediction_id,
2017 task,
2018 drop_on_cancel,
2019 });
2020 } else {
2021 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2022 project_state.pending_predictions.push(PendingPrediction {
2023 id: pending_prediction_id,
2024 task,
2025 drop_on_cancel,
2026 });
2027 project_state.cancel_pending_prediction(pending_prediction, cx);
2028 }
2029 }
2030
2031 pub fn request_prediction(
2032 &mut self,
2033 project: &Entity<Project>,
2034 active_buffer: &Entity<Buffer>,
2035 position: language::Anchor,
2036 trigger: PredictEditsRequestTrigger,
2037 cx: &mut Context<Self>,
2038 ) -> Task<Result<Option<EditPredictionResult>>> {
2039 self.request_prediction_internal(
2040 project.clone(),
2041 active_buffer.clone(),
2042 position,
2043 trigger,
2044 cx.has_flag::<Zeta2FeatureFlag>(),
2045 cx,
2046 )
2047 }
2048
2049 fn request_prediction_internal(
2050 &mut self,
2051 project: Entity<Project>,
2052 active_buffer: Entity<Buffer>,
2053 position: language::Anchor,
2054 trigger: PredictEditsRequestTrigger,
2055 allow_jump: bool,
2056 cx: &mut Context<Self>,
2057 ) -> Task<Result<Option<EditPredictionResult>>> {
2058 self.get_or_init_project(&project, cx);
2059 let project_state = self.projects.get(&project.entity_id()).unwrap();
2060 let stored_events = project_state.events(cx);
2061 let has_events = !stored_events.is_empty();
2062 let events: Vec<Arc<zeta_prompt::Event>> =
2063 stored_events.into_iter().map(|e| e.event).collect();
2064 let debug_tx = project_state.debug_tx.clone();
2065
2066 let snapshot = active_buffer.read(cx).snapshot();
2067 let cursor_point = position.to_point(&snapshot);
2068 let current_offset = position.to_offset(&snapshot);
2069
2070 let mut user_actions: Vec<UserActionRecord> =
2071 project_state.user_actions.iter().cloned().collect();
2072
2073 if let Some(last_action) = user_actions.last() {
2074 if last_action.buffer_id == active_buffer.entity_id()
2075 && current_offset != last_action.offset
2076 {
2077 let timestamp_epoch_ms = SystemTime::now()
2078 .duration_since(UNIX_EPOCH)
2079 .map(|d| d.as_millis() as u64)
2080 .unwrap_or(0);
2081 user_actions.push(UserActionRecord {
2082 action_type: UserActionType::CursorMovement,
2083 buffer_id: active_buffer.entity_id(),
2084 line_number: cursor_point.row,
2085 offset: current_offset,
2086 timestamp_epoch_ms,
2087 });
2088 }
2089 }
2090 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2091 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2092 let diagnostic_search_range =
2093 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2094
2095 let related_files = self.context_for_project(&project, cx);
2096
2097 let inputs = EditPredictionModelInput {
2098 project: project.clone(),
2099 buffer: active_buffer,
2100 snapshot: snapshot,
2101 position,
2102 events,
2103 related_files,
2104 recent_paths: project_state.recent_paths.clone(),
2105 trigger,
2106 diagnostic_search_range: diagnostic_search_range,
2107 debug_tx,
2108 user_actions,
2109 };
2110
2111 let task = match self.edit_prediction_model {
2112 EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
2113 self,
2114 inputs,
2115 Some(zeta_prompt::EditPredictionModelKind::Zeta1),
2116 cx,
2117 ),
2118 EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
2119 self,
2120 inputs,
2121 Some(zeta_prompt::EditPredictionModelKind::Zeta2),
2122 cx,
2123 ),
2124 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2125 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2126 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2127 };
2128
2129 cx.spawn(async move |this, cx| {
2130 let prediction = task.await?;
2131
2132 if prediction.is_none() && allow_jump && has_events {
2133 this.update(cx, |this, cx| {
2134 this.refresh_prediction_from_diagnostics(
2135 project,
2136 DiagnosticSearchScope::Local,
2137 cx,
2138 );
2139 })?;
2140 return anyhow::Ok(None);
2141 }
2142
2143 Ok(prediction)
2144 })
2145 }
2146
2147 pub(crate) async fn next_diagnostic_location(
2148 active_buffer: Entity<Buffer>,
2149 active_buffer_snapshot: &BufferSnapshot,
2150 active_buffer_diagnostic_search_range: Range<Point>,
2151 active_buffer_cursor_point: Point,
2152 project: &Entity<Project>,
2153 cx: &mut AsyncApp,
2154 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2155 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2156 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2157 .flat_map(|(_, _, _, selections)| {
2158 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2159 })
2160 .collect();
2161
2162 let mut jump_location = active_buffer_snapshot
2163 .diagnostic_groups(None)
2164 .into_iter()
2165 .filter_map(|(_, group)| {
2166 let range = &group.entries[group.primary_ix]
2167 .range
2168 .to_point(&active_buffer_snapshot);
2169 if range.overlaps(&active_buffer_diagnostic_search_range) {
2170 return None;
2171 }
2172 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2173 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2174 });
2175 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2176 <= DIAGNOSTIC_LINES_RANGE;
2177 if near_collaborator && !near_local {
2178 return None;
2179 }
2180 Some(range.start)
2181 })
2182 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2183 .map(|position| {
2184 (
2185 active_buffer.clone(),
2186 active_buffer_snapshot.anchor_before(position),
2187 )
2188 });
2189
2190 if jump_location.is_none() {
2191 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2192 let file = buffer.file()?;
2193
2194 Some(ProjectPath {
2195 worktree_id: file.worktree_id(cx),
2196 path: file.path().clone(),
2197 })
2198 });
2199
2200 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2201 project
2202 .diagnostic_summaries(false, cx)
2203 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2204 .map(|(path, _, _)| {
2205 let shared_prefix = path
2206 .path
2207 .components()
2208 .zip(
2209 active_buffer_path
2210 .as_ref()
2211 .map(|p| p.path.components())
2212 .unwrap_or_default(),
2213 )
2214 .take_while(|(a, b)| a == b)
2215 .count();
2216 (path, shared_prefix)
2217 })
2218 .collect()
2219 });
2220
2221 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2222
2223 for (path, _) in candidates {
2224 let candidate_buffer = project
2225 .update(cx, |project, cx| project.open_buffer(path, cx))
2226 .await?;
2227
2228 let (has_collaborators, diagnostic_position) =
2229 candidate_buffer.read_with(cx, |buffer, _cx| {
2230 let snapshot = buffer.snapshot();
2231 let has_collaborators = snapshot
2232 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2233 .next()
2234 .is_some();
2235 let position = buffer
2236 .buffer_diagnostics(None)
2237 .into_iter()
2238 .min_by_key(|entry| entry.diagnostic.severity)
2239 .map(|entry| entry.range.start);
2240 (has_collaborators, position)
2241 });
2242
2243 if has_collaborators {
2244 continue;
2245 }
2246
2247 if let Some(position) = diagnostic_position {
2248 jump_location = Some((candidate_buffer, position));
2249 break;
2250 }
2251 }
2252 }
2253
2254 anyhow::Ok(jump_location)
2255 }
2256
2257 async fn send_raw_llm_request(
2258 request: RawCompletionRequest,
2259 client: Arc<Client>,
2260 custom_url: Option<Arc<Url>>,
2261 llm_token: LlmApiToken,
2262 app_version: Version,
2263 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2264 let url = if let Some(custom_url) = custom_url {
2265 custom_url.as_ref().clone()
2266 } else {
2267 client
2268 .http_client()
2269 .build_zed_llm_url("/predict_edits/raw", &[])?
2270 };
2271
2272 Self::send_api_request(
2273 |builder| {
2274 let req = builder
2275 .uri(url.as_ref())
2276 .body(serde_json::to_string(&request)?.into());
2277 Ok(req?)
2278 },
2279 client,
2280 llm_token,
2281 app_version,
2282 true,
2283 )
2284 .await
2285 }
2286
2287 pub(crate) async fn send_v3_request(
2288 input: ZetaPromptInput,
2289 client: Arc<Client>,
2290 llm_token: LlmApiToken,
2291 app_version: Version,
2292 trigger: PredictEditsRequestTrigger,
2293 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2294 let url = client
2295 .http_client()
2296 .build_zed_llm_url("/predict_edits/v3", &[])?;
2297
2298 let request = PredictEditsV3Request { input, trigger };
2299
2300 let json_bytes = serde_json::to_vec(&request)?;
2301 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2302
2303 Self::send_api_request(
2304 |builder| {
2305 let req = builder
2306 .uri(url.as_ref())
2307 .header("Content-Encoding", "zstd")
2308 .body(compressed.clone().into());
2309 Ok(req?)
2310 },
2311 client,
2312 llm_token,
2313 app_version,
2314 true,
2315 )
2316 .await
2317 }
2318
2319 fn handle_api_response<T>(
2320 this: &WeakEntity<Self>,
2321 response: Result<(T, Option<EditPredictionUsage>)>,
2322 cx: &mut gpui::AsyncApp,
2323 ) -> Result<T> {
2324 match response {
2325 Ok((data, usage)) => {
2326 if let Some(usage) = usage {
2327 this.update(cx, |this, cx| {
2328 this.user_store.update(cx, |user_store, cx| {
2329 user_store.update_edit_prediction_usage(usage, cx);
2330 });
2331 })
2332 .ok();
2333 }
2334 Ok(data)
2335 }
2336 Err(err) => {
2337 if err.is::<ZedUpdateRequiredError>() {
2338 cx.update(|cx| {
2339 this.update(cx, |this, _cx| {
2340 this.update_required = true;
2341 })
2342 .ok();
2343
2344 let error_message: SharedString = err.to_string().into();
2345 show_app_notification(
2346 NotificationId::unique::<ZedUpdateRequiredError>(),
2347 cx,
2348 move |cx| {
2349 cx.new(|cx| {
2350 ErrorMessagePrompt::new(error_message.clone(), cx)
2351 .with_link_button("Update Zed", "https://zed.dev/releases")
2352 })
2353 },
2354 );
2355 });
2356 }
2357 Err(err)
2358 }
2359 }
2360 }
2361
2362 async fn send_api_request<Res>(
2363 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2364 client: Arc<Client>,
2365 llm_token: LlmApiToken,
2366 app_version: Version,
2367 require_auth: bool,
2368 ) -> Result<(Res, Option<EditPredictionUsage>)>
2369 where
2370 Res: DeserializeOwned,
2371 {
2372 let http_client = client.http_client();
2373
2374 let mut token = if require_auth {
2375 Some(llm_token.acquire(&client).await?)
2376 } else {
2377 llm_token.acquire(&client).await.ok()
2378 };
2379 let mut did_retry = false;
2380
2381 loop {
2382 let request_builder = http_client::Request::builder().method(Method::POST);
2383
2384 let mut request_builder = request_builder
2385 .header("Content-Type", "application/json")
2386 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2387
2388 // Only add Authorization header if we have a token
2389 if let Some(ref token_value) = token {
2390 request_builder =
2391 request_builder.header("Authorization", format!("Bearer {}", token_value));
2392 }
2393
2394 let request = build(request_builder)?;
2395
2396 let mut response = http_client.send(request).await?;
2397
2398 if let Some(minimum_required_version) = response
2399 .headers()
2400 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2401 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2402 {
2403 anyhow::ensure!(
2404 app_version >= minimum_required_version,
2405 ZedUpdateRequiredError {
2406 minimum_version: minimum_required_version
2407 }
2408 );
2409 }
2410
2411 if response.status().is_success() {
2412 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2413
2414 let mut body = Vec::new();
2415 response.body_mut().read_to_end(&mut body).await?;
2416 return Ok((serde_json::from_slice(&body)?, usage));
2417 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2418 did_retry = true;
2419 token = Some(llm_token.refresh(&client).await?);
2420 } else {
2421 let mut body = String::new();
2422 response.body_mut().read_to_string(&mut body).await?;
2423 anyhow::bail!(
2424 "Request failed with status: {:?}\nBody: {}",
2425 response.status(),
2426 body
2427 );
2428 }
2429 }
2430 }
2431
2432 pub fn refresh_context(
2433 &mut self,
2434 project: &Entity<Project>,
2435 buffer: &Entity<language::Buffer>,
2436 cursor_position: language::Anchor,
2437 cx: &mut Context<Self>,
2438 ) {
2439 self.get_or_init_project(project, cx)
2440 .context
2441 .update(cx, |store, cx| {
2442 store.refresh(buffer.clone(), cursor_position, cx);
2443 });
2444 }
2445
2446 #[cfg(feature = "cli-support")]
2447 pub fn set_context_for_buffer(
2448 &mut self,
2449 project: &Entity<Project>,
2450 related_files: Vec<RelatedFile>,
2451 cx: &mut Context<Self>,
2452 ) {
2453 self.get_or_init_project(project, cx)
2454 .context
2455 .update(cx, |store, cx| {
2456 store.set_related_files(related_files, cx);
2457 });
2458 }
2459
2460 #[cfg(feature = "cli-support")]
2461 pub fn set_recent_paths_for_project(
2462 &mut self,
2463 project: &Entity<Project>,
2464 paths: impl IntoIterator<Item = project::ProjectPath>,
2465 cx: &mut Context<Self>,
2466 ) {
2467 let project_state = self.get_or_init_project(project, cx);
2468 project_state.recent_paths = paths.into_iter().collect();
2469 }
2470
2471 fn is_file_open_source(
2472 &self,
2473 project: &Entity<Project>,
2474 file: &Arc<dyn File>,
2475 cx: &App,
2476 ) -> bool {
2477 if !file.is_local() || file.is_private() {
2478 return false;
2479 }
2480 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2481 return false;
2482 };
2483 project_state
2484 .license_detection_watchers
2485 .get(&file.worktree_id(cx))
2486 .as_ref()
2487 .is_some_and(|watcher| watcher.is_project_open_source())
2488 }
2489
2490 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2491 self.data_collection_choice.is_enabled(cx)
2492 }
2493
2494 fn load_data_collection_choice() -> DataCollectionChoice {
2495 let choice = KEY_VALUE_STORE
2496 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2497 .log_err()
2498 .flatten();
2499
2500 match choice.as_deref() {
2501 Some("true") => DataCollectionChoice::Enabled,
2502 Some("false") => DataCollectionChoice::Disabled,
2503 Some(_) => {
2504 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2505 DataCollectionChoice::NotAnswered
2506 }
2507 None => DataCollectionChoice::NotAnswered,
2508 }
2509 }
2510
2511 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2512 self.data_collection_choice = self.data_collection_choice.toggle();
2513 let new_choice = self.data_collection_choice;
2514 let is_enabled = new_choice.is_enabled(cx);
2515 db::write_and_log(cx, move || {
2516 KEY_VALUE_STORE.write_kvp(
2517 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2518 is_enabled.to_string(),
2519 )
2520 });
2521 }
2522
2523 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2524 self.shown_predictions.iter()
2525 }
2526
2527 pub fn shown_completions_len(&self) -> usize {
2528 self.shown_predictions.len()
2529 }
2530
2531 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2532 self.rated_predictions.contains(id)
2533 }
2534
2535 pub fn rate_prediction(
2536 &mut self,
2537 prediction: &EditPrediction,
2538 rating: EditPredictionRating,
2539 feedback: String,
2540 cx: &mut Context<Self>,
2541 ) {
2542 let organization = self.user_store.read(cx).current_organization();
2543
2544 self.rated_predictions.insert(prediction.id.clone());
2545
2546 cx.background_spawn({
2547 let client = self.client.clone();
2548 let prediction_id = prediction.id.to_string();
2549 let inputs = serde_json::to_value(&prediction.inputs);
2550 let output = prediction
2551 .edit_preview
2552 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2553 async move {
2554 client
2555 .cloud_client()
2556 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2557 organization_id: organization.map(|organization| organization.id.clone()),
2558 request_id: prediction_id,
2559 rating: match rating {
2560 EditPredictionRating::Positive => "positive".to_string(),
2561 EditPredictionRating::Negative => "negative".to_string(),
2562 },
2563 inputs: inputs?,
2564 output,
2565 feedback,
2566 })
2567 .await?;
2568
2569 anyhow::Ok(())
2570 }
2571 })
2572 .detach_and_log_err(cx);
2573
2574 cx.notify();
2575 }
2576}
2577
2578fn merge_trailing_events_if_needed(
2579 events: &mut VecDeque<StoredEvent>,
2580 end_snapshot: &TextBufferSnapshot,
2581 latest_snapshot: &TextBufferSnapshot,
2582 latest_edit_range: &Range<Anchor>,
2583) {
2584 if let Some(last_event) = events.back() {
2585 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2586 return;
2587 }
2588 }
2589
2590 let mut next_old_event = None;
2591 let mut mergeable_count = 0;
2592 for old_event in events.iter().rev() {
2593 if let Some(next_old_event) = &next_old_event
2594 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2595 {
2596 break;
2597 }
2598 mergeable_count += 1;
2599 next_old_event = Some(old_event);
2600 }
2601
2602 if mergeable_count <= 1 {
2603 return;
2604 }
2605
2606 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2607 let oldest_event = events_to_merge.peek().unwrap();
2608 let oldest_snapshot = oldest_event.old_snapshot.clone();
2609
2610 if let Some((diff, edited_range)) =
2611 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2612 {
2613 let merged_event = match oldest_event.event.as_ref() {
2614 zeta_prompt::Event::BufferChange {
2615 old_path,
2616 path,
2617 in_open_source_repo,
2618 ..
2619 } => StoredEvent {
2620 event: Arc::new(zeta_prompt::Event::BufferChange {
2621 old_path: old_path.clone(),
2622 path: path.clone(),
2623 diff,
2624 in_open_source_repo: *in_open_source_repo,
2625 predicted: events_to_merge.all(|e| {
2626 matches!(
2627 e.event.as_ref(),
2628 zeta_prompt::Event::BufferChange {
2629 predicted: true,
2630 ..
2631 }
2632 )
2633 }),
2634 }),
2635 old_snapshot: oldest_snapshot.clone(),
2636 edit_range: end_snapshot.anchor_before(edited_range.start)
2637 ..end_snapshot.anchor_before(edited_range.end),
2638 },
2639 };
2640 events.truncate(events.len() - mergeable_count);
2641 events.push_back(merged_event);
2642 }
2643}
2644
2645pub(crate) fn filter_redundant_excerpts(
2646 mut related_files: Vec<RelatedFile>,
2647 cursor_path: &Path,
2648 cursor_row_range: Range<u32>,
2649) -> Vec<RelatedFile> {
2650 for file in &mut related_files {
2651 if file.path.as_ref() == cursor_path {
2652 file.excerpts.retain(|excerpt| {
2653 excerpt.row_range.start < cursor_row_range.start
2654 || excerpt.row_range.end > cursor_row_range.end
2655 });
2656 }
2657 }
2658 related_files.retain(|file| !file.excerpts.is_empty());
2659 related_files
2660}
2661
2662#[derive(Error, Debug)]
2663#[error(
2664 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2665)]
2666pub struct ZedUpdateRequiredError {
2667 minimum_version: Version,
2668}
2669
2670#[derive(Debug, Clone, Copy)]
2671pub enum DataCollectionChoice {
2672 NotAnswered,
2673 Enabled,
2674 Disabled,
2675}
2676
2677impl DataCollectionChoice {
2678 pub fn is_enabled(self, cx: &App) -> bool {
2679 if cx.is_staff() {
2680 return true;
2681 }
2682 match self {
2683 Self::Enabled => true,
2684 Self::NotAnswered | Self::Disabled => false,
2685 }
2686 }
2687
2688 #[must_use]
2689 pub fn toggle(&self) -> DataCollectionChoice {
2690 match self {
2691 Self::Enabled => Self::Disabled,
2692 Self::Disabled => Self::Enabled,
2693 Self::NotAnswered => Self::Enabled,
2694 }
2695 }
2696}
2697
2698impl From<bool> for DataCollectionChoice {
2699 fn from(value: bool) -> Self {
2700 match value {
2701 true => DataCollectionChoice::Enabled,
2702 false => DataCollectionChoice::Disabled,
2703 }
2704 }
2705}
2706
2707struct ZedPredictUpsell;
2708
2709impl Dismissable for ZedPredictUpsell {
2710 const KEY: &'static str = "dismissed-edit-predict-upsell";
2711
2712 fn dismissed() -> bool {
2713 // To make this backwards compatible with older versions of Zed, we
2714 // check if the user has seen the previous Edit Prediction Onboarding
2715 // before, by checking the data collection choice which was written to
2716 // the database once the user clicked on "Accept and Enable"
2717 if KEY_VALUE_STORE
2718 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2719 .log_err()
2720 .is_some_and(|s| s.is_some())
2721 {
2722 return true;
2723 }
2724
2725 KEY_VALUE_STORE
2726 .read_kvp(Self::KEY)
2727 .log_err()
2728 .is_some_and(|s| s.is_some())
2729 }
2730}
2731
2732pub fn should_show_upsell_modal() -> bool {
2733 !ZedPredictUpsell::dismissed()
2734}
2735
2736pub fn init(cx: &mut App) {
2737 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2738 workspace.register_action(
2739 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2740 ZedPredictModal::toggle(
2741 workspace,
2742 workspace.user_store().clone(),
2743 workspace.client().clone(),
2744 window,
2745 cx,
2746 )
2747 },
2748 );
2749
2750 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2751 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2752 settings
2753 .project
2754 .all_languages
2755 .edit_predictions
2756 .get_or_insert_default()
2757 .provider = Some(EditPredictionProvider::None)
2758 });
2759 });
2760 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2761 EditPredictionStore::try_global(cx).and_then(|store| {
2762 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2763 })
2764 }
2765
2766 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2767 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2768 copilot_ui::initiate_sign_in(copilot, window, cx);
2769 }
2770 });
2771 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2772 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2773 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2774 }
2775 });
2776 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2777 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2778 copilot_ui::initiate_sign_out(copilot, window, cx);
2779 }
2780 });
2781 })
2782 .detach();
2783}