1use anyhow::Result;
2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
3use cloud_api_client::LlmApiToken;
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PREDICT_EDITS_MODE_HEADER_NAME, PredictEditsMode, PredictEditsV3Request,
7 PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
8};
9use cloud_llm_client::{
10 EditPredictionRejectReason, EditPredictionRejection,
11 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
12 PREFERRED_EXPERIMENT_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
13 ZED_VERSION_HEADER_NAME,
14};
15use collections::{HashMap, HashSet};
16use copilot::{Copilot, Reinstall, SignIn, SignOut};
17use credentials_provider::CredentialsProvider;
18use db::kvp::{Dismissable, KeyValueStore};
19use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PresenceFlag, register_feature_flag};
21use futures::{
22 AsyncReadExt as _, FutureExt as _, StreamExt as _,
23 channel::mpsc::{self, UnboundedReceiver},
24 select_biased,
25};
26use gpui::BackgroundExecutor;
27use gpui::http_client::Url;
28use gpui::{
29 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
30 http_client::{self, AsyncBody, Method},
31 prelude::*,
32};
33use heapless::Vec as ArrayVec;
34use language::{
35 Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point,
36 TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings,
37};
38use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
39use release_channel::AppVersion;
40use semver::Version;
41use serde::de::DeserializeOwned;
42use settings::{
43 EditPredictionDataCollectionChoice, EditPredictionPromptFormat, EditPredictionProvider,
44 Settings as _, update_settings_file,
45};
46use std::collections::{VecDeque, hash_map};
47use std::env;
48use text::{AnchorRangeExt, Edit};
49use workspace::{AppState, Workspace};
50use zeta_prompt::{ZetaFormat, ZetaPromptInput};
51
52use std::mem;
53use std::ops::Range;
54use std::path::Path;
55use std::rc::Rc;
56use std::str::FromStr as _;
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59
60use thiserror::Error;
61use util::{RangeExt as _, ResultExt as _};
62
63pub mod cursor_excerpt;
64pub mod example_spec;
65pub mod fim;
66mod license_detection;
67pub mod mercury;
68pub mod metrics;
69pub mod ollama;
70mod onboarding_modal;
71pub mod open_ai_response;
72mod prediction;
73
74pub mod udiff;
75
76mod capture_example;
77pub mod open_ai_compatible;
78mod zed_edit_prediction_delegate;
79pub mod zeta;
80
81#[cfg(test)]
82mod edit_prediction_tests;
83
84use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
85use crate::example_spec::ExampleSpec;
86use crate::license_detection::LicenseDetectionWatcher;
87use crate::mercury::Mercury;
88pub use crate::metrics::{KeptRateResult, compute_kept_rate};
89use crate::onboarding_modal::ZedPredictModal;
90pub use crate::prediction::EditPrediction;
91pub use crate::prediction::EditPredictionId;
92use crate::prediction::EditPredictionResult;
93pub use capture_example::capture_example;
94pub use language_model::ApiKeyState;
95pub use telemetry_events::EditPredictionRating;
96pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
97
98actions!(
99 edit_prediction,
100 [
101 /// Resets the edit prediction onboarding state.
102 ResetOnboarding,
103 /// Clears the edit prediction history.
104 ClearHistory,
105 ]
106);
107
108/// Maximum number of events to track.
109const EVENT_COUNT_MAX: usize = 10;
110const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
111const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
112const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
113const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
114const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
115const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
116const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
117const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
118const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
119
120pub struct EditPredictionJumpsFeatureFlag;
121
122impl FeatureFlag for EditPredictionJumpsFeatureFlag {
123 const NAME: &'static str = "edit_prediction_jumps";
124 type Value = PresenceFlag;
125}
126register_feature_flag!(EditPredictionJumpsFeatureFlag);
127
128#[derive(Clone)]
129struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
130
131impl Global for EditPredictionStoreGlobal {}
132
133/// Configuration for using the raw Zeta2 endpoint.
134/// When set, the client uses the raw endpoint and constructs the prompt itself.
135/// The version is also used as the Baseten environment name (lowercased).
136#[derive(Clone)]
137pub struct Zeta2RawConfig {
138 pub model_id: Option<String>,
139 pub environment: Option<String>,
140 pub format: ZetaFormat,
141}
142
143pub struct EditPredictionStore {
144 client: Arc<Client>,
145 user_store: Entity<UserStore>,
146 llm_token: LlmApiToken,
147 _fetch_experiments_task: Task<()>,
148 projects: HashMap<EntityId, ProjectState>,
149 update_required: bool,
150 edit_prediction_model: EditPredictionModel,
151 zeta2_raw_config: Option<Zeta2RawConfig>,
152 preferred_experiment: Option<String>,
153 available_experiments: Vec<String>,
154 pub mercury: Mercury,
155 legacy_data_collection_enabled: bool,
156 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
157 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
158 shown_predictions: VecDeque<EditPrediction>,
159 rated_predictions: HashSet<EditPredictionId>,
160 #[cfg(test)]
161 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
162 credentials_provider: Arc<dyn CredentialsProvider>,
163}
164
165pub(crate) struct EditPredictionRejectionPayload {
166 rejection: EditPredictionRejection,
167 organization_id: Option<OrganizationId>,
168}
169
170#[derive(Copy, Clone, PartialEq, Eq)]
171pub enum EditPredictionModel {
172 Zeta,
173 Fim { format: EditPredictionPromptFormat },
174 Mercury,
175}
176
177#[derive(Clone)]
178pub struct EditPredictionModelInput {
179 project: Entity<Project>,
180 buffer: Entity<Buffer>,
181 snapshot: BufferSnapshot,
182 position: Anchor,
183 events: Vec<Arc<zeta_prompt::Event>>,
184 related_files: Vec<RelatedFile>,
185 mode: PredictEditsMode,
186 trigger: PredictEditsRequestTrigger,
187 diagnostic_search_range: Range<Point>,
188 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
189 can_collect_data: bool,
190 is_open_source: bool,
191}
192
193#[derive(Debug)]
194pub enum DebugEvent {
195 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
196 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
197 EditPredictionStarted(EditPredictionStartedDebugEvent),
198 EditPredictionFinished(EditPredictionFinishedDebugEvent),
199}
200
201#[derive(Debug)]
202pub struct ContextRetrievalStartedDebugEvent {
203 pub project_entity_id: EntityId,
204 pub timestamp: Instant,
205 pub search_prompt: String,
206}
207
208#[derive(Debug)]
209pub struct ContextRetrievalFinishedDebugEvent {
210 pub project_entity_id: EntityId,
211 pub timestamp: Instant,
212 pub metadata: Vec<(&'static str, SharedString)>,
213}
214
215#[derive(Debug)]
216pub struct EditPredictionStartedDebugEvent {
217 pub buffer: WeakEntity<Buffer>,
218 pub position: Anchor,
219 pub prompt: Option<String>,
220}
221
222#[derive(Debug)]
223pub struct EditPredictionFinishedDebugEvent {
224 pub buffer: WeakEntity<Buffer>,
225 pub position: Anchor,
226 pub model_output: Option<String>,
227}
228
229/// An event with associated metadata for reconstructing buffer state.
230#[derive(Clone)]
231pub struct StoredEvent {
232 pub event: Arc<zeta_prompt::Event>,
233 pub old_snapshot: TextBufferSnapshot,
234 pub new_snapshot_version: clock::Global,
235 pub total_edit_range: Range<Anchor>,
236}
237
238impl StoredEvent {
239 fn can_merge(
240 &self,
241 next_old_event: &StoredEvent,
242 latest_snapshot: &TextBufferSnapshot,
243 latest_edit_range: &Range<Anchor>,
244 ) -> bool {
245 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
246 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
247 return false;
248 }
249 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
250 return false;
251 }
252 if self.new_snapshot_version != next_old_event.old_snapshot.version {
253 return false;
254 }
255 if !latest_snapshot
256 .version
257 .observed_all(&next_old_event.new_snapshot_version)
258 {
259 return false;
260 }
261
262 let a_is_predicted = matches!(
263 self.event.as_ref(),
264 zeta_prompt::Event::BufferChange {
265 predicted: true,
266 ..
267 }
268 );
269 let b_is_predicted = matches!(
270 next_old_event.event.as_ref(),
271 zeta_prompt::Event::BufferChange {
272 predicted: true,
273 ..
274 }
275 );
276
277 // If events come from the same source (both predicted or both manual) then
278 // we would have coalesced them already.
279 if a_is_predicted == b_is_predicted {
280 return false;
281 }
282
283 let left_range = self.total_edit_range.to_point(latest_snapshot);
284 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
285 let latest_range = latest_edit_range.to_point(latest_snapshot);
286
287 // Events near to the latest edit are not merged if their sources differ.
288 if lines_between_ranges(&left_range, &latest_range)
289 .min(lines_between_ranges(&right_range, &latest_range))
290 <= CHANGE_GROUPING_LINE_SPAN
291 {
292 return false;
293 }
294
295 // Events that are distant from each other are not merged.
296 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
297 return false;
298 }
299
300 true
301 }
302}
303
304fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
305 if left.start > right.end {
306 return left.start.row - right.end.row;
307 }
308 if right.start > left.end {
309 return right.start.row - left.end.row;
310 }
311 0
312}
313
314struct ProjectState {
315 events: VecDeque<StoredEvent>,
316 last_event: Option<LastEvent>,
317 recent_paths: VecDeque<ProjectPath>,
318 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
319 current_prediction: Option<CurrentEditPrediction>,
320 next_pending_prediction_id: usize,
321 pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
322 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
323 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
324 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
325 cancelled_predictions: HashSet<usize>,
326 context: Entity<RelatedExcerptStore>,
327 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
328 _subscriptions: [gpui::Subscription; 2],
329 copilot: Option<Entity<Copilot>>,
330}
331
332impl ProjectState {
333 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
334 self.events
335 .iter()
336 .cloned()
337 .chain(self.last_event.as_ref().iter().flat_map(|event| {
338 let (one, two) = event.split_by_pause();
339 let one = one.finalize(&self.license_detection_watchers, cx);
340 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
341 one.into_iter().chain(two)
342 }))
343 .collect()
344 }
345
346 fn cancel_pending_prediction(
347 &mut self,
348 pending_prediction: PendingPrediction,
349 cx: &mut Context<EditPredictionStore>,
350 ) {
351 self.cancelled_predictions.insert(pending_prediction.id);
352
353 if pending_prediction.drop_on_cancel {
354 drop(pending_prediction.task);
355 } else {
356 cx.spawn(async move |this, cx| {
357 let Some((prediction_id, model_version)) = pending_prediction.task.await else {
358 return;
359 };
360
361 this.update(cx, |this, cx| {
362 this.reject_prediction(
363 prediction_id,
364 EditPredictionRejectReason::Canceled,
365 false,
366 model_version,
367 None,
368 cx,
369 );
370 })
371 .ok();
372 })
373 .detach()
374 }
375 }
376
377 fn active_buffer(
378 &self,
379 project: &Entity<Project>,
380 cx: &App,
381 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
382 let project = project.read(cx);
383 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
384 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
385 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
386 Some((active_buffer, registered_buffer.last_position))
387 }
388}
389
390#[derive(Debug, Clone)]
391struct CurrentEditPrediction {
392 pub requested_by: PredictionRequestedBy,
393 pub prediction: EditPrediction,
394 pub was_shown: bool,
395 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
396 pub e2e_latency: std::time::Duration,
397}
398
399impl CurrentEditPrediction {
400 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
401 let Some(new_edits) = self
402 .prediction
403 .interpolate(&self.prediction.buffer.read(cx))
404 else {
405 return false;
406 };
407
408 if self.prediction.buffer != old_prediction.prediction.buffer {
409 return true;
410 }
411
412 let Some(old_edits) = old_prediction
413 .prediction
414 .interpolate(&old_prediction.prediction.buffer.read(cx))
415 else {
416 return true;
417 };
418
419 let requested_by_buffer_id = self.requested_by.buffer_id();
420
421 // This reduces the occurrence of UI thrash from replacing edits
422 //
423 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
424 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
425 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
426 && old_edits.len() == 1
427 && new_edits.len() == 1
428 {
429 let (old_range, old_text) = &old_edits[0];
430 let (new_range, new_text) = &new_edits[0];
431 new_range == old_range && new_text.starts_with(old_text.as_ref())
432 } else {
433 true
434 }
435 }
436}
437
438#[derive(Debug, Clone)]
439enum PredictionRequestedBy {
440 DiagnosticsUpdate,
441 Buffer(EntityId),
442}
443
444impl PredictionRequestedBy {
445 pub fn buffer_id(&self) -> Option<EntityId> {
446 match self {
447 PredictionRequestedBy::DiagnosticsUpdate => None,
448 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
449 }
450 }
451}
452
453const DIAGNOSTIC_LINES_RANGE: u32 = 20;
454
455#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
456pub enum DiagnosticSearchScope {
457 Local,
458 Global,
459}
460
461#[derive(Debug)]
462struct PendingPrediction {
463 id: usize,
464 task: Task<Option<(EditPredictionId, Option<String>)>>,
465 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
466 /// If false, the task is awaited to completion so rejection can be reported.
467 drop_on_cancel: bool,
468}
469
470/// A prediction from the perspective of a buffer.
471#[derive(Debug)]
472enum BufferEditPrediction<'a> {
473 Local { prediction: &'a EditPrediction },
474 Jump { prediction: &'a EditPrediction },
475}
476
477#[cfg(test)]
478impl std::ops::Deref for BufferEditPrediction<'_> {
479 type Target = EditPrediction;
480
481 fn deref(&self) -> &Self::Target {
482 match self {
483 BufferEditPrediction::Local { prediction } => prediction,
484 BufferEditPrediction::Jump { prediction } => prediction,
485 }
486 }
487}
488
489#[derive(Clone)]
490struct PendingSettledPrediction {
491 request_id: EditPredictionId,
492 editable_anchor_range: Range<Anchor>,
493 editable_region_before_prediction: String,
494 predicted_editable_region: String,
495 ts_error_count_before_prediction: usize,
496 ts_error_count_after_prediction: usize,
497 example: Option<ExampleSpec>,
498 enqueued_at: Instant,
499 last_edit_at: Instant,
500 e2e_latency: std::time::Duration,
501}
502
503struct RegisteredBuffer {
504 file: Option<Arc<dyn File>>,
505 snapshot: TextBufferSnapshot,
506 pending_predictions: Vec<PendingSettledPrediction>,
507 last_position: Option<Anchor>,
508 _subscriptions: [gpui::Subscription; 2],
509}
510
511#[derive(Clone)]
512struct LastEvent {
513 old_snapshot: TextBufferSnapshot,
514 new_snapshot: TextBufferSnapshot,
515 old_file: Option<Arc<dyn File>>,
516 new_file: Option<Arc<dyn File>>,
517 latest_edit_range: Range<Anchor>,
518 total_edit_range: Range<Anchor>,
519 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
520 predicted: bool,
521 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
522 last_edit_time: Option<Instant>,
523}
524
525impl LastEvent {
526 pub fn finalize(
527 &self,
528 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
529 cx: &App,
530 ) -> Option<StoredEvent> {
531 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
532 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
533
534 let in_open_source_repo =
535 [self.new_file.as_ref(), self.old_file.as_ref()]
536 .iter()
537 .all(|file| {
538 file.is_some_and(|file| {
539 license_detection_watchers
540 .get(&file.worktree_id(cx))
541 .is_some_and(|watcher| watcher.is_project_open_source())
542 })
543 });
544
545 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
546 &self.old_snapshot,
547 &self.new_snapshot,
548 &self.total_edit_range,
549 )?;
550
551 if path == old_path && diff.is_empty() {
552 None
553 } else {
554 Some(StoredEvent {
555 event: Arc::new(zeta_prompt::Event::BufferChange {
556 old_path,
557 path,
558 diff,
559 in_open_source_repo,
560 predicted: self.predicted,
561 }),
562 old_snapshot: self.old_snapshot.clone(),
563 new_snapshot_version: self.new_snapshot.version.clone(),
564 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
565 ..self.new_snapshot.anchor_before(edit_range.end),
566 })
567 }
568 }
569
570 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
571 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
572 return (self.clone(), None);
573 };
574
575 let total_edit_range_before_pause = self
576 .total_edit_range_at_last_pause_boundary
577 .clone()
578 .unwrap_or_else(|| self.total_edit_range.clone());
579
580 let Some(total_edit_range_after_pause) =
581 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
582 else {
583 return (self.clone(), None);
584 };
585
586 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
587 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
588
589 let before = LastEvent {
590 old_snapshot: self.old_snapshot.clone(),
591 new_snapshot: boundary_snapshot.clone(),
592 old_file: self.old_file.clone(),
593 new_file: self.new_file.clone(),
594 latest_edit_range: latest_edit_range_before_pause,
595 total_edit_range: total_edit_range_before_pause,
596 total_edit_range_at_last_pause_boundary: None,
597 predicted: self.predicted,
598 snapshot_after_last_editing_pause: None,
599 last_edit_time: self.last_edit_time,
600 };
601
602 let after = LastEvent {
603 old_snapshot: boundary_snapshot.clone(),
604 new_snapshot: self.new_snapshot.clone(),
605 old_file: self.old_file.clone(),
606 new_file: self.new_file.clone(),
607 latest_edit_range: latest_edit_range_after_pause,
608 total_edit_range: total_edit_range_after_pause,
609 total_edit_range_at_last_pause_boundary: None,
610 predicted: self.predicted,
611 snapshot_after_last_editing_pause: None,
612 last_edit_time: self.last_edit_time,
613 };
614
615 (before, Some(after))
616 }
617}
618
619fn compute_total_edit_range_between_snapshots(
620 old_snapshot: &TextBufferSnapshot,
621 new_snapshot: &TextBufferSnapshot,
622) -> Option<Range<Anchor>> {
623 let edits: Vec<Edit<usize>> = new_snapshot
624 .edits_since::<usize>(&old_snapshot.version)
625 .collect();
626
627 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
628 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
629 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
630
631 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
632}
633
634fn compute_old_range_for_new_range(
635 old_snapshot: &TextBufferSnapshot,
636 new_snapshot: &TextBufferSnapshot,
637 total_edit_range: &Range<Anchor>,
638) -> Option<Range<Point>> {
639 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
640 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
641
642 let edits: Vec<Edit<usize>> = new_snapshot
643 .edits_since::<usize>(&old_snapshot.version)
644 .collect();
645 let mut old_start_offset = None;
646 let mut old_end_offset = None;
647 let mut delta: isize = 0;
648
649 for edit in &edits {
650 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
651 old_start_offset = Some(if new_start_offset < edit.new.start {
652 new_start_offset.checked_add_signed(-delta)?
653 } else {
654 edit.old.start
655 });
656 }
657
658 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
659 old_end_offset = Some(if new_end_offset < edit.new.start {
660 new_end_offset.checked_add_signed(-delta)?
661 } else {
662 edit.old.end
663 });
664 }
665
666 delta += edit.new.len() as isize - edit.old.len() as isize;
667 }
668
669 let old_start_offset =
670 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
671 let old_end_offset =
672 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
673
674 Some(
675 old_snapshot.offset_to_point(old_start_offset)
676 ..old_snapshot.offset_to_point(old_end_offset),
677 )
678}
679
680fn compute_diff_between_snapshots_in_range(
681 old_snapshot: &TextBufferSnapshot,
682 new_snapshot: &TextBufferSnapshot,
683 total_edit_range: &Range<Anchor>,
684) -> Option<(String, Range<Point>)> {
685 let new_start_point = total_edit_range.start.to_point(new_snapshot);
686 let new_end_point = total_edit_range.end.to_point(new_snapshot);
687 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
688 let old_start_point = old_range.start;
689 let old_end_point = old_range.end;
690
691 const CONTEXT_LINES: u32 = 3;
692
693 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
694 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
695 let old_context_end_row =
696 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
697 let new_context_end_row =
698 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
699
700 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
701 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
702 let old_end_line_offset = old_snapshot
703 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
704 let new_end_line_offset = new_snapshot
705 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
706 let old_edit_range = old_start_line_offset..old_end_line_offset;
707 let new_edit_range = new_start_line_offset..new_end_line_offset;
708
709 if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
710 || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
711 {
712 return None;
713 }
714
715 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
716 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
717
718 let diff = language::unified_diff_with_offsets(
719 &old_region_text,
720 &new_region_text,
721 old_context_start_row,
722 new_context_start_row,
723 );
724
725 Some((diff, new_start_point..new_end_point))
726}
727
728fn buffer_path_with_id_fallback(
729 file: Option<&Arc<dyn File>>,
730 snapshot: &TextBufferSnapshot,
731 cx: &App,
732) -> Arc<Path> {
733 if let Some(file) = file {
734 file.full_path(cx).into()
735 } else {
736 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
737 }
738}
739
740impl EditPredictionStore {
741 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
742 cx.try_global::<EditPredictionStoreGlobal>()
743 .map(|global| global.0.clone())
744 }
745
746 pub fn global(
747 client: &Arc<Client>,
748 user_store: &Entity<UserStore>,
749 cx: &mut App,
750 ) -> Entity<Self> {
751 cx.try_global::<EditPredictionStoreGlobal>()
752 .map(|global| global.0.clone())
753 .unwrap_or_else(|| {
754 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
755 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
756 ep_store
757 })
758 }
759
760 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
761 let llm_token = global_llm_token(cx);
762 let legacy_data_collection_enabled = Self::load_legacy_data_collection_enabled(cx);
763
764 let (reject_tx, reject_rx) = mpsc::unbounded();
765 cx.background_spawn({
766 let client = client.clone();
767 let llm_token = llm_token.clone();
768 let app_version = AppVersion::global(cx);
769 let background_executor = cx.background_executor().clone();
770 async move {
771 Self::handle_rejected_predictions(
772 reject_rx,
773 client,
774 llm_token,
775 app_version,
776 background_executor,
777 )
778 .await
779 }
780 })
781 .detach();
782
783 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
784 cx.spawn(async move |this, cx| {
785 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
786 })
787 .detach();
788
789 let mut current_user = user_store.read(cx).watch_current_user();
790 let fetch_experiments_task = cx.spawn(async move |this, cx| {
791 while current_user.borrow().is_none() {
792 current_user.next().await;
793 }
794
795 this.update(cx, |this, cx| {
796 if cx.is_staff() {
797 this.refresh_available_experiments(cx);
798 }
799 })
800 .log_err();
801 });
802
803 let credentials_provider = zed_credentials_provider::global(cx);
804
805 let this = Self {
806 projects: HashMap::default(),
807 client,
808 user_store,
809 llm_token,
810 _fetch_experiments_task: fetch_experiments_task,
811 update_required: false,
812 edit_prediction_model: EditPredictionModel::Zeta,
813 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
814 preferred_experiment: None,
815 available_experiments: Vec::new(),
816 mercury: Mercury::new(cx),
817 legacy_data_collection_enabled,
818
819 reject_predictions_tx: reject_tx,
820 settled_predictions_tx,
821 rated_predictions: Default::default(),
822 shown_predictions: Default::default(),
823 #[cfg(test)]
824 settled_event_callback: None,
825
826 credentials_provider,
827 };
828
829 this
830 }
831
832 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
833 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
834 let format = ZetaFormat::parse(&version_str).ok()?;
835 let model_id = env::var("ZED_ZETA_MODEL").ok();
836 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
837 Some(Zeta2RawConfig {
838 model_id,
839 environment,
840 format,
841 })
842 }
843
844 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
845 self.edit_prediction_model = model;
846 }
847
848 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
849 self.zeta2_raw_config = Some(config);
850 }
851
852 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
853 self.zeta2_raw_config.as_ref()
854 }
855
856 pub fn preferred_experiment(&self) -> Option<&str> {
857 self.preferred_experiment.as_deref()
858 }
859
860 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
861 self.preferred_experiment = experiment;
862 }
863
864 pub fn available_experiments(&self) -> &[String] {
865 &self.available_experiments
866 }
867
868 pub fn active_experiment(&self) -> Option<&str> {
869 self.preferred_experiment.as_deref().or_else(|| {
870 self.shown_predictions
871 .iter()
872 .find_map(|p| p.model_version.as_ref())
873 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
874 })
875 }
876
877 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
878 let client = self.client.clone();
879 let llm_token = self.llm_token.clone();
880 let app_version = AppVersion::global(cx);
881 let organization_id = self
882 .user_store
883 .read(cx)
884 .current_organization()
885 .map(|organization| organization.id.clone());
886
887 cx.spawn(async move |this, cx| {
888 let experiments = cx
889 .background_spawn(async move {
890 let http_client = client.http_client();
891 let token = client
892 .acquire_llm_token(&llm_token, organization_id.clone())
893 .await?;
894 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
895 let request = http_client::Request::builder()
896 .method(Method::GET)
897 .uri(url.as_ref())
898 .header("Authorization", format!("Bearer {}", token))
899 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
900 .body(Default::default())?;
901 let mut response = http_client.send(request).await?;
902 if response.status().is_success() {
903 let mut body = Vec::new();
904 response.body_mut().read_to_end(&mut body).await?;
905 let experiments: Vec<String> = serde_json::from_slice(&body)?;
906 Ok(experiments)
907 } else {
908 let mut body = String::new();
909 response.body_mut().read_to_string(&mut body).await?;
910 anyhow::bail!(
911 "Failed to fetch experiments: {:?}\nBody: {}",
912 response.status(),
913 body
914 );
915 }
916 })
917 .await?;
918 this.update(cx, |this, cx| {
919 this.available_experiments = experiments;
920 cx.notify();
921 })?;
922 anyhow::Ok(())
923 })
924 .detach_and_log_err(cx);
925 }
926
927 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
928 use ui::IconName;
929 match self.edit_prediction_model {
930 EditPredictionModel::Mercury => {
931 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
932 }
933 EditPredictionModel::Zeta => {
934 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
935 .with_disabled(IconName::ZedPredictDisabled)
936 .with_up(IconName::ZedPredictUp)
937 .with_down(IconName::ZedPredictDown)
938 .with_error(IconName::ZedPredictError)
939 }
940 EditPredictionModel::Fim { .. } => {
941 let settings = &all_language_settings(None, cx).edit_predictions;
942 match settings.provider {
943 EditPredictionProvider::Ollama => {
944 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
945 }
946 _ => {
947 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
948 }
949 }
950 }
951 }
952 }
953
954 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
955 self.mercury.api_token.read(cx).has_key()
956 }
957
958 pub fn mercury_has_payment_required_error(&self) -> bool {
959 self.mercury.has_payment_required_error()
960 }
961
962 pub fn clear_history(&mut self) {
963 for project_state in self.projects.values_mut() {
964 project_state.events.clear();
965 project_state.last_event.take();
966 }
967 }
968
969 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
970 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
971 project_state.events.clear();
972 project_state.last_event.take();
973 }
974 }
975
976 pub fn edit_history_for_project(
977 &self,
978 project: &Entity<Project>,
979 cx: &App,
980 ) -> Vec<StoredEvent> {
981 self.projects
982 .get(&project.entity_id())
983 .map(|project_state| project_state.events(cx))
984 .unwrap_or_default()
985 }
986
987 pub fn context_for_project<'a>(
988 &'a self,
989 project: &Entity<Project>,
990 cx: &'a mut App,
991 ) -> Vec<RelatedFile> {
992 self.projects
993 .get(&project.entity_id())
994 .map(|project_state| {
995 project_state.context.update(cx, |context, cx| {
996 context
997 .related_files_with_buffers(cx)
998 .map(|(mut related_file, buffer)| {
999 related_file.in_open_source_repo = buffer
1000 .read(cx)
1001 .file()
1002 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1003 related_file
1004 })
1005 .collect()
1006 })
1007 })
1008 .unwrap_or_default()
1009 }
1010
1011 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1012 self.projects
1013 .get(&project.entity_id())
1014 .and_then(|project| project.copilot.clone())
1015 }
1016
1017 pub fn start_copilot_for_project(
1018 &mut self,
1019 project: &Entity<Project>,
1020 cx: &mut Context<Self>,
1021 ) -> Option<Entity<Copilot>> {
1022 if DisableAiSettings::get(None, cx).disable_ai {
1023 return None;
1024 }
1025 let state = self.get_or_init_project(project, cx);
1026
1027 if state.copilot.is_some() {
1028 return state.copilot.clone();
1029 }
1030 let _project = project.clone();
1031 let project = project.read(cx);
1032
1033 let node = project.node_runtime().cloned();
1034 if let Some(node) = node {
1035 let next_id = project.languages().next_language_server_id();
1036 let fs = project.fs().clone();
1037
1038 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1039 state.copilot = Some(copilot.clone());
1040 Some(copilot)
1041 } else {
1042 None
1043 }
1044 }
1045
1046 pub fn context_for_project_with_buffers<'a>(
1047 &'a self,
1048 project: &Entity<Project>,
1049 cx: &'a mut App,
1050 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1051 self.projects
1052 .get(&project.entity_id())
1053 .map(|project| {
1054 project.context.update(cx, |context, cx| {
1055 context.related_files_with_buffers(cx).collect()
1056 })
1057 })
1058 .unwrap_or_default()
1059 }
1060
1061 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1062 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1063 self.user_store.read(cx).edit_prediction_usage()
1064 } else {
1065 None
1066 }
1067 }
1068
1069 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1070 self.get_or_init_project(project, cx);
1071 }
1072
1073 pub fn register_buffer(
1074 &mut self,
1075 buffer: &Entity<Buffer>,
1076 project: &Entity<Project>,
1077 cx: &mut Context<Self>,
1078 ) {
1079 let project_state = self.get_or_init_project(project, cx);
1080 Self::register_buffer_impl(project_state, buffer, project, cx);
1081 }
1082
1083 fn get_or_init_project(
1084 &mut self,
1085 project: &Entity<Project>,
1086 cx: &mut Context<Self>,
1087 ) -> &mut ProjectState {
1088 let entity_id = project.entity_id();
1089 self.projects
1090 .entry(entity_id)
1091 .or_insert_with(|| ProjectState {
1092 context: {
1093 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1094 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1095 this.handle_excerpt_store_event(entity_id, event);
1096 })
1097 .detach();
1098 related_excerpt_store
1099 },
1100 events: VecDeque::new(),
1101 last_event: None,
1102 recent_paths: VecDeque::new(),
1103 debug_tx: None,
1104 registered_buffers: HashMap::default(),
1105 current_prediction: None,
1106 cancelled_predictions: HashSet::default(),
1107 pending_predictions: ArrayVec::new(),
1108 next_pending_prediction_id: 0,
1109 last_edit_prediction_refresh: None,
1110 last_jump_prediction_refresh: None,
1111 license_detection_watchers: HashMap::default(),
1112 _subscriptions: [
1113 cx.subscribe(&project, Self::handle_project_event),
1114 cx.observe_release(&project, move |this, _, cx| {
1115 this.projects.remove(&entity_id);
1116 cx.notify();
1117 }),
1118 ],
1119 copilot: None,
1120 })
1121 }
1122
1123 pub fn remove_project(&mut self, project: &Entity<Project>) {
1124 self.projects.remove(&project.entity_id());
1125 }
1126
1127 fn handle_excerpt_store_event(
1128 &mut self,
1129 project_entity_id: EntityId,
1130 event: &RelatedExcerptStoreEvent,
1131 ) {
1132 if let Some(project_state) = self.projects.get(&project_entity_id) {
1133 if let Some(debug_tx) = project_state.debug_tx.clone() {
1134 match event {
1135 RelatedExcerptStoreEvent::StartedRefresh => {
1136 debug_tx
1137 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1138 ContextRetrievalStartedDebugEvent {
1139 project_entity_id: project_entity_id,
1140 timestamp: Instant::now(),
1141 search_prompt: String::new(),
1142 },
1143 ))
1144 .ok();
1145 }
1146 RelatedExcerptStoreEvent::FinishedRefresh {
1147 cache_hit_count,
1148 cache_miss_count,
1149 mean_definition_latency,
1150 max_definition_latency,
1151 } => {
1152 debug_tx
1153 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1154 ContextRetrievalFinishedDebugEvent {
1155 project_entity_id: project_entity_id,
1156 timestamp: Instant::now(),
1157 metadata: vec![
1158 (
1159 "Cache Hits",
1160 format!(
1161 "{}/{}",
1162 cache_hit_count,
1163 cache_hit_count + cache_miss_count
1164 )
1165 .into(),
1166 ),
1167 (
1168 "Max LSP Time",
1169 format!("{} ms", max_definition_latency.as_millis())
1170 .into(),
1171 ),
1172 (
1173 "Mean LSP Time",
1174 format!("{} ms", mean_definition_latency.as_millis())
1175 .into(),
1176 ),
1177 ],
1178 },
1179 ))
1180 .ok();
1181 }
1182 }
1183 }
1184 }
1185 }
1186
1187 pub fn debug_info(
1188 &mut self,
1189 project: &Entity<Project>,
1190 cx: &mut Context<Self>,
1191 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1192 let project_state = self.get_or_init_project(project, cx);
1193 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1194 project_state.debug_tx = Some(debug_watch_tx);
1195 debug_watch_rx
1196 }
1197
1198 fn handle_project_event(
1199 &mut self,
1200 project: Entity<Project>,
1201 event: &project::Event,
1202 cx: &mut Context<Self>,
1203 ) {
1204 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1205 return;
1206 }
1207 // TODO [zeta2] init with recent paths
1208 match event {
1209 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1210 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1211 return;
1212 };
1213 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1214 if let Some(path) = path {
1215 if let Some(ix) = project_state
1216 .recent_paths
1217 .iter()
1218 .position(|probe| probe == &path)
1219 {
1220 project_state.recent_paths.remove(ix);
1221 }
1222 project_state.recent_paths.push_front(path);
1223 }
1224 }
1225 project::Event::DiagnosticsUpdated { .. } => {
1226 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1227 self.refresh_prediction_from_diagnostics(
1228 project,
1229 DiagnosticSearchScope::Global,
1230 cx,
1231 );
1232 }
1233 }
1234 _ => (),
1235 }
1236 }
1237
1238 fn register_buffer_impl<'a>(
1239 project_state: &'a mut ProjectState,
1240 buffer: &Entity<Buffer>,
1241 project: &Entity<Project>,
1242 cx: &mut Context<Self>,
1243 ) -> &'a mut RegisteredBuffer {
1244 let buffer_id = buffer.entity_id();
1245
1246 if let Some(file) = buffer.read(cx).file() {
1247 let worktree_id = file.worktree_id(cx);
1248 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1249 project_state
1250 .license_detection_watchers
1251 .entry(worktree_id)
1252 .or_insert_with(|| {
1253 let project_entity_id = project.entity_id();
1254 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1255 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1256 else {
1257 return;
1258 };
1259 project_state
1260 .license_detection_watchers
1261 .remove(&worktree_id);
1262 })
1263 .detach();
1264 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1265 });
1266 }
1267 }
1268
1269 match project_state.registered_buffers.entry(buffer_id) {
1270 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1271 hash_map::Entry::Vacant(entry) => {
1272 let buf = buffer.read(cx);
1273 let snapshot = buf.text_snapshot();
1274 let file = buf.file().cloned();
1275 let project_entity_id = project.entity_id();
1276 entry.insert(RegisteredBuffer {
1277 snapshot,
1278 file,
1279 last_position: None,
1280 pending_predictions: Vec::new(),
1281 _subscriptions: [
1282 cx.subscribe(buffer, {
1283 let project = project.downgrade();
1284 move |this, buffer, event, cx| {
1285 if let language::BufferEvent::Edited { is_local } = event
1286 && let Some(project) = project.upgrade()
1287 {
1288 this.report_changes_for_buffer(
1289 &buffer, &project, false, *is_local, cx,
1290 );
1291 }
1292 }
1293 }),
1294 cx.observe_release(buffer, move |this, _buffer, _cx| {
1295 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1296 else {
1297 return;
1298 };
1299 project_state.registered_buffers.remove(&buffer_id);
1300 }),
1301 ],
1302 })
1303 }
1304 }
1305 }
1306
1307 fn report_changes_for_buffer(
1308 &mut self,
1309 buffer: &Entity<Buffer>,
1310 project: &Entity<Project>,
1311 is_predicted: bool,
1312 is_local: bool,
1313 cx: &mut Context<Self>,
1314 ) {
1315 let project_state = self.get_or_init_project(project, cx);
1316 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1317
1318 let buf = buffer.read(cx);
1319 let new_file = buf.file().cloned();
1320 let new_snapshot = buf.text_snapshot();
1321 if new_snapshot.version == registered_buffer.snapshot.version {
1322 return;
1323 }
1324 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1325 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1326 let mut edit_range: Option<Range<Anchor>> = None;
1327 let now = cx.background_executor().now();
1328
1329 for (_edit, anchor_range) in
1330 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1331 {
1332 edit_range = Some(match edit_range {
1333 None => anchor_range,
1334 Some(acc) => acc.start..anchor_range.end,
1335 });
1336 }
1337
1338 let Some(edit_range) = edit_range else {
1339 return;
1340 };
1341
1342 for pending_prediction in &mut registered_buffer.pending_predictions {
1343 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1344 pending_prediction.last_edit_at = now;
1345 }
1346 }
1347
1348 let include_in_history = is_local
1349 || collaborator_edit_overlaps_locality_region(
1350 project_state,
1351 project,
1352 buffer,
1353 &buf.snapshot(),
1354 &edit_range,
1355 cx,
1356 );
1357
1358 if !include_in_history {
1359 return;
1360 }
1361
1362 let is_recordable_history_edit =
1363 compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1364 .is_some();
1365
1366 let events = &mut project_state.events;
1367
1368 if !is_recordable_history_edit {
1369 if let Some(event) = project_state.last_event.take() {
1370 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1371 if events.len() + 1 >= EVENT_COUNT_MAX {
1372 events.pop_front();
1373 }
1374 events.push_back(event);
1375 }
1376 }
1377 return;
1378 }
1379
1380 if let Some(last_event) = project_state.last_event.as_mut() {
1381 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1382 == last_event.new_snapshot.remote_id()
1383 && old_snapshot.version == last_event.new_snapshot.version;
1384
1385 let prediction_source_changed = is_predicted != last_event.predicted;
1386
1387 let should_coalesce = is_next_snapshot_of_same_buffer
1388 && !prediction_source_changed
1389 && lines_between_ranges(
1390 &edit_range.to_point(&new_snapshot),
1391 &last_event.latest_edit_range.to_point(&new_snapshot),
1392 ) <= CHANGE_GROUPING_LINE_SPAN;
1393
1394 if should_coalesce {
1395 let pause_elapsed = last_event
1396 .last_edit_time
1397 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1398 .unwrap_or(false);
1399 if pause_elapsed {
1400 last_event.snapshot_after_last_editing_pause =
1401 Some(last_event.new_snapshot.clone());
1402 last_event.total_edit_range_at_last_pause_boundary =
1403 Some(last_event.total_edit_range.clone());
1404 }
1405
1406 last_event.latest_edit_range = edit_range.clone();
1407 last_event.total_edit_range =
1408 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1409 last_event.new_snapshot = new_snapshot;
1410 last_event.last_edit_time = Some(now);
1411 return;
1412 }
1413 }
1414
1415 if let Some(event) = project_state.last_event.take() {
1416 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1417 if events.len() + 1 >= EVENT_COUNT_MAX {
1418 events.pop_front();
1419 }
1420 events.push_back(event);
1421 }
1422 }
1423
1424 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1425
1426 project_state.last_event = Some(LastEvent {
1427 old_file,
1428 new_file,
1429 old_snapshot,
1430 new_snapshot,
1431 latest_edit_range: edit_range.clone(),
1432 total_edit_range: edit_range,
1433 total_edit_range_at_last_pause_boundary: None,
1434 predicted: is_predicted,
1435 snapshot_after_last_editing_pause: None,
1436 last_edit_time: Some(now),
1437 });
1438 }
1439
1440 fn prediction_at(
1441 &mut self,
1442 buffer: &Entity<Buffer>,
1443 position: Option<language::Anchor>,
1444 project: &Entity<Project>,
1445 cx: &App,
1446 ) -> Option<BufferEditPrediction<'_>> {
1447 let project_state = self.projects.get_mut(&project.entity_id())?;
1448 if let Some(position) = position
1449 && let Some(buffer) = project_state
1450 .registered_buffers
1451 .get_mut(&buffer.entity_id())
1452 {
1453 buffer.last_position = Some(position);
1454 }
1455
1456 let CurrentEditPrediction {
1457 requested_by,
1458 prediction,
1459 ..
1460 } = project_state.current_prediction.as_ref()?;
1461
1462 if prediction.targets_buffer(buffer.read(cx)) {
1463 Some(BufferEditPrediction::Local { prediction })
1464 } else {
1465 let show_jump = match requested_by {
1466 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1467 requested_by_buffer_id == &buffer.entity_id()
1468 }
1469 PredictionRequestedBy::DiagnosticsUpdate => true,
1470 };
1471
1472 if show_jump {
1473 Some(BufferEditPrediction::Jump { prediction })
1474 } else {
1475 None
1476 }
1477 }
1478 }
1479
1480 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1481 let Some(current_prediction) = self
1482 .projects
1483 .get_mut(&project.entity_id())
1484 .and_then(|project_state| project_state.current_prediction.take())
1485 else {
1486 return;
1487 };
1488
1489 self.report_changes_for_buffer(
1490 ¤t_prediction.prediction.buffer,
1491 project,
1492 true,
1493 true,
1494 cx,
1495 );
1496
1497 // can't hold &mut project_state ref across report_changes_for_buffer_call
1498 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1499 return;
1500 };
1501
1502 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1503 project_state.cancel_pending_prediction(pending_prediction, cx);
1504 }
1505
1506 match self.edit_prediction_model {
1507 EditPredictionModel::Mercury => {
1508 mercury::edit_prediction_accepted(
1509 current_prediction.prediction.id,
1510 self.client.http_client(),
1511 cx,
1512 );
1513 }
1514 EditPredictionModel::Zeta => {
1515 let is_cloud = !matches!(
1516 all_language_settings(None, cx).edit_predictions.provider,
1517 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1518 );
1519 if is_cloud {
1520 zeta::edit_prediction_accepted(self, current_prediction, cx)
1521 }
1522 }
1523 EditPredictionModel::Fim { .. } => {}
1524 }
1525 }
1526
1527 async fn handle_rejected_predictions(
1528 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1529 client: Arc<Client>,
1530 llm_token: LlmApiToken,
1531 app_version: Version,
1532 background_executor: BackgroundExecutor,
1533 ) {
1534 let mut rx = std::pin::pin!(rx.peekable());
1535 let mut batched = Vec::new();
1536
1537 while let Some(EditPredictionRejectionPayload {
1538 rejection,
1539 organization_id,
1540 }) = rx.next().await
1541 {
1542 batched.push(rejection);
1543
1544 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1545 select_biased! {
1546 next = rx.as_mut().peek().fuse() => {
1547 if next.is_some() {
1548 continue;
1549 }
1550 }
1551 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1552 }
1553 }
1554
1555 let url = client
1556 .http_client()
1557 .build_zed_llm_url("/predict_edits/reject", &[])
1558 .unwrap();
1559
1560 let flush_count = batched
1561 .len()
1562 // in case items have accumulated after failure
1563 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1564 let start = batched.len() - flush_count;
1565
1566 let body = RejectEditPredictionsBodyRef {
1567 rejections: &batched[start..],
1568 };
1569
1570 let result = Self::send_api_request::<()>(
1571 |builder| {
1572 let req = builder
1573 .uri(url.as_ref())
1574 .body(serde_json::to_string(&body)?.into());
1575 anyhow::Ok(req?)
1576 },
1577 client.clone(),
1578 llm_token.clone(),
1579 organization_id,
1580 app_version.clone(),
1581 true,
1582 )
1583 .await;
1584
1585 if result.log_err().is_some() {
1586 batched.drain(start..);
1587 }
1588 }
1589 }
1590
1591 async fn run_settled_predictions_worker(
1592 this: WeakEntity<Self>,
1593 mut rx: UnboundedReceiver<Instant>,
1594 cx: &mut AsyncApp,
1595 ) {
1596 let mut next_wake_time: Option<Instant> = None;
1597 loop {
1598 let now = cx.background_executor().now();
1599 if let Some(wake_time) = next_wake_time.take() {
1600 cx.background_executor()
1601 .timer(wake_time.duration_since(now))
1602 .await;
1603 } else {
1604 let Some(new_enqueue_time) = rx.next().await else {
1605 break;
1606 };
1607 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1608 while rx.next().now_or_never().flatten().is_some() {}
1609 continue;
1610 }
1611
1612 let Some(this) = this.upgrade() else {
1613 break;
1614 };
1615
1616 let now = cx.background_executor().now();
1617 let mut oldest_edited_at = None;
1618 let mut ready_predictions = Vec::new();
1619
1620 this.update(cx, |this, _| {
1621 for (_, project_state) in this.projects.iter_mut() {
1622 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1623 let mut pending_index = 0;
1624 while pending_index < registered_buffer.pending_predictions.len() {
1625 let pending_prediction =
1626 ®istered_buffer.pending_predictions[pending_index];
1627 let age = now.saturating_duration_since(pending_prediction.enqueued_at);
1628 if age >= EDIT_PREDICTION_SETTLED_TTL {
1629 registered_buffer.pending_predictions.remove(pending_index);
1630 continue;
1631 }
1632
1633 let quiet_for =
1634 now.saturating_duration_since(pending_prediction.last_edit_at);
1635 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1636 let pending_prediction =
1637 registered_buffer.pending_predictions.remove(pending_index);
1638 let settled_editable_region = registered_buffer
1639 .snapshot
1640 .text_for_range(
1641 pending_prediction.editable_anchor_range.clone(),
1642 )
1643 .collect::<String>();
1644 ready_predictions
1645 .push((pending_prediction, settled_editable_region));
1646 continue;
1647 }
1648
1649 if oldest_edited_at
1650 .is_none_or(|time| pending_prediction.last_edit_at < time)
1651 {
1652 oldest_edited_at = Some(pending_prediction.last_edit_at);
1653 }
1654 pending_index += 1;
1655 }
1656 }
1657 }
1658 });
1659
1660 for (pending_prediction, settled_editable_region) in ready_predictions {
1661 let PendingSettledPrediction {
1662 request_id,
1663 editable_region_before_prediction,
1664 predicted_editable_region,
1665 ts_error_count_before_prediction,
1666 ts_error_count_after_prediction,
1667 example,
1668 e2e_latency,
1669 ..
1670 } = pending_prediction;
1671 let settled_editable_region_for_metrics = settled_editable_region.clone();
1672 let kept_rate_result = cx
1673 .background_spawn(async move {
1674 compute_kept_rate(
1675 &editable_region_before_prediction,
1676 &predicted_editable_region,
1677 &settled_editable_region_for_metrics,
1678 )
1679 })
1680 .await;
1681
1682 #[cfg(test)]
1683 {
1684 let request_id = request_id.clone();
1685 let settled_editable_region = settled_editable_region.clone();
1686 this.update(cx, |this, _| {
1687 if let Some(callback) = &this.settled_event_callback {
1688 callback(request_id, settled_editable_region);
1689 }
1690 });
1691 }
1692
1693 telemetry::event!(
1694 EDIT_PREDICTION_SETTLED_EVENT,
1695 request_id = request_id.0.clone(),
1696 settled_editable_region,
1697 ts_error_count_before_prediction,
1698 ts_error_count_after_prediction,
1699 edit_bytes_candidate_new = kept_rate_result.candidate_new_chars,
1700 edit_bytes_reference_new = kept_rate_result.reference_new_chars,
1701 edit_bytes_candidate_deleted = kept_rate_result.candidate_deleted_chars,
1702 edit_bytes_reference_deleted = kept_rate_result.reference_deleted_chars,
1703 edit_bytes_kept = kept_rate_result.kept_chars,
1704 edit_bytes_correctly_deleted = kept_rate_result.correctly_deleted_chars,
1705 edit_bytes_discarded = kept_rate_result.discarded_chars,
1706 edit_bytes_context = kept_rate_result.context_chars,
1707 edit_bytes_kept_rate = kept_rate_result.kept_rate,
1708 edit_bytes_recall_rate = kept_rate_result.recall_rate,
1709 example,
1710 e2e_latency = e2e_latency.as_millis(),
1711 );
1712 }
1713
1714 next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1715 }
1716 }
1717
1718 pub(crate) fn enqueue_settled_prediction(
1719 &mut self,
1720 request_id: EditPredictionId,
1721 project: &Entity<Project>,
1722 edited_buffer: &Entity<Buffer>,
1723 edited_buffer_snapshot: &BufferSnapshot,
1724 editable_offset_range: Range<usize>,
1725 edit_preview: &EditPreview,
1726 example: Option<ExampleSpec>,
1727 e2e_latency: std::time::Duration,
1728 cx: &mut Context<Self>,
1729 ) {
1730 let this = &mut *self;
1731 let project_state = this.get_or_init_project(project, cx);
1732 let Some(registered_buffer) = project_state
1733 .registered_buffers
1734 .get_mut(&edited_buffer.entity_id())
1735 else {
1736 return;
1737 };
1738
1739 let editable_region_before_prediction = edited_buffer_snapshot
1740 .text_for_range(editable_offset_range.clone())
1741 .collect::<String>();
1742 let editable_anchor_range_for_result =
1743 edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
1744 let predicted_editable_region = edit_preview
1745 .result_text_snapshot()
1746 .text_for_range(editable_anchor_range_for_result.clone())
1747 .collect();
1748 let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
1749 edited_buffer_snapshot
1750 .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
1751 );
1752 let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
1753 edit_preview.result_syntax_snapshot().layers_for_range(
1754 editable_anchor_range_for_result,
1755 edit_preview.result_text_snapshot(),
1756 true,
1757 ),
1758 );
1759 let editable_anchor_range =
1760 edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
1761 let now = cx.background_executor().now();
1762 registered_buffer
1763 .pending_predictions
1764 .push(PendingSettledPrediction {
1765 request_id,
1766 editable_anchor_range,
1767 editable_region_before_prediction,
1768 predicted_editable_region,
1769 ts_error_count_before_prediction,
1770 ts_error_count_after_prediction,
1771 example,
1772 e2e_latency,
1773 enqueued_at: now,
1774 last_edit_at: now,
1775 });
1776 this.settled_predictions_tx.unbounded_send(now).ok();
1777 }
1778
1779 fn reject_current_prediction(
1780 &mut self,
1781 reason: EditPredictionRejectReason,
1782 project: &Entity<Project>,
1783 cx: &App,
1784 ) {
1785 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1786 project_state.pending_predictions.clear();
1787 if let Some(prediction) = project_state.current_prediction.take() {
1788 let model_version = prediction.prediction.model_version.clone();
1789 self.reject_prediction(
1790 prediction.prediction.id,
1791 reason,
1792 prediction.was_shown,
1793 model_version,
1794 Some(prediction.e2e_latency),
1795 cx,
1796 );
1797 }
1798 };
1799 }
1800
1801 fn did_show_current_prediction(
1802 &mut self,
1803 project: &Entity<Project>,
1804 display_type: edit_prediction_types::SuggestionDisplayType,
1805 _cx: &mut Context<Self>,
1806 ) {
1807 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1808 return;
1809 };
1810
1811 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1812 return;
1813 };
1814
1815 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1816 let previous_shown_with = current_prediction.shown_with;
1817
1818 if previous_shown_with.is_none() || !is_jump {
1819 current_prediction.shown_with = Some(display_type);
1820 }
1821
1822 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1823
1824 if is_first_non_jump_show {
1825 current_prediction.was_shown = true;
1826 }
1827
1828 if is_first_non_jump_show {
1829 self.shown_predictions
1830 .push_front(current_prediction.prediction.clone());
1831 if self.shown_predictions.len() > 50 {
1832 let completion = self.shown_predictions.pop_back().unwrap();
1833 self.rated_predictions.remove(&completion.id);
1834 }
1835 }
1836 }
1837
1838 fn reject_prediction(
1839 &mut self,
1840 prediction_id: EditPredictionId,
1841 reason: EditPredictionRejectReason,
1842 was_shown: bool,
1843 model_version: Option<String>,
1844 e2e_latency: Option<std::time::Duration>,
1845 cx: &App,
1846 ) {
1847 match self.edit_prediction_model {
1848 EditPredictionModel::Zeta => {
1849 let is_cloud = !matches!(
1850 all_language_settings(None, cx).edit_predictions.provider,
1851 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1852 );
1853
1854 if is_cloud {
1855 let organization_id = self
1856 .user_store
1857 .read(cx)
1858 .current_organization()
1859 .map(|organization| organization.id.clone());
1860
1861 self.reject_predictions_tx
1862 .unbounded_send(EditPredictionRejectionPayload {
1863 rejection: EditPredictionRejection {
1864 request_id: prediction_id.to_string(),
1865 reason,
1866 was_shown,
1867 model_version,
1868 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1869 },
1870 organization_id,
1871 })
1872 .log_err();
1873 }
1874 }
1875 EditPredictionModel::Mercury => {
1876 mercury::edit_prediction_rejected(
1877 prediction_id,
1878 was_shown,
1879 reason,
1880 self.client.http_client(),
1881 cx,
1882 );
1883 }
1884 EditPredictionModel::Fim { .. } => {}
1885 }
1886 }
1887
1888 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1889 self.projects
1890 .get(&project.entity_id())
1891 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1892 }
1893
1894 pub fn refresh_prediction_from_buffer(
1895 &mut self,
1896 project: Entity<Project>,
1897 buffer: Entity<Buffer>,
1898 position: language::Anchor,
1899 cx: &mut Context<Self>,
1900 ) {
1901 self.queue_prediction_refresh(
1902 project.clone(),
1903 PredictEditsRequestTrigger::Other,
1904 buffer.entity_id(),
1905 cx,
1906 move |this, cx| {
1907 let Some(request_task) = this
1908 .update(cx, |this, cx| {
1909 this.request_prediction(
1910 &project,
1911 &buffer,
1912 position,
1913 PredictEditsRequestTrigger::Other,
1914 cx,
1915 )
1916 })
1917 .log_err()
1918 else {
1919 return Task::ready(anyhow::Ok(None));
1920 };
1921
1922 cx.spawn(async move |_cx| {
1923 request_task.await.map(|prediction_result| {
1924 prediction_result.map(|prediction_result| {
1925 (
1926 prediction_result,
1927 PredictionRequestedBy::Buffer(buffer.entity_id()),
1928 )
1929 })
1930 })
1931 })
1932 },
1933 )
1934 }
1935
1936 pub fn refresh_prediction_from_diagnostics(
1937 &mut self,
1938 project: Entity<Project>,
1939 scope: DiagnosticSearchScope,
1940 cx: &mut Context<Self>,
1941 ) {
1942 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1943 return;
1944 }
1945
1946 if currently_following(&project, cx) {
1947 return;
1948 }
1949
1950 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1951 return;
1952 };
1953
1954 // Prefer predictions from buffer
1955 if project_state.current_prediction.is_some() {
1956 log::debug!(
1957 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1958 );
1959 return;
1960 }
1961
1962 self.queue_prediction_refresh(
1963 project.clone(),
1964 PredictEditsRequestTrigger::Diagnostics,
1965 project.entity_id(),
1966 cx,
1967 move |this, cx| {
1968 let Some((active_buffer, snapshot, cursor_point)) = this
1969 .read_with(cx, |this, cx| {
1970 let project_state = this.projects.get(&project.entity_id())?;
1971 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1972 let snapshot = buffer.read(cx).snapshot();
1973
1974 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1975 return None;
1976 }
1977
1978 let cursor_point = position
1979 .map(|pos| pos.to_point(&snapshot))
1980 .unwrap_or_default();
1981
1982 Some((buffer, snapshot, cursor_point))
1983 })
1984 .log_err()
1985 .flatten()
1986 else {
1987 return Task::ready(anyhow::Ok(None));
1988 };
1989
1990 cx.spawn(async move |cx| {
1991 let diagnostic_search_range = match scope {
1992 DiagnosticSearchScope::Local => {
1993 let diagnostic_search_start =
1994 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1995 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1996 Point::new(diagnostic_search_start, 0)
1997 ..Point::new(diagnostic_search_end, 0)
1998 }
1999 DiagnosticSearchScope::Global => Default::default(),
2000 };
2001
2002 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
2003 active_buffer,
2004 &snapshot,
2005 diagnostic_search_range,
2006 cursor_point,
2007 &project,
2008 cx,
2009 )
2010 .await?
2011 else {
2012 return anyhow::Ok(None);
2013 };
2014
2015 let Some(prediction_result) = this
2016 .update(cx, |this, cx| {
2017 this.request_prediction(
2018 &project,
2019 &jump_buffer,
2020 jump_position,
2021 PredictEditsRequestTrigger::Diagnostics,
2022 cx,
2023 )
2024 })?
2025 .await?
2026 else {
2027 return anyhow::Ok(None);
2028 };
2029
2030 this.update(cx, |this, cx| {
2031 Some((
2032 if this
2033 .get_or_init_project(&project, cx)
2034 .current_prediction
2035 .is_none()
2036 {
2037 prediction_result
2038 } else {
2039 EditPredictionResult {
2040 id: prediction_result.id,
2041 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2042 model_version: prediction_result.model_version,
2043 e2e_latency: prediction_result.e2e_latency,
2044 }
2045 },
2046 PredictionRequestedBy::DiagnosticsUpdate,
2047 ))
2048 })
2049 })
2050 },
2051 );
2052 }
2053
2054 fn predictions_enabled_at(
2055 snapshot: &BufferSnapshot,
2056 position: Option<language::Anchor>,
2057 cx: &App,
2058 ) -> bool {
2059 let file = snapshot.file();
2060 let all_settings = all_language_settings(file, cx);
2061 if !all_settings.show_edit_predictions(snapshot.language(), cx)
2062 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2063 {
2064 return false;
2065 }
2066
2067 if let Some(last_position) = position {
2068 let settings = snapshot.settings_at(last_position, cx);
2069
2070 if !settings.edit_predictions_disabled_in.is_empty()
2071 && let Some(scope) = snapshot.language_scope_at(last_position)
2072 && let Some(scope_name) = scope.override_name()
2073 && settings
2074 .edit_predictions_disabled_in
2075 .iter()
2076 .any(|s| s == scope_name)
2077 {
2078 return false;
2079 }
2080 }
2081
2082 true
2083 }
2084
2085 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2086}
2087
2088fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2089 let Some(app_state) = AppState::try_global(cx) else {
2090 return false;
2091 };
2092
2093 app_state
2094 .workspace_store
2095 .read(cx)
2096 .workspaces()
2097 .filter_map(|workspace| workspace.upgrade())
2098 .any(|workspace| {
2099 workspace.read(cx).project().entity_id() == project.entity_id()
2100 && workspace
2101 .read(cx)
2102 .leader_for_pane(workspace.read(cx).active_pane())
2103 .is_some()
2104 })
2105}
2106
2107fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2108 match provider {
2109 EditPredictionProvider::Zed
2110 | EditPredictionProvider::Mercury
2111 | EditPredictionProvider::Ollama
2112 | EditPredictionProvider::OpenAiCompatibleApi
2113 | EditPredictionProvider::Experimental(_) => true,
2114 EditPredictionProvider::None
2115 | EditPredictionProvider::Copilot
2116 | EditPredictionProvider::Codestral => false,
2117 }
2118}
2119
2120impl EditPredictionStore {
2121 fn queue_prediction_refresh(
2122 &mut self,
2123 project: Entity<Project>,
2124 request_trigger: PredictEditsRequestTrigger,
2125 throttle_entity: EntityId,
2126 cx: &mut Context<Self>,
2127 do_refresh: impl FnOnce(
2128 WeakEntity<Self>,
2129 &mut AsyncApp,
2130 )
2131 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2132 + 'static,
2133 ) {
2134 fn select_throttle(
2135 project_state: &mut ProjectState,
2136 request_trigger: PredictEditsRequestTrigger,
2137 ) -> &mut Option<(EntityId, Instant)> {
2138 match request_trigger {
2139 PredictEditsRequestTrigger::Diagnostics => {
2140 &mut project_state.last_jump_prediction_refresh
2141 }
2142 _ => &mut project_state.last_edit_prediction_refresh,
2143 }
2144 }
2145
2146 let (needs_acceptance_tracking, max_pending_predictions) =
2147 match all_language_settings(None, cx).edit_predictions.provider {
2148 EditPredictionProvider::Zed
2149 | EditPredictionProvider::Mercury
2150 | EditPredictionProvider::Experimental(_) => (true, 2),
2151 EditPredictionProvider::Ollama => (false, 1),
2152 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2153 EditPredictionProvider::None
2154 | EditPredictionProvider::Copilot
2155 | EditPredictionProvider::Codestral => {
2156 log::error!("queue_prediction_refresh called with non-store provider");
2157 return;
2158 }
2159 };
2160
2161 let drop_on_cancel = !needs_acceptance_tracking;
2162 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2163 let project_state = self.get_or_init_project(&project, cx);
2164 let pending_prediction_id = project_state.next_pending_prediction_id;
2165 project_state.next_pending_prediction_id += 1;
2166 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2167
2168 let task = cx.spawn(async move |this, cx| {
2169 let throttle_wait = this
2170 .update(cx, |this, cx| {
2171 let project_state = this.get_or_init_project(&project, cx);
2172 let throttle = *select_throttle(project_state, request_trigger);
2173
2174 let now = cx.background_executor().now();
2175 throttle.and_then(|(last_entity, last_timestamp)| {
2176 if throttle_entity != last_entity {
2177 return None;
2178 }
2179 (last_timestamp + throttle_timeout).checked_duration_since(now)
2180 })
2181 })
2182 .ok()
2183 .flatten();
2184
2185 if let Some(timeout) = throttle_wait {
2186 cx.background_executor().timer(timeout).await;
2187 }
2188
2189 // If this task was cancelled before the throttle timeout expired,
2190 // do not perform a request. Also skip if another task already
2191 // proceeded since we were enqueued (duplicate).
2192 let mut is_cancelled = true;
2193 this.update(cx, |this, cx| {
2194 let project_state = this.get_or_init_project(&project, cx);
2195 let was_cancelled = project_state
2196 .cancelled_predictions
2197 .remove(&pending_prediction_id);
2198 if was_cancelled {
2199 return;
2200 }
2201
2202 // Another request has been already sent since this was enqueued
2203 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2204 return;
2205 }
2206
2207 let new_refresh = (throttle_entity, cx.background_executor().now());
2208 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2209 is_cancelled = false;
2210 })
2211 .ok();
2212 if is_cancelled {
2213 return None;
2214 }
2215
2216 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2217 let new_prediction_metadata = new_prediction_result
2218 .as_ref()
2219 .map(|(prediction, _)| (prediction.id.clone(), prediction.model_version.clone()));
2220
2221 // When a prediction completes, remove it from the pending list, and cancel
2222 // any pending predictions that were enqueued before it.
2223 this.update(cx, |this, cx| {
2224 let project_state = this.get_or_init_project(&project, cx);
2225
2226 let is_cancelled = project_state
2227 .cancelled_predictions
2228 .remove(&pending_prediction_id);
2229
2230 let new_current_prediction = if !is_cancelled
2231 && let Some((prediction_result, requested_by)) = new_prediction_result
2232 {
2233 match prediction_result.prediction {
2234 Ok(prediction) => {
2235 let new_prediction = CurrentEditPrediction {
2236 requested_by,
2237 prediction,
2238 was_shown: false,
2239 shown_with: None,
2240 e2e_latency: prediction_result.e2e_latency,
2241 };
2242
2243 if let Some(current_prediction) =
2244 project_state.current_prediction.as_ref()
2245 {
2246 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2247 {
2248 this.reject_current_prediction(
2249 EditPredictionRejectReason::Replaced,
2250 &project,
2251 cx,
2252 );
2253
2254 Some(new_prediction)
2255 } else {
2256 this.reject_prediction(
2257 new_prediction.prediction.id,
2258 EditPredictionRejectReason::CurrentPreferred,
2259 false,
2260 new_prediction.prediction.model_version,
2261 Some(new_prediction.e2e_latency),
2262 cx,
2263 );
2264 None
2265 }
2266 } else {
2267 Some(new_prediction)
2268 }
2269 }
2270 Err(reject_reason) => {
2271 this.reject_prediction(
2272 prediction_result.id,
2273 reject_reason,
2274 false,
2275 prediction_result.model_version,
2276 Some(prediction_result.e2e_latency),
2277 cx,
2278 );
2279 None
2280 }
2281 }
2282 } else {
2283 None
2284 };
2285
2286 let project_state = this.get_or_init_project(&project, cx);
2287
2288 if let Some(new_prediction) = new_current_prediction {
2289 project_state.current_prediction = Some(new_prediction);
2290 }
2291
2292 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2293 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2294 if pending_prediction.id == pending_prediction_id {
2295 pending_predictions.remove(ix);
2296 for pending_prediction in pending_predictions.drain(0..ix) {
2297 project_state.cancel_pending_prediction(pending_prediction, cx)
2298 }
2299 break;
2300 }
2301 }
2302 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2303 cx.notify();
2304 })
2305 .ok();
2306
2307 new_prediction_metadata
2308 });
2309
2310 if project_state.pending_predictions.len() < max_pending_predictions {
2311 project_state
2312 .pending_predictions
2313 .push(PendingPrediction {
2314 id: pending_prediction_id,
2315 task,
2316 drop_on_cancel,
2317 })
2318 .unwrap();
2319 } else {
2320 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2321 project_state
2322 .pending_predictions
2323 .push(PendingPrediction {
2324 id: pending_prediction_id,
2325 task,
2326 drop_on_cancel,
2327 })
2328 .unwrap();
2329 project_state.cancel_pending_prediction(pending_prediction, cx);
2330 }
2331 }
2332
2333 pub fn request_prediction(
2334 &mut self,
2335 project: &Entity<Project>,
2336 active_buffer: &Entity<Buffer>,
2337 position: language::Anchor,
2338 trigger: PredictEditsRequestTrigger,
2339 cx: &mut Context<Self>,
2340 ) -> Task<Result<Option<EditPredictionResult>>> {
2341 self.request_prediction_internal(
2342 project.clone(),
2343 active_buffer.clone(),
2344 position,
2345 trigger,
2346 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2347 cx,
2348 )
2349 }
2350
2351 fn request_prediction_internal(
2352 &mut self,
2353 project: Entity<Project>,
2354 active_buffer: Entity<Buffer>,
2355 position: language::Anchor,
2356 trigger: PredictEditsRequestTrigger,
2357 allow_jump: bool,
2358 cx: &mut Context<Self>,
2359 ) -> Task<Result<Option<EditPredictionResult>>> {
2360 self.get_or_init_project(&project, cx);
2361 let project_state = self.projects.get(&project.entity_id()).unwrap();
2362 let stored_events = project_state.events(cx);
2363 let has_events = !stored_events.is_empty();
2364 let events: Vec<Arc<zeta_prompt::Event>> =
2365 stored_events.iter().map(|e| e.event.clone()).collect();
2366 let debug_tx = project_state.debug_tx.clone();
2367
2368 let snapshot = active_buffer.read(cx).snapshot();
2369 let cursor_point = position.to_point(&snapshot);
2370 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2371 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2372 let diagnostic_search_range =
2373 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2374
2375 let related_files = self.context_for_project(&project, cx);
2376 let mode = match all_language_settings(snapshot.file(), cx).edit_predictions_mode() {
2377 EditPredictionsMode::Eager => PredictEditsMode::Eager,
2378 EditPredictionsMode::Subtle => PredictEditsMode::Subtle,
2379 };
2380
2381 let is_open_source = snapshot
2382 .file()
2383 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2384 && events.iter().all(|event| event.in_open_source_repo())
2385 && related_files.iter().all(|file| file.in_open_source_repo);
2386
2387 let can_collect_data = !cfg!(test)
2388 && is_open_source
2389 && self.is_data_collection_enabled(cx)
2390 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2391 let inputs = EditPredictionModelInput {
2392 project: project.clone(),
2393 buffer: active_buffer,
2394 snapshot,
2395 position,
2396 events,
2397 related_files,
2398 mode,
2399 trigger,
2400 diagnostic_search_range,
2401 debug_tx,
2402 can_collect_data,
2403 is_open_source,
2404 };
2405
2406 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2407
2408 let task = match self.edit_prediction_model {
2409 EditPredictionModel::Zeta => {
2410 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2411 }
2412 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2413 EditPredictionModel::Mercury => {
2414 self.mercury
2415 .request_prediction(inputs, self.credentials_provider.clone(), cx)
2416 }
2417 };
2418
2419 cx.spawn(async move |this, cx| {
2420 let prediction = task.await?;
2421
2422 // Only fall back to diagnostics-based prediction if we got a
2423 // the model had nothing to suggest for the buffer
2424 if prediction.is_none()
2425 && allow_jump
2426 && has_events
2427 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2428 {
2429 this.update(cx, |this, cx| {
2430 this.refresh_prediction_from_diagnostics(
2431 project,
2432 DiagnosticSearchScope::Local,
2433 cx,
2434 );
2435 })?;
2436 return anyhow::Ok(None);
2437 }
2438
2439 Ok(prediction)
2440 })
2441 }
2442
2443 pub(crate) async fn next_diagnostic_location(
2444 active_buffer: Entity<Buffer>,
2445 active_buffer_snapshot: &BufferSnapshot,
2446 active_buffer_diagnostic_search_range: Range<Point>,
2447 active_buffer_cursor_point: Point,
2448 project: &Entity<Project>,
2449 cx: &mut AsyncApp,
2450 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2451 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2452 .selections_in_range(
2453 Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2454 false,
2455 )
2456 .flat_map(|(_, _, _, selections)| {
2457 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2458 })
2459 .collect();
2460
2461 let mut jump_location = active_buffer_snapshot
2462 .diagnostic_groups(None)
2463 .into_iter()
2464 .filter_map(|(_, group)| {
2465 let range = &group.entries[group.primary_ix]
2466 .range
2467 .to_point(&active_buffer_snapshot);
2468 if range.overlaps(&active_buffer_diagnostic_search_range) {
2469 return None;
2470 }
2471 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2472 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2473 });
2474 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2475 <= DIAGNOSTIC_LINES_RANGE;
2476 if near_collaborator && !near_local {
2477 return None;
2478 }
2479 Some(range.start)
2480 })
2481 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2482 .map(|position| {
2483 (
2484 active_buffer.clone(),
2485 active_buffer_snapshot.anchor_before(position),
2486 )
2487 });
2488
2489 if jump_location.is_none() {
2490 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2491 let file = buffer.file()?;
2492
2493 Some(ProjectPath {
2494 worktree_id: file.worktree_id(cx),
2495 path: file.path().clone(),
2496 })
2497 });
2498
2499 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2500 project
2501 .diagnostic_summaries(false, cx)
2502 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2503 .map(|(path, _, _)| {
2504 let shared_prefix = path
2505 .path
2506 .components()
2507 .zip(
2508 active_buffer_path
2509 .as_ref()
2510 .map(|p| p.path.components())
2511 .unwrap_or_default(),
2512 )
2513 .take_while(|(a, b)| a == b)
2514 .count();
2515 (path, shared_prefix)
2516 })
2517 .collect()
2518 });
2519
2520 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2521
2522 for (path, _) in candidates {
2523 let candidate_buffer = project
2524 .update(cx, |project, cx| project.open_buffer(path, cx))
2525 .await?;
2526
2527 let (has_collaborators, diagnostic_position) =
2528 candidate_buffer.read_with(cx, |buffer, _cx| {
2529 let snapshot = buffer.snapshot();
2530 let has_collaborators = snapshot
2531 .selections_in_range(
2532 Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2533 false,
2534 )
2535 .next()
2536 .is_some();
2537 let position = buffer
2538 .buffer_diagnostics(None)
2539 .into_iter()
2540 .min_by_key(|entry| entry.diagnostic.severity)
2541 .map(|entry| entry.range.start);
2542 (has_collaborators, position)
2543 });
2544
2545 if has_collaborators {
2546 continue;
2547 }
2548
2549 if let Some(position) = diagnostic_position {
2550 jump_location = Some((candidate_buffer, position));
2551 break;
2552 }
2553 }
2554 }
2555
2556 anyhow::Ok(jump_location)
2557 }
2558
2559 async fn send_raw_llm_request(
2560 request: RawCompletionRequest,
2561 client: Arc<Client>,
2562 custom_url: Option<Arc<Url>>,
2563 llm_token: LlmApiToken,
2564 organization_id: Option<OrganizationId>,
2565 app_version: Version,
2566 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2567 let url = if let Some(custom_url) = custom_url {
2568 custom_url.as_ref().clone()
2569 } else {
2570 client
2571 .http_client()
2572 .build_zed_llm_url("/predict_edits/raw", &[])?
2573 };
2574
2575 Self::send_api_request(
2576 |builder| {
2577 let req = builder
2578 .uri(url.as_ref())
2579 .body(serde_json::to_string(&request)?.into());
2580 Ok(req?)
2581 },
2582 client,
2583 llm_token,
2584 organization_id,
2585 app_version,
2586 true,
2587 )
2588 .await
2589 }
2590
2591 pub(crate) async fn send_v3_request(
2592 input: ZetaPromptInput,
2593 preferred_experiment: Option<String>,
2594 client: Arc<Client>,
2595 llm_token: LlmApiToken,
2596 organization_id: Option<OrganizationId>,
2597 app_version: Version,
2598 trigger: PredictEditsRequestTrigger,
2599 mode: PredictEditsMode,
2600 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2601 let url = client
2602 .http_client()
2603 .build_zed_llm_url("/predict_edits/v3", &[])?;
2604
2605 let request = PredictEditsV3Request { input, trigger };
2606
2607 let json_bytes = serde_json::to_vec(&request)?;
2608 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2609
2610 Self::send_api_request(
2611 |builder| {
2612 let builder = builder
2613 .uri(url.as_ref())
2614 .header("Content-Encoding", "zstd")
2615 .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref());
2616 let builder = if let Some(preferred_experiment) = preferred_experiment.as_deref() {
2617 builder.header(PREFERRED_EXPERIMENT_HEADER_NAME, preferred_experiment)
2618 } else {
2619 builder
2620 };
2621 let req = builder.body(compressed.clone().into());
2622 Ok(req?)
2623 },
2624 client,
2625 llm_token,
2626 organization_id,
2627 app_version,
2628 true,
2629 )
2630 .await
2631 }
2632
2633 async fn send_api_request<Res>(
2634 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2635 client: Arc<Client>,
2636 llm_token: LlmApiToken,
2637 organization_id: Option<OrganizationId>,
2638 app_version: Version,
2639 require_auth: bool,
2640 ) -> Result<(Res, Option<EditPredictionUsage>)>
2641 where
2642 Res: DeserializeOwned,
2643 {
2644 let http_client = client.http_client();
2645 let mut token = if require_auth {
2646 Some(
2647 client
2648 .acquire_llm_token(&llm_token, organization_id.clone())
2649 .await?,
2650 )
2651 } else {
2652 client
2653 .acquire_llm_token(&llm_token, organization_id.clone())
2654 .await
2655 .ok()
2656 };
2657 let mut did_retry = false;
2658
2659 loop {
2660 let request_builder = http_client::Request::builder().method(Method::POST);
2661
2662 let mut request_builder = request_builder
2663 .header("Content-Type", "application/json")
2664 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2665
2666 // Only add Authorization header if we have a token
2667 if let Some(ref token_value) = token {
2668 request_builder =
2669 request_builder.header("Authorization", format!("Bearer {}", token_value));
2670 }
2671
2672 let request = build(request_builder)?;
2673
2674 let mut response = http_client.send(request).await?;
2675
2676 if let Some(minimum_required_version) = response
2677 .headers()
2678 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2679 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2680 {
2681 anyhow::ensure!(
2682 app_version >= minimum_required_version,
2683 ZedUpdateRequiredError {
2684 minimum_version: minimum_required_version
2685 }
2686 );
2687 }
2688
2689 if response.status().is_success() {
2690 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2691
2692 let mut body = Vec::new();
2693 response.body_mut().read_to_end(&mut body).await?;
2694 return Ok((serde_json::from_slice(&body)?, usage));
2695 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2696 did_retry = true;
2697 token = Some(
2698 client
2699 .refresh_llm_token(&llm_token, organization_id.clone())
2700 .await?,
2701 );
2702 } else {
2703 let mut body = String::new();
2704 response.body_mut().read_to_string(&mut body).await?;
2705 anyhow::bail!(
2706 "Request failed with status: {:?}\nBody: {}",
2707 response.status(),
2708 body
2709 );
2710 }
2711 }
2712 }
2713
2714 pub fn refresh_context(
2715 &mut self,
2716 project: &Entity<Project>,
2717 buffer: &Entity<language::Buffer>,
2718 cursor_position: language::Anchor,
2719 cx: &mut Context<Self>,
2720 ) {
2721 self.get_or_init_project(project, cx)
2722 .context
2723 .update(cx, |store, cx| {
2724 store.refresh(buffer.clone(), cursor_position, cx);
2725 });
2726 }
2727
2728 #[cfg(feature = "cli-support")]
2729 pub fn set_context_for_buffer(
2730 &mut self,
2731 project: &Entity<Project>,
2732 related_files: Vec<RelatedFile>,
2733 cx: &mut Context<Self>,
2734 ) {
2735 self.get_or_init_project(project, cx)
2736 .context
2737 .update(cx, |store, cx| {
2738 store.set_related_files(related_files, cx);
2739 });
2740 }
2741
2742 #[cfg(feature = "cli-support")]
2743 pub fn set_recent_paths_for_project(
2744 &mut self,
2745 project: &Entity<Project>,
2746 paths: impl IntoIterator<Item = project::ProjectPath>,
2747 cx: &mut Context<Self>,
2748 ) {
2749 let project_state = self.get_or_init_project(project, cx);
2750 project_state.recent_paths = paths.into_iter().collect();
2751 }
2752
2753 fn is_file_open_source(
2754 &self,
2755 project: &Entity<Project>,
2756 file: &Arc<dyn File>,
2757 cx: &App,
2758 ) -> bool {
2759 if !file.is_local() || file.is_private() {
2760 return false;
2761 }
2762 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2763 return false;
2764 };
2765 project_state
2766 .license_detection_watchers
2767 .get(&file.worktree_id(cx))
2768 .as_ref()
2769 .is_some_and(|watcher| watcher.is_project_open_source())
2770 }
2771
2772 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2773 if !self.is_data_collection_allowed_by_organization(cx) {
2774 return false;
2775 }
2776
2777 if cx.is_staff() {
2778 return true;
2779 }
2780
2781 match all_language_settings(None, cx)
2782 .edit_predictions
2783 .allow_data_collection
2784 {
2785 EditPredictionDataCollectionChoice::Yes => true,
2786 EditPredictionDataCollectionChoice::No => false,
2787 // Fall back to the legacy KV entry captured when the store was
2788 // created, preserving existing users' choices without per-request
2789 // database reads.
2790 EditPredictionDataCollectionChoice::Default => self.legacy_data_collection_enabled,
2791 }
2792 }
2793
2794 fn load_legacy_data_collection_enabled(cx: &App) -> bool {
2795 KeyValueStore::global(cx)
2796 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2797 .log_err()
2798 .flatten()
2799 .as_deref()
2800 == Some("true")
2801 }
2802
2803 pub(crate) fn is_data_collection_allowed_by_organization(&self, cx: &App) -> bool {
2804 self.user_store
2805 .read(cx)
2806 .current_organization_configuration()
2807 .is_none_or(|organization_configuration| {
2808 organization_configuration
2809 .edit_prediction
2810 .is_feedback_enabled
2811 })
2812 }
2813
2814 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2815 self.shown_predictions.iter()
2816 }
2817
2818 pub fn shown_completions_len(&self) -> usize {
2819 self.shown_predictions.len()
2820 }
2821
2822 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2823 self.rated_predictions.contains(id)
2824 }
2825
2826 pub fn rate_prediction(
2827 &mut self,
2828 prediction: &EditPrediction,
2829 rating: EditPredictionRating,
2830 feedback: String,
2831 cx: &mut Context<Self>,
2832 ) {
2833 let organization = self.user_store.read(cx).current_organization();
2834
2835 self.rated_predictions.insert(prediction.id.clone());
2836
2837 cx.background_spawn({
2838 let client = self.client.clone();
2839 let prediction_id = prediction.id.to_string();
2840 let inputs = serde_json::to_value(&prediction.inputs);
2841 let output = prediction
2842 .edit_preview
2843 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2844 async move {
2845 client
2846 .cloud_client()
2847 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2848 organization_id: organization.map(|organization| organization.id.clone()),
2849 request_id: prediction_id,
2850 rating: match rating {
2851 EditPredictionRating::Positive => "positive".to_string(),
2852 EditPredictionRating::Negative => "negative".to_string(),
2853 },
2854 inputs: inputs?,
2855 output,
2856 feedback,
2857 })
2858 .await?;
2859
2860 anyhow::Ok(())
2861 }
2862 })
2863 .detach_and_log_err(cx);
2864
2865 cx.notify();
2866 }
2867}
2868
2869fn collaborator_edit_overlaps_locality_region(
2870 project_state: &ProjectState,
2871 project: &Entity<Project>,
2872 buffer: &Entity<Buffer>,
2873 snapshot: &BufferSnapshot,
2874 edit_range: &Range<Anchor>,
2875 cx: &App,
2876) -> bool {
2877 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2878 return false;
2879 };
2880
2881 if active_buffer.entity_id() != buffer.entity_id() {
2882 return false;
2883 }
2884
2885 let locality_point_range = expand_context_syntactically_then_linewise(
2886 snapshot,
2887 (position..position).to_point(snapshot),
2888 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2889 );
2890 let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2891
2892 edit_range.overlaps(&locality_anchor_range, snapshot)
2893}
2894
2895fn merge_trailing_events_if_needed(
2896 events: &mut VecDeque<StoredEvent>,
2897 end_snapshot: &TextBufferSnapshot,
2898 latest_snapshot: &TextBufferSnapshot,
2899 latest_edit_range: &Range<Anchor>,
2900) {
2901 if let Some(last_event) = events.back() {
2902 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2903 return;
2904 }
2905 if !latest_snapshot
2906 .version
2907 .observed_all(&last_event.new_snapshot_version)
2908 {
2909 return;
2910 }
2911 }
2912
2913 let mut next_old_event = None;
2914 let mut mergeable_count = 0;
2915 for old_event in events.iter().rev() {
2916 if let Some(next_old_event) = next_old_event
2917 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2918 {
2919 break;
2920 }
2921 mergeable_count += 1;
2922 next_old_event = Some(old_event);
2923 }
2924
2925 if mergeable_count <= 1 {
2926 return;
2927 }
2928
2929 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2930 let oldest_event = events_to_merge.peek().unwrap();
2931 let oldest_snapshot = oldest_event.old_snapshot.clone();
2932 let newest_snapshot = end_snapshot;
2933 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2934
2935 for event in events.range(events.len() - mergeable_count + 1..) {
2936 merged_edit_range =
2937 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2938 }
2939
2940 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2941 &oldest_snapshot,
2942 newest_snapshot,
2943 &merged_edit_range,
2944 ) {
2945 let merged_event = match oldest_event.event.as_ref() {
2946 zeta_prompt::Event::BufferChange {
2947 old_path,
2948 path,
2949 in_open_source_repo,
2950 ..
2951 } => StoredEvent {
2952 event: Arc::new(zeta_prompt::Event::BufferChange {
2953 old_path: old_path.clone(),
2954 path: path.clone(),
2955 diff,
2956 in_open_source_repo: *in_open_source_repo,
2957 predicted: events_to_merge.all(|e| {
2958 matches!(
2959 e.event.as_ref(),
2960 zeta_prompt::Event::BufferChange {
2961 predicted: true,
2962 ..
2963 }
2964 )
2965 }),
2966 }),
2967 old_snapshot: oldest_snapshot.clone(),
2968 new_snapshot_version: newest_snapshot.version.clone(),
2969 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2970 ..newest_snapshot.anchor_before(edit_range.end),
2971 },
2972 };
2973 events.truncate(events.len() - mergeable_count);
2974 events.push_back(merged_event);
2975 }
2976}
2977
2978fn merge_anchor_ranges(
2979 left: &Range<Anchor>,
2980 right: &Range<Anchor>,
2981 snapshot: &TextBufferSnapshot,
2982) -> Range<Anchor> {
2983 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2984 left.start
2985 } else {
2986 right.start
2987 };
2988 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2989 left.end
2990 } else {
2991 right.end
2992 };
2993 start..end
2994}
2995
2996#[derive(Error, Debug)]
2997#[error(
2998 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2999)]
3000pub struct ZedUpdateRequiredError {
3001 minimum_version: Version,
3002}
3003
3004struct ZedPredictUpsell;
3005
3006fn is_upsell_dismissed(cx: &App) -> bool {
3007 // To make this backwards compatible with older versions of Zed, we
3008 // check if the user has seen the previous Edit Prediction Onboarding
3009 // before, by checking the data collection choice which was written to
3010 // the database once the user clicked on "Accept and Enable"
3011 let kvp = KeyValueStore::global(cx);
3012 if kvp
3013 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3014 .log_err()
3015 .is_some_and(|s| s.is_some())
3016 {
3017 return true;
3018 }
3019
3020 kvp.read_kvp(ZedPredictUpsell::KEY)
3021 .log_err()
3022 .is_some_and(|s| s.is_some())
3023}
3024
3025impl Dismissable for ZedPredictUpsell {
3026 const KEY: &'static str = "dismissed-edit-predict-upsell";
3027
3028 fn dismissed(cx: &App) -> bool {
3029 is_upsell_dismissed(cx)
3030 }
3031}
3032
3033pub fn should_show_upsell_modal(cx: &App) -> bool {
3034 !is_upsell_dismissed(cx)
3035}
3036
3037pub fn init(cx: &mut App) {
3038 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3039 workspace.register_action(
3040 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3041 ZedPredictModal::toggle(
3042 workspace,
3043 workspace.user_store().clone(),
3044 workspace.client().clone(),
3045 window,
3046 cx,
3047 )
3048 },
3049 );
3050
3051 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3052 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3053 settings
3054 .project
3055 .all_languages
3056 .edit_predictions
3057 .get_or_insert_default()
3058 .provider = Some(EditPredictionProvider::None)
3059 });
3060 });
3061 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3062 EditPredictionStore::try_global(cx).and_then(|store| {
3063 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3064 })
3065 }
3066
3067 workspace.register_action(|workspace, _: &SignIn, window, cx| {
3068 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3069 copilot_ui::initiate_sign_in(copilot, window, cx);
3070 }
3071 });
3072 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3073 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3074 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3075 }
3076 });
3077 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3078 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3079 copilot_ui::initiate_sign_out(copilot, window, cx);
3080 }
3081 });
3082 })
3083 .detach();
3084}