1use anyhow::Result;
2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
3use cloud_api_client::LlmApiToken;
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PREDICT_EDITS_MODE_HEADER_NAME, PredictEditsMode, PredictEditsV3Request,
7 PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
8};
9use cloud_llm_client::{
10 EditPredictionRejectReason, EditPredictionRejection,
11 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
12 PREFERRED_EXPERIMENT_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
13 ZED_VERSION_HEADER_NAME,
14};
15use collections::{HashMap, HashSet};
16use copilot::{Copilot, Reinstall, SignIn, SignOut};
17use credentials_provider::CredentialsProvider;
18use db::kvp::{Dismissable, KeyValueStore};
19use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PresenceFlag, register_feature_flag};
21use futures::{
22 AsyncReadExt as _, FutureExt as _, StreamExt as _,
23 channel::mpsc::{self, UnboundedReceiver},
24 select_biased,
25};
26use gpui::BackgroundExecutor;
27use gpui::http_client::Url;
28use gpui::{
29 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
30 http_client::{self, AsyncBody, Method},
31 prelude::*,
32};
33use heapless::Vec as ArrayVec;
34use language::{
35 Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point,
36 TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings,
37};
38use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
39use release_channel::AppVersion;
40use semver::Version;
41use serde::de::DeserializeOwned;
42use settings::{
43 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
44};
45use std::collections::{VecDeque, hash_map};
46use std::env;
47use text::{AnchorRangeExt, Edit};
48use workspace::{AppState, Workspace};
49use zeta_prompt::{ZetaFormat, ZetaPromptInput};
50
51use std::mem;
52use std::ops::Range;
53use std::path::Path;
54use std::rc::Rc;
55use std::str::FromStr as _;
56use std::sync::Arc;
57use std::time::{Duration, Instant};
58
59use thiserror::Error;
60use util::{RangeExt as _, ResultExt as _};
61
62pub mod cursor_excerpt;
63pub mod example_spec;
64pub mod fim;
65mod license_detection;
66pub mod mercury;
67pub mod metrics;
68pub mod ollama;
69mod onboarding_modal;
70pub mod open_ai_response;
71mod prediction;
72
73pub mod udiff;
74
75mod capture_example;
76pub mod open_ai_compatible;
77mod zed_edit_prediction_delegate;
78pub mod zeta;
79
80#[cfg(test)]
81mod edit_prediction_tests;
82
83use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
84use crate::example_spec::ExampleSpec;
85use crate::license_detection::LicenseDetectionWatcher;
86use crate::mercury::Mercury;
87pub use crate::metrics::{KeptRateResult, compute_kept_rate};
88use crate::onboarding_modal::ZedPredictModal;
89pub use crate::prediction::EditPrediction;
90pub use crate::prediction::EditPredictionId;
91use crate::prediction::EditPredictionResult;
92pub use capture_example::capture_example;
93pub use language_model::ApiKeyState;
94pub use telemetry_events::EditPredictionRating;
95pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
96
97actions!(
98 edit_prediction,
99 [
100 /// Resets the edit prediction onboarding state.
101 ResetOnboarding,
102 /// Clears the edit prediction history.
103 ClearHistory,
104 ]
105);
106
107/// Maximum number of events to track.
108const EVENT_COUNT_MAX: usize = 10;
109const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
110const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
111const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
112const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
113const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
114const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
115const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
116const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
117const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
118
119pub struct EditPredictionJumpsFeatureFlag;
120
121impl FeatureFlag for EditPredictionJumpsFeatureFlag {
122 const NAME: &'static str = "edit_prediction_jumps";
123 type Value = PresenceFlag;
124}
125register_feature_flag!(EditPredictionJumpsFeatureFlag);
126
127#[derive(Clone)]
128struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
129
130impl Global for EditPredictionStoreGlobal {}
131
132/// Configuration for using the raw Zeta2 endpoint.
133/// When set, the client uses the raw endpoint and constructs the prompt itself.
134/// The version is also used as the Baseten environment name (lowercased).
135#[derive(Clone)]
136pub struct Zeta2RawConfig {
137 pub model_id: Option<String>,
138 pub environment: Option<String>,
139 pub format: ZetaFormat,
140}
141
142pub struct EditPredictionStore {
143 client: Arc<Client>,
144 user_store: Entity<UserStore>,
145 llm_token: LlmApiToken,
146 _fetch_experiments_task: Task<()>,
147 projects: HashMap<EntityId, ProjectState>,
148 update_required: bool,
149 edit_prediction_model: EditPredictionModel,
150 zeta2_raw_config: Option<Zeta2RawConfig>,
151 preferred_experiment: Option<String>,
152 available_experiments: Vec<String>,
153 pub mercury: Mercury,
154 data_collection_choice: DataCollectionChoice,
155 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
156 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
157 shown_predictions: VecDeque<EditPrediction>,
158 rated_predictions: HashSet<EditPredictionId>,
159 #[cfg(test)]
160 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
161 credentials_provider: Arc<dyn CredentialsProvider>,
162}
163
164pub(crate) struct EditPredictionRejectionPayload {
165 rejection: EditPredictionRejection,
166 organization_id: Option<OrganizationId>,
167}
168
169#[derive(Copy, Clone, PartialEq, Eq)]
170pub enum EditPredictionModel {
171 Zeta,
172 Fim { format: EditPredictionPromptFormat },
173 Mercury,
174}
175
176#[derive(Clone)]
177pub struct EditPredictionModelInput {
178 project: Entity<Project>,
179 buffer: Entity<Buffer>,
180 snapshot: BufferSnapshot,
181 position: Anchor,
182 events: Vec<Arc<zeta_prompt::Event>>,
183 related_files: Vec<RelatedFile>,
184 mode: PredictEditsMode,
185 trigger: PredictEditsRequestTrigger,
186 diagnostic_search_range: Range<Point>,
187 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
188 can_collect_data: bool,
189 is_open_source: bool,
190}
191
192#[derive(Debug)]
193pub enum DebugEvent {
194 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
195 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
196 EditPredictionStarted(EditPredictionStartedDebugEvent),
197 EditPredictionFinished(EditPredictionFinishedDebugEvent),
198}
199
200#[derive(Debug)]
201pub struct ContextRetrievalStartedDebugEvent {
202 pub project_entity_id: EntityId,
203 pub timestamp: Instant,
204 pub search_prompt: String,
205}
206
207#[derive(Debug)]
208pub struct ContextRetrievalFinishedDebugEvent {
209 pub project_entity_id: EntityId,
210 pub timestamp: Instant,
211 pub metadata: Vec<(&'static str, SharedString)>,
212}
213
214#[derive(Debug)]
215pub struct EditPredictionStartedDebugEvent {
216 pub buffer: WeakEntity<Buffer>,
217 pub position: Anchor,
218 pub prompt: Option<String>,
219}
220
221#[derive(Debug)]
222pub struct EditPredictionFinishedDebugEvent {
223 pub buffer: WeakEntity<Buffer>,
224 pub position: Anchor,
225 pub model_output: Option<String>,
226}
227
228/// An event with associated metadata for reconstructing buffer state.
229#[derive(Clone)]
230pub struct StoredEvent {
231 pub event: Arc<zeta_prompt::Event>,
232 pub old_snapshot: TextBufferSnapshot,
233 pub new_snapshot_version: clock::Global,
234 pub total_edit_range: Range<Anchor>,
235}
236
237impl StoredEvent {
238 fn can_merge(
239 &self,
240 next_old_event: &StoredEvent,
241 latest_snapshot: &TextBufferSnapshot,
242 latest_edit_range: &Range<Anchor>,
243 ) -> bool {
244 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
245 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
246 return false;
247 }
248 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
249 return false;
250 }
251 if self.new_snapshot_version != next_old_event.old_snapshot.version {
252 return false;
253 }
254 if !latest_snapshot
255 .version
256 .observed_all(&next_old_event.new_snapshot_version)
257 {
258 return false;
259 }
260
261 let a_is_predicted = matches!(
262 self.event.as_ref(),
263 zeta_prompt::Event::BufferChange {
264 predicted: true,
265 ..
266 }
267 );
268 let b_is_predicted = matches!(
269 next_old_event.event.as_ref(),
270 zeta_prompt::Event::BufferChange {
271 predicted: true,
272 ..
273 }
274 );
275
276 // If events come from the same source (both predicted or both manual) then
277 // we would have coalesced them already.
278 if a_is_predicted == b_is_predicted {
279 return false;
280 }
281
282 let left_range = self.total_edit_range.to_point(latest_snapshot);
283 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
284 let latest_range = latest_edit_range.to_point(latest_snapshot);
285
286 // Events near to the latest edit are not merged if their sources differ.
287 if lines_between_ranges(&left_range, &latest_range)
288 .min(lines_between_ranges(&right_range, &latest_range))
289 <= CHANGE_GROUPING_LINE_SPAN
290 {
291 return false;
292 }
293
294 // Events that are distant from each other are not merged.
295 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
296 return false;
297 }
298
299 true
300 }
301}
302
303fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
304 if left.start > right.end {
305 return left.start.row - right.end.row;
306 }
307 if right.start > left.end {
308 return right.start.row - left.end.row;
309 }
310 0
311}
312
313struct ProjectState {
314 events: VecDeque<StoredEvent>,
315 last_event: Option<LastEvent>,
316 recent_paths: VecDeque<ProjectPath>,
317 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
318 current_prediction: Option<CurrentEditPrediction>,
319 next_pending_prediction_id: usize,
320 pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
321 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
322 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
323 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
324 cancelled_predictions: HashSet<usize>,
325 context: Entity<RelatedExcerptStore>,
326 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
327 _subscriptions: [gpui::Subscription; 2],
328 copilot: Option<Entity<Copilot>>,
329}
330
331impl ProjectState {
332 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
333 self.events
334 .iter()
335 .cloned()
336 .chain(self.last_event.as_ref().iter().flat_map(|event| {
337 let (one, two) = event.split_by_pause();
338 let one = one.finalize(&self.license_detection_watchers, cx);
339 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
340 one.into_iter().chain(two)
341 }))
342 .collect()
343 }
344
345 fn cancel_pending_prediction(
346 &mut self,
347 pending_prediction: PendingPrediction,
348 cx: &mut Context<EditPredictionStore>,
349 ) {
350 self.cancelled_predictions.insert(pending_prediction.id);
351
352 if pending_prediction.drop_on_cancel {
353 drop(pending_prediction.task);
354 } else {
355 cx.spawn(async move |this, cx| {
356 let Some((prediction_id, model_version)) = pending_prediction.task.await else {
357 return;
358 };
359
360 this.update(cx, |this, cx| {
361 this.reject_prediction(
362 prediction_id,
363 EditPredictionRejectReason::Canceled,
364 false,
365 model_version,
366 None,
367 cx,
368 );
369 })
370 .ok();
371 })
372 .detach()
373 }
374 }
375
376 fn active_buffer(
377 &self,
378 project: &Entity<Project>,
379 cx: &App,
380 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
381 let project = project.read(cx);
382 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
383 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
384 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
385 Some((active_buffer, registered_buffer.last_position))
386 }
387}
388
389#[derive(Debug, Clone)]
390struct CurrentEditPrediction {
391 pub requested_by: PredictionRequestedBy,
392 pub prediction: EditPrediction,
393 pub was_shown: bool,
394 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
395 pub e2e_latency: std::time::Duration,
396}
397
398impl CurrentEditPrediction {
399 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
400 let Some(new_edits) = self
401 .prediction
402 .interpolate(&self.prediction.buffer.read(cx))
403 else {
404 return false;
405 };
406
407 if self.prediction.buffer != old_prediction.prediction.buffer {
408 return true;
409 }
410
411 let Some(old_edits) = old_prediction
412 .prediction
413 .interpolate(&old_prediction.prediction.buffer.read(cx))
414 else {
415 return true;
416 };
417
418 let requested_by_buffer_id = self.requested_by.buffer_id();
419
420 // This reduces the occurrence of UI thrash from replacing edits
421 //
422 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
423 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
424 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
425 && old_edits.len() == 1
426 && new_edits.len() == 1
427 {
428 let (old_range, old_text) = &old_edits[0];
429 let (new_range, new_text) = &new_edits[0];
430 new_range == old_range && new_text.starts_with(old_text.as_ref())
431 } else {
432 true
433 }
434 }
435}
436
437#[derive(Debug, Clone)]
438enum PredictionRequestedBy {
439 DiagnosticsUpdate,
440 Buffer(EntityId),
441}
442
443impl PredictionRequestedBy {
444 pub fn buffer_id(&self) -> Option<EntityId> {
445 match self {
446 PredictionRequestedBy::DiagnosticsUpdate => None,
447 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
448 }
449 }
450}
451
452const DIAGNOSTIC_LINES_RANGE: u32 = 20;
453
454#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
455pub enum DiagnosticSearchScope {
456 Local,
457 Global,
458}
459
460#[derive(Debug)]
461struct PendingPrediction {
462 id: usize,
463 task: Task<Option<(EditPredictionId, Option<String>)>>,
464 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
465 /// If false, the task is awaited to completion so rejection can be reported.
466 drop_on_cancel: bool,
467}
468
469/// A prediction from the perspective of a buffer.
470#[derive(Debug)]
471enum BufferEditPrediction<'a> {
472 Local { prediction: &'a EditPrediction },
473 Jump { prediction: &'a EditPrediction },
474}
475
476#[cfg(test)]
477impl std::ops::Deref for BufferEditPrediction<'_> {
478 type Target = EditPrediction;
479
480 fn deref(&self) -> &Self::Target {
481 match self {
482 BufferEditPrediction::Local { prediction } => prediction,
483 BufferEditPrediction::Jump { prediction } => prediction,
484 }
485 }
486}
487
488#[derive(Clone)]
489struct PendingSettledPrediction {
490 request_id: EditPredictionId,
491 editable_anchor_range: Range<Anchor>,
492 editable_region_before_prediction: String,
493 predicted_editable_region: String,
494 ts_error_count_before_prediction: usize,
495 ts_error_count_after_prediction: usize,
496 example: Option<ExampleSpec>,
497 enqueued_at: Instant,
498 last_edit_at: Instant,
499 e2e_latency: std::time::Duration,
500}
501
502struct RegisteredBuffer {
503 file: Option<Arc<dyn File>>,
504 snapshot: TextBufferSnapshot,
505 pending_predictions: Vec<PendingSettledPrediction>,
506 last_position: Option<Anchor>,
507 _subscriptions: [gpui::Subscription; 2],
508}
509
510#[derive(Clone)]
511struct LastEvent {
512 old_snapshot: TextBufferSnapshot,
513 new_snapshot: TextBufferSnapshot,
514 old_file: Option<Arc<dyn File>>,
515 new_file: Option<Arc<dyn File>>,
516 latest_edit_range: Range<Anchor>,
517 total_edit_range: Range<Anchor>,
518 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
519 predicted: bool,
520 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
521 last_edit_time: Option<Instant>,
522}
523
524impl LastEvent {
525 pub fn finalize(
526 &self,
527 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
528 cx: &App,
529 ) -> Option<StoredEvent> {
530 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
531 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
532
533 let in_open_source_repo =
534 [self.new_file.as_ref(), self.old_file.as_ref()]
535 .iter()
536 .all(|file| {
537 file.is_some_and(|file| {
538 license_detection_watchers
539 .get(&file.worktree_id(cx))
540 .is_some_and(|watcher| watcher.is_project_open_source())
541 })
542 });
543
544 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
545 &self.old_snapshot,
546 &self.new_snapshot,
547 &self.total_edit_range,
548 )?;
549
550 if path == old_path && diff.is_empty() {
551 None
552 } else {
553 Some(StoredEvent {
554 event: Arc::new(zeta_prompt::Event::BufferChange {
555 old_path,
556 path,
557 diff,
558 in_open_source_repo,
559 predicted: self.predicted,
560 }),
561 old_snapshot: self.old_snapshot.clone(),
562 new_snapshot_version: self.new_snapshot.version.clone(),
563 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
564 ..self.new_snapshot.anchor_before(edit_range.end),
565 })
566 }
567 }
568
569 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
570 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
571 return (self.clone(), None);
572 };
573
574 let total_edit_range_before_pause = self
575 .total_edit_range_at_last_pause_boundary
576 .clone()
577 .unwrap_or_else(|| self.total_edit_range.clone());
578
579 let Some(total_edit_range_after_pause) =
580 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
581 else {
582 return (self.clone(), None);
583 };
584
585 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
586 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
587
588 let before = LastEvent {
589 old_snapshot: self.old_snapshot.clone(),
590 new_snapshot: boundary_snapshot.clone(),
591 old_file: self.old_file.clone(),
592 new_file: self.new_file.clone(),
593 latest_edit_range: latest_edit_range_before_pause,
594 total_edit_range: total_edit_range_before_pause,
595 total_edit_range_at_last_pause_boundary: None,
596 predicted: self.predicted,
597 snapshot_after_last_editing_pause: None,
598 last_edit_time: self.last_edit_time,
599 };
600
601 let after = LastEvent {
602 old_snapshot: boundary_snapshot.clone(),
603 new_snapshot: self.new_snapshot.clone(),
604 old_file: self.old_file.clone(),
605 new_file: self.new_file.clone(),
606 latest_edit_range: latest_edit_range_after_pause,
607 total_edit_range: total_edit_range_after_pause,
608 total_edit_range_at_last_pause_boundary: None,
609 predicted: self.predicted,
610 snapshot_after_last_editing_pause: None,
611 last_edit_time: self.last_edit_time,
612 };
613
614 (before, Some(after))
615 }
616}
617
618fn compute_total_edit_range_between_snapshots(
619 old_snapshot: &TextBufferSnapshot,
620 new_snapshot: &TextBufferSnapshot,
621) -> Option<Range<Anchor>> {
622 let edits: Vec<Edit<usize>> = new_snapshot
623 .edits_since::<usize>(&old_snapshot.version)
624 .collect();
625
626 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
627 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
628 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
629
630 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
631}
632
633fn compute_old_range_for_new_range(
634 old_snapshot: &TextBufferSnapshot,
635 new_snapshot: &TextBufferSnapshot,
636 total_edit_range: &Range<Anchor>,
637) -> Option<Range<Point>> {
638 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
639 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
640
641 let edits: Vec<Edit<usize>> = new_snapshot
642 .edits_since::<usize>(&old_snapshot.version)
643 .collect();
644 let mut old_start_offset = None;
645 let mut old_end_offset = None;
646 let mut delta: isize = 0;
647
648 for edit in &edits {
649 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
650 old_start_offset = Some(if new_start_offset < edit.new.start {
651 new_start_offset.checked_add_signed(-delta)?
652 } else {
653 edit.old.start
654 });
655 }
656
657 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
658 old_end_offset = Some(if new_end_offset < edit.new.start {
659 new_end_offset.checked_add_signed(-delta)?
660 } else {
661 edit.old.end
662 });
663 }
664
665 delta += edit.new.len() as isize - edit.old.len() as isize;
666 }
667
668 let old_start_offset =
669 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
670 let old_end_offset =
671 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
672
673 Some(
674 old_snapshot.offset_to_point(old_start_offset)
675 ..old_snapshot.offset_to_point(old_end_offset),
676 )
677}
678
679fn compute_diff_between_snapshots_in_range(
680 old_snapshot: &TextBufferSnapshot,
681 new_snapshot: &TextBufferSnapshot,
682 total_edit_range: &Range<Anchor>,
683) -> Option<(String, Range<Point>)> {
684 let new_start_point = total_edit_range.start.to_point(new_snapshot);
685 let new_end_point = total_edit_range.end.to_point(new_snapshot);
686 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
687 let old_start_point = old_range.start;
688 let old_end_point = old_range.end;
689
690 const CONTEXT_LINES: u32 = 3;
691
692 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
693 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
694 let old_context_end_row =
695 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
696 let new_context_end_row =
697 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
698
699 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
700 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
701 let old_end_line_offset = old_snapshot
702 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
703 let new_end_line_offset = new_snapshot
704 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
705 let old_edit_range = old_start_line_offset..old_end_line_offset;
706 let new_edit_range = new_start_line_offset..new_end_line_offset;
707
708 if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
709 || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
710 {
711 return None;
712 }
713
714 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
715 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
716
717 let diff = language::unified_diff_with_offsets(
718 &old_region_text,
719 &new_region_text,
720 old_context_start_row,
721 new_context_start_row,
722 );
723
724 Some((diff, new_start_point..new_end_point))
725}
726
727fn buffer_path_with_id_fallback(
728 file: Option<&Arc<dyn File>>,
729 snapshot: &TextBufferSnapshot,
730 cx: &App,
731) -> Arc<Path> {
732 if let Some(file) = file {
733 file.full_path(cx).into()
734 } else {
735 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
736 }
737}
738
739impl EditPredictionStore {
740 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
741 cx.try_global::<EditPredictionStoreGlobal>()
742 .map(|global| global.0.clone())
743 }
744
745 pub fn global(
746 client: &Arc<Client>,
747 user_store: &Entity<UserStore>,
748 cx: &mut App,
749 ) -> Entity<Self> {
750 cx.try_global::<EditPredictionStoreGlobal>()
751 .map(|global| global.0.clone())
752 .unwrap_or_else(|| {
753 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
754 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
755 ep_store
756 })
757 }
758
759 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
760 let data_collection_choice = Self::load_data_collection_choice(cx);
761
762 let llm_token = global_llm_token(cx);
763
764 let (reject_tx, reject_rx) = mpsc::unbounded();
765 cx.background_spawn({
766 let client = client.clone();
767 let llm_token = llm_token.clone();
768 let app_version = AppVersion::global(cx);
769 let background_executor = cx.background_executor().clone();
770 async move {
771 Self::handle_rejected_predictions(
772 reject_rx,
773 client,
774 llm_token,
775 app_version,
776 background_executor,
777 )
778 .await
779 }
780 })
781 .detach();
782
783 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
784 cx.spawn(async move |this, cx| {
785 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
786 })
787 .detach();
788
789 let mut current_user = user_store.read(cx).watch_current_user();
790 let fetch_experiments_task = cx.spawn(async move |this, cx| {
791 while current_user.borrow().is_none() {
792 current_user.next().await;
793 }
794
795 this.update(cx, |this, cx| {
796 if cx.is_staff() {
797 this.refresh_available_experiments(cx);
798 }
799 })
800 .log_err();
801 });
802
803 let credentials_provider = zed_credentials_provider::global(cx);
804
805 let this = Self {
806 projects: HashMap::default(),
807 client,
808 user_store,
809 llm_token,
810 _fetch_experiments_task: fetch_experiments_task,
811 update_required: false,
812 edit_prediction_model: EditPredictionModel::Zeta,
813 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
814 preferred_experiment: None,
815 available_experiments: Vec::new(),
816 mercury: Mercury::new(cx),
817
818 data_collection_choice,
819 reject_predictions_tx: reject_tx,
820 settled_predictions_tx,
821 rated_predictions: Default::default(),
822 shown_predictions: Default::default(),
823 #[cfg(test)]
824 settled_event_callback: None,
825
826 credentials_provider,
827 };
828
829 this
830 }
831
832 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
833 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
834 let format = ZetaFormat::parse(&version_str).ok()?;
835 let model_id = env::var("ZED_ZETA_MODEL").ok();
836 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
837 Some(Zeta2RawConfig {
838 model_id,
839 environment,
840 format,
841 })
842 }
843
844 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
845 self.edit_prediction_model = model;
846 }
847
848 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
849 self.zeta2_raw_config = Some(config);
850 }
851
852 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
853 self.zeta2_raw_config.as_ref()
854 }
855
856 pub fn preferred_experiment(&self) -> Option<&str> {
857 self.preferred_experiment.as_deref()
858 }
859
860 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
861 self.preferred_experiment = experiment;
862 }
863
864 pub fn available_experiments(&self) -> &[String] {
865 &self.available_experiments
866 }
867
868 pub fn active_experiment(&self) -> Option<&str> {
869 self.preferred_experiment.as_deref().or_else(|| {
870 self.shown_predictions
871 .iter()
872 .find_map(|p| p.model_version.as_ref())
873 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
874 })
875 }
876
877 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
878 let client = self.client.clone();
879 let llm_token = self.llm_token.clone();
880 let app_version = AppVersion::global(cx);
881 let organization_id = self
882 .user_store
883 .read(cx)
884 .current_organization()
885 .map(|organization| organization.id.clone());
886
887 cx.spawn(async move |this, cx| {
888 let experiments = cx
889 .background_spawn(async move {
890 let http_client = client.http_client();
891 let token = client
892 .acquire_llm_token(&llm_token, organization_id.clone())
893 .await?;
894 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
895 let request = http_client::Request::builder()
896 .method(Method::GET)
897 .uri(url.as_ref())
898 .header("Authorization", format!("Bearer {}", token))
899 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
900 .body(Default::default())?;
901 let mut response = http_client.send(request).await?;
902 if response.status().is_success() {
903 let mut body = Vec::new();
904 response.body_mut().read_to_end(&mut body).await?;
905 let experiments: Vec<String> = serde_json::from_slice(&body)?;
906 Ok(experiments)
907 } else {
908 let mut body = String::new();
909 response.body_mut().read_to_string(&mut body).await?;
910 anyhow::bail!(
911 "Failed to fetch experiments: {:?}\nBody: {}",
912 response.status(),
913 body
914 );
915 }
916 })
917 .await?;
918 this.update(cx, |this, cx| {
919 this.available_experiments = experiments;
920 cx.notify();
921 })?;
922 anyhow::Ok(())
923 })
924 .detach_and_log_err(cx);
925 }
926
927 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
928 use ui::IconName;
929 match self.edit_prediction_model {
930 EditPredictionModel::Mercury => {
931 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
932 }
933 EditPredictionModel::Zeta => {
934 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
935 .with_disabled(IconName::ZedPredictDisabled)
936 .with_up(IconName::ZedPredictUp)
937 .with_down(IconName::ZedPredictDown)
938 .with_error(IconName::ZedPredictError)
939 }
940 EditPredictionModel::Fim { .. } => {
941 let settings = &all_language_settings(None, cx).edit_predictions;
942 match settings.provider {
943 EditPredictionProvider::Ollama => {
944 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
945 }
946 _ => {
947 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
948 }
949 }
950 }
951 }
952 }
953
954 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
955 self.mercury.api_token.read(cx).has_key()
956 }
957
958 pub fn mercury_has_payment_required_error(&self) -> bool {
959 self.mercury.has_payment_required_error()
960 }
961
962 pub fn clear_history(&mut self) {
963 for project_state in self.projects.values_mut() {
964 project_state.events.clear();
965 project_state.last_event.take();
966 }
967 }
968
969 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
970 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
971 project_state.events.clear();
972 project_state.last_event.take();
973 }
974 }
975
976 pub fn edit_history_for_project(
977 &self,
978 project: &Entity<Project>,
979 cx: &App,
980 ) -> Vec<StoredEvent> {
981 self.projects
982 .get(&project.entity_id())
983 .map(|project_state| project_state.events(cx))
984 .unwrap_or_default()
985 }
986
987 pub fn context_for_project<'a>(
988 &'a self,
989 project: &Entity<Project>,
990 cx: &'a mut App,
991 ) -> Vec<RelatedFile> {
992 self.projects
993 .get(&project.entity_id())
994 .map(|project_state| {
995 project_state.context.update(cx, |context, cx| {
996 context
997 .related_files_with_buffers(cx)
998 .map(|(mut related_file, buffer)| {
999 related_file.in_open_source_repo = buffer
1000 .read(cx)
1001 .file()
1002 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1003 related_file
1004 })
1005 .collect()
1006 })
1007 })
1008 .unwrap_or_default()
1009 }
1010
1011 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1012 self.projects
1013 .get(&project.entity_id())
1014 .and_then(|project| project.copilot.clone())
1015 }
1016
1017 pub fn start_copilot_for_project(
1018 &mut self,
1019 project: &Entity<Project>,
1020 cx: &mut Context<Self>,
1021 ) -> Option<Entity<Copilot>> {
1022 if DisableAiSettings::get(None, cx).disable_ai {
1023 return None;
1024 }
1025 let state = self.get_or_init_project(project, cx);
1026
1027 if state.copilot.is_some() {
1028 return state.copilot.clone();
1029 }
1030 let _project = project.clone();
1031 let project = project.read(cx);
1032
1033 let node = project.node_runtime().cloned();
1034 if let Some(node) = node {
1035 let next_id = project.languages().next_language_server_id();
1036 let fs = project.fs().clone();
1037
1038 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1039 state.copilot = Some(copilot.clone());
1040 Some(copilot)
1041 } else {
1042 None
1043 }
1044 }
1045
1046 pub fn context_for_project_with_buffers<'a>(
1047 &'a self,
1048 project: &Entity<Project>,
1049 cx: &'a mut App,
1050 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1051 self.projects
1052 .get(&project.entity_id())
1053 .map(|project| {
1054 project.context.update(cx, |context, cx| {
1055 context.related_files_with_buffers(cx).collect()
1056 })
1057 })
1058 .unwrap_or_default()
1059 }
1060
1061 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1062 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1063 self.user_store.read(cx).edit_prediction_usage()
1064 } else {
1065 None
1066 }
1067 }
1068
1069 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1070 self.get_or_init_project(project, cx);
1071 }
1072
1073 pub fn register_buffer(
1074 &mut self,
1075 buffer: &Entity<Buffer>,
1076 project: &Entity<Project>,
1077 cx: &mut Context<Self>,
1078 ) {
1079 let project_state = self.get_or_init_project(project, cx);
1080 Self::register_buffer_impl(project_state, buffer, project, cx);
1081 }
1082
1083 fn get_or_init_project(
1084 &mut self,
1085 project: &Entity<Project>,
1086 cx: &mut Context<Self>,
1087 ) -> &mut ProjectState {
1088 let entity_id = project.entity_id();
1089 self.projects
1090 .entry(entity_id)
1091 .or_insert_with(|| ProjectState {
1092 context: {
1093 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1094 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1095 this.handle_excerpt_store_event(entity_id, event);
1096 })
1097 .detach();
1098 related_excerpt_store
1099 },
1100 events: VecDeque::new(),
1101 last_event: None,
1102 recent_paths: VecDeque::new(),
1103 debug_tx: None,
1104 registered_buffers: HashMap::default(),
1105 current_prediction: None,
1106 cancelled_predictions: HashSet::default(),
1107 pending_predictions: ArrayVec::new(),
1108 next_pending_prediction_id: 0,
1109 last_edit_prediction_refresh: None,
1110 last_jump_prediction_refresh: None,
1111 license_detection_watchers: HashMap::default(),
1112 _subscriptions: [
1113 cx.subscribe(&project, Self::handle_project_event),
1114 cx.observe_release(&project, move |this, _, cx| {
1115 this.projects.remove(&entity_id);
1116 cx.notify();
1117 }),
1118 ],
1119 copilot: None,
1120 })
1121 }
1122
1123 pub fn remove_project(&mut self, project: &Entity<Project>) {
1124 self.projects.remove(&project.entity_id());
1125 }
1126
1127 fn handle_excerpt_store_event(
1128 &mut self,
1129 project_entity_id: EntityId,
1130 event: &RelatedExcerptStoreEvent,
1131 ) {
1132 if let Some(project_state) = self.projects.get(&project_entity_id) {
1133 if let Some(debug_tx) = project_state.debug_tx.clone() {
1134 match event {
1135 RelatedExcerptStoreEvent::StartedRefresh => {
1136 debug_tx
1137 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1138 ContextRetrievalStartedDebugEvent {
1139 project_entity_id: project_entity_id,
1140 timestamp: Instant::now(),
1141 search_prompt: String::new(),
1142 },
1143 ))
1144 .ok();
1145 }
1146 RelatedExcerptStoreEvent::FinishedRefresh {
1147 cache_hit_count,
1148 cache_miss_count,
1149 mean_definition_latency,
1150 max_definition_latency,
1151 } => {
1152 debug_tx
1153 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1154 ContextRetrievalFinishedDebugEvent {
1155 project_entity_id: project_entity_id,
1156 timestamp: Instant::now(),
1157 metadata: vec![
1158 (
1159 "Cache Hits",
1160 format!(
1161 "{}/{}",
1162 cache_hit_count,
1163 cache_hit_count + cache_miss_count
1164 )
1165 .into(),
1166 ),
1167 (
1168 "Max LSP Time",
1169 format!("{} ms", max_definition_latency.as_millis())
1170 .into(),
1171 ),
1172 (
1173 "Mean LSP Time",
1174 format!("{} ms", mean_definition_latency.as_millis())
1175 .into(),
1176 ),
1177 ],
1178 },
1179 ))
1180 .ok();
1181 }
1182 }
1183 }
1184 }
1185 }
1186
1187 pub fn debug_info(
1188 &mut self,
1189 project: &Entity<Project>,
1190 cx: &mut Context<Self>,
1191 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1192 let project_state = self.get_or_init_project(project, cx);
1193 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1194 project_state.debug_tx = Some(debug_watch_tx);
1195 debug_watch_rx
1196 }
1197
1198 fn handle_project_event(
1199 &mut self,
1200 project: Entity<Project>,
1201 event: &project::Event,
1202 cx: &mut Context<Self>,
1203 ) {
1204 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1205 return;
1206 }
1207 // TODO [zeta2] init with recent paths
1208 match event {
1209 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1210 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1211 return;
1212 };
1213 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1214 if let Some(path) = path {
1215 if let Some(ix) = project_state
1216 .recent_paths
1217 .iter()
1218 .position(|probe| probe == &path)
1219 {
1220 project_state.recent_paths.remove(ix);
1221 }
1222 project_state.recent_paths.push_front(path);
1223 }
1224 }
1225 project::Event::DiagnosticsUpdated { .. } => {
1226 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1227 self.refresh_prediction_from_diagnostics(
1228 project,
1229 DiagnosticSearchScope::Global,
1230 cx,
1231 );
1232 }
1233 }
1234 _ => (),
1235 }
1236 }
1237
1238 fn register_buffer_impl<'a>(
1239 project_state: &'a mut ProjectState,
1240 buffer: &Entity<Buffer>,
1241 project: &Entity<Project>,
1242 cx: &mut Context<Self>,
1243 ) -> &'a mut RegisteredBuffer {
1244 let buffer_id = buffer.entity_id();
1245
1246 if let Some(file) = buffer.read(cx).file() {
1247 let worktree_id = file.worktree_id(cx);
1248 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1249 project_state
1250 .license_detection_watchers
1251 .entry(worktree_id)
1252 .or_insert_with(|| {
1253 let project_entity_id = project.entity_id();
1254 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1255 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1256 else {
1257 return;
1258 };
1259 project_state
1260 .license_detection_watchers
1261 .remove(&worktree_id);
1262 })
1263 .detach();
1264 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1265 });
1266 }
1267 }
1268
1269 match project_state.registered_buffers.entry(buffer_id) {
1270 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1271 hash_map::Entry::Vacant(entry) => {
1272 let buf = buffer.read(cx);
1273 let snapshot = buf.text_snapshot();
1274 let file = buf.file().cloned();
1275 let project_entity_id = project.entity_id();
1276 entry.insert(RegisteredBuffer {
1277 snapshot,
1278 file,
1279 last_position: None,
1280 pending_predictions: Vec::new(),
1281 _subscriptions: [
1282 cx.subscribe(buffer, {
1283 let project = project.downgrade();
1284 move |this, buffer, event, cx| {
1285 if let language::BufferEvent::Edited { is_local } = event
1286 && let Some(project) = project.upgrade()
1287 {
1288 this.report_changes_for_buffer(
1289 &buffer, &project, false, *is_local, cx,
1290 );
1291 }
1292 }
1293 }),
1294 cx.observe_release(buffer, move |this, _buffer, _cx| {
1295 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1296 else {
1297 return;
1298 };
1299 project_state.registered_buffers.remove(&buffer_id);
1300 }),
1301 ],
1302 })
1303 }
1304 }
1305 }
1306
1307 fn report_changes_for_buffer(
1308 &mut self,
1309 buffer: &Entity<Buffer>,
1310 project: &Entity<Project>,
1311 is_predicted: bool,
1312 is_local: bool,
1313 cx: &mut Context<Self>,
1314 ) {
1315 let project_state = self.get_or_init_project(project, cx);
1316 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1317
1318 let buf = buffer.read(cx);
1319 let new_file = buf.file().cloned();
1320 let new_snapshot = buf.text_snapshot();
1321 if new_snapshot.version == registered_buffer.snapshot.version {
1322 return;
1323 }
1324 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1325 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1326 let mut edit_range: Option<Range<Anchor>> = None;
1327 let now = cx.background_executor().now();
1328
1329 for (_edit, anchor_range) in
1330 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1331 {
1332 edit_range = Some(match edit_range {
1333 None => anchor_range,
1334 Some(acc) => acc.start..anchor_range.end,
1335 });
1336 }
1337
1338 let Some(edit_range) = edit_range else {
1339 return;
1340 };
1341
1342 for pending_prediction in &mut registered_buffer.pending_predictions {
1343 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1344 pending_prediction.last_edit_at = now;
1345 }
1346 }
1347
1348 let include_in_history = is_local
1349 || collaborator_edit_overlaps_locality_region(
1350 project_state,
1351 project,
1352 buffer,
1353 &buf.snapshot(),
1354 &edit_range,
1355 cx,
1356 );
1357
1358 if !include_in_history {
1359 return;
1360 }
1361
1362 let is_recordable_history_edit =
1363 compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1364 .is_some();
1365
1366 let events = &mut project_state.events;
1367
1368 if !is_recordable_history_edit {
1369 if let Some(event) = project_state.last_event.take() {
1370 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1371 if events.len() + 1 >= EVENT_COUNT_MAX {
1372 events.pop_front();
1373 }
1374 events.push_back(event);
1375 }
1376 }
1377 return;
1378 }
1379
1380 if let Some(last_event) = project_state.last_event.as_mut() {
1381 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1382 == last_event.new_snapshot.remote_id()
1383 && old_snapshot.version == last_event.new_snapshot.version;
1384
1385 let prediction_source_changed = is_predicted != last_event.predicted;
1386
1387 let should_coalesce = is_next_snapshot_of_same_buffer
1388 && !prediction_source_changed
1389 && lines_between_ranges(
1390 &edit_range.to_point(&new_snapshot),
1391 &last_event.latest_edit_range.to_point(&new_snapshot),
1392 ) <= CHANGE_GROUPING_LINE_SPAN;
1393
1394 if should_coalesce {
1395 let pause_elapsed = last_event
1396 .last_edit_time
1397 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1398 .unwrap_or(false);
1399 if pause_elapsed {
1400 last_event.snapshot_after_last_editing_pause =
1401 Some(last_event.new_snapshot.clone());
1402 last_event.total_edit_range_at_last_pause_boundary =
1403 Some(last_event.total_edit_range.clone());
1404 }
1405
1406 last_event.latest_edit_range = edit_range.clone();
1407 last_event.total_edit_range =
1408 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1409 last_event.new_snapshot = new_snapshot;
1410 last_event.last_edit_time = Some(now);
1411 return;
1412 }
1413 }
1414
1415 if let Some(event) = project_state.last_event.take() {
1416 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1417 if events.len() + 1 >= EVENT_COUNT_MAX {
1418 events.pop_front();
1419 }
1420 events.push_back(event);
1421 }
1422 }
1423
1424 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1425
1426 project_state.last_event = Some(LastEvent {
1427 old_file,
1428 new_file,
1429 old_snapshot,
1430 new_snapshot,
1431 latest_edit_range: edit_range.clone(),
1432 total_edit_range: edit_range,
1433 total_edit_range_at_last_pause_boundary: None,
1434 predicted: is_predicted,
1435 snapshot_after_last_editing_pause: None,
1436 last_edit_time: Some(now),
1437 });
1438 }
1439
1440 fn prediction_at(
1441 &mut self,
1442 buffer: &Entity<Buffer>,
1443 position: Option<language::Anchor>,
1444 project: &Entity<Project>,
1445 cx: &App,
1446 ) -> Option<BufferEditPrediction<'_>> {
1447 let project_state = self.projects.get_mut(&project.entity_id())?;
1448 if let Some(position) = position
1449 && let Some(buffer) = project_state
1450 .registered_buffers
1451 .get_mut(&buffer.entity_id())
1452 {
1453 buffer.last_position = Some(position);
1454 }
1455
1456 let CurrentEditPrediction {
1457 requested_by,
1458 prediction,
1459 ..
1460 } = project_state.current_prediction.as_ref()?;
1461
1462 if prediction.targets_buffer(buffer.read(cx)) {
1463 Some(BufferEditPrediction::Local { prediction })
1464 } else {
1465 let show_jump = match requested_by {
1466 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1467 requested_by_buffer_id == &buffer.entity_id()
1468 }
1469 PredictionRequestedBy::DiagnosticsUpdate => true,
1470 };
1471
1472 if show_jump {
1473 Some(BufferEditPrediction::Jump { prediction })
1474 } else {
1475 None
1476 }
1477 }
1478 }
1479
1480 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1481 let Some(current_prediction) = self
1482 .projects
1483 .get_mut(&project.entity_id())
1484 .and_then(|project_state| project_state.current_prediction.take())
1485 else {
1486 return;
1487 };
1488
1489 self.report_changes_for_buffer(
1490 ¤t_prediction.prediction.buffer,
1491 project,
1492 true,
1493 true,
1494 cx,
1495 );
1496
1497 // can't hold &mut project_state ref across report_changes_for_buffer_call
1498 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1499 return;
1500 };
1501
1502 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1503 project_state.cancel_pending_prediction(pending_prediction, cx);
1504 }
1505
1506 match self.edit_prediction_model {
1507 EditPredictionModel::Mercury => {
1508 mercury::edit_prediction_accepted(
1509 current_prediction.prediction.id,
1510 self.client.http_client(),
1511 cx,
1512 );
1513 }
1514 EditPredictionModel::Zeta => {
1515 let is_cloud = !matches!(
1516 all_language_settings(None, cx).edit_predictions.provider,
1517 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1518 );
1519 if is_cloud {
1520 zeta::edit_prediction_accepted(self, current_prediction, cx)
1521 }
1522 }
1523 EditPredictionModel::Fim { .. } => {}
1524 }
1525 }
1526
1527 async fn handle_rejected_predictions(
1528 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1529 client: Arc<Client>,
1530 llm_token: LlmApiToken,
1531 app_version: Version,
1532 background_executor: BackgroundExecutor,
1533 ) {
1534 let mut rx = std::pin::pin!(rx.peekable());
1535 let mut batched = Vec::new();
1536
1537 while let Some(EditPredictionRejectionPayload {
1538 rejection,
1539 organization_id,
1540 }) = rx.next().await
1541 {
1542 batched.push(rejection);
1543
1544 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1545 select_biased! {
1546 next = rx.as_mut().peek().fuse() => {
1547 if next.is_some() {
1548 continue;
1549 }
1550 }
1551 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1552 }
1553 }
1554
1555 let url = client
1556 .http_client()
1557 .build_zed_llm_url("/predict_edits/reject", &[])
1558 .unwrap();
1559
1560 let flush_count = batched
1561 .len()
1562 // in case items have accumulated after failure
1563 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1564 let start = batched.len() - flush_count;
1565
1566 let body = RejectEditPredictionsBodyRef {
1567 rejections: &batched[start..],
1568 };
1569
1570 let result = Self::send_api_request::<()>(
1571 |builder| {
1572 let req = builder
1573 .uri(url.as_ref())
1574 .body(serde_json::to_string(&body)?.into());
1575 anyhow::Ok(req?)
1576 },
1577 client.clone(),
1578 llm_token.clone(),
1579 organization_id,
1580 app_version.clone(),
1581 true,
1582 )
1583 .await;
1584
1585 if result.log_err().is_some() {
1586 batched.drain(start..);
1587 }
1588 }
1589 }
1590
1591 async fn run_settled_predictions_worker(
1592 this: WeakEntity<Self>,
1593 mut rx: UnboundedReceiver<Instant>,
1594 cx: &mut AsyncApp,
1595 ) {
1596 let mut next_wake_time: Option<Instant> = None;
1597 loop {
1598 let now = cx.background_executor().now();
1599 if let Some(wake_time) = next_wake_time.take() {
1600 cx.background_executor()
1601 .timer(wake_time.duration_since(now))
1602 .await;
1603 } else {
1604 let Some(new_enqueue_time) = rx.next().await else {
1605 break;
1606 };
1607 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1608 while rx.next().now_or_never().flatten().is_some() {}
1609 continue;
1610 }
1611
1612 let Some(this) = this.upgrade() else {
1613 break;
1614 };
1615
1616 let now = cx.background_executor().now();
1617 let mut oldest_edited_at = None;
1618 let mut ready_predictions = Vec::new();
1619
1620 this.update(cx, |this, _| {
1621 for (_, project_state) in this.projects.iter_mut() {
1622 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1623 let mut pending_index = 0;
1624 while pending_index < registered_buffer.pending_predictions.len() {
1625 let pending_prediction =
1626 ®istered_buffer.pending_predictions[pending_index];
1627 let age = now.saturating_duration_since(pending_prediction.enqueued_at);
1628 if age >= EDIT_PREDICTION_SETTLED_TTL {
1629 registered_buffer.pending_predictions.remove(pending_index);
1630 continue;
1631 }
1632
1633 let quiet_for =
1634 now.saturating_duration_since(pending_prediction.last_edit_at);
1635 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1636 let pending_prediction =
1637 registered_buffer.pending_predictions.remove(pending_index);
1638 let settled_editable_region = registered_buffer
1639 .snapshot
1640 .text_for_range(
1641 pending_prediction.editable_anchor_range.clone(),
1642 )
1643 .collect::<String>();
1644 ready_predictions
1645 .push((pending_prediction, settled_editable_region));
1646 continue;
1647 }
1648
1649 if oldest_edited_at
1650 .is_none_or(|time| pending_prediction.last_edit_at < time)
1651 {
1652 oldest_edited_at = Some(pending_prediction.last_edit_at);
1653 }
1654 pending_index += 1;
1655 }
1656 }
1657 }
1658 });
1659
1660 for (pending_prediction, settled_editable_region) in ready_predictions {
1661 let PendingSettledPrediction {
1662 request_id,
1663 editable_region_before_prediction,
1664 predicted_editable_region,
1665 ts_error_count_before_prediction,
1666 ts_error_count_after_prediction,
1667 example,
1668 e2e_latency,
1669 ..
1670 } = pending_prediction;
1671 let settled_editable_region_for_metrics = settled_editable_region.clone();
1672 let kept_rate_result = cx
1673 .background_spawn(async move {
1674 compute_kept_rate(
1675 &editable_region_before_prediction,
1676 &predicted_editable_region,
1677 &settled_editable_region_for_metrics,
1678 )
1679 })
1680 .await;
1681
1682 #[cfg(test)]
1683 {
1684 let request_id = request_id.clone();
1685 let settled_editable_region = settled_editable_region.clone();
1686 this.update(cx, |this, _| {
1687 if let Some(callback) = &this.settled_event_callback {
1688 callback(request_id, settled_editable_region);
1689 }
1690 });
1691 }
1692
1693 telemetry::event!(
1694 EDIT_PREDICTION_SETTLED_EVENT,
1695 request_id = request_id.0.clone(),
1696 settled_editable_region,
1697 ts_error_count_before_prediction,
1698 ts_error_count_after_prediction,
1699 edit_bytes_candidate_new = kept_rate_result.candidate_new_chars,
1700 edit_bytes_reference_new = kept_rate_result.reference_new_chars,
1701 edit_bytes_candidate_deleted = kept_rate_result.candidate_deleted_chars,
1702 edit_bytes_reference_deleted = kept_rate_result.reference_deleted_chars,
1703 edit_bytes_kept = kept_rate_result.kept_chars,
1704 edit_bytes_correctly_deleted = kept_rate_result.correctly_deleted_chars,
1705 edit_bytes_discarded = kept_rate_result.discarded_chars,
1706 edit_bytes_context = kept_rate_result.context_chars,
1707 edit_bytes_kept_rate = kept_rate_result.kept_rate,
1708 edit_bytes_recall_rate = kept_rate_result.recall_rate,
1709 example,
1710 e2e_latency = e2e_latency.as_millis(),
1711 );
1712 }
1713
1714 next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1715 }
1716 }
1717
1718 pub(crate) fn enqueue_settled_prediction(
1719 &mut self,
1720 request_id: EditPredictionId,
1721 project: &Entity<Project>,
1722 edited_buffer: &Entity<Buffer>,
1723 edited_buffer_snapshot: &BufferSnapshot,
1724 editable_offset_range: Range<usize>,
1725 edit_preview: &EditPreview,
1726 example: Option<ExampleSpec>,
1727 e2e_latency: std::time::Duration,
1728 cx: &mut Context<Self>,
1729 ) {
1730 let this = &mut *self;
1731 let project_state = this.get_or_init_project(project, cx);
1732 let Some(registered_buffer) = project_state
1733 .registered_buffers
1734 .get_mut(&edited_buffer.entity_id())
1735 else {
1736 return;
1737 };
1738
1739 let editable_region_before_prediction = edited_buffer_snapshot
1740 .text_for_range(editable_offset_range.clone())
1741 .collect::<String>();
1742 let editable_anchor_range_for_result =
1743 edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
1744 let predicted_editable_region = edit_preview
1745 .result_text_snapshot()
1746 .text_for_range(editable_anchor_range_for_result.clone())
1747 .collect();
1748 let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
1749 edited_buffer_snapshot
1750 .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
1751 );
1752 let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
1753 edit_preview.result_syntax_snapshot().layers_for_range(
1754 editable_anchor_range_for_result,
1755 edit_preview.result_text_snapshot(),
1756 true,
1757 ),
1758 );
1759 let editable_anchor_range =
1760 edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
1761 let now = cx.background_executor().now();
1762 registered_buffer
1763 .pending_predictions
1764 .push(PendingSettledPrediction {
1765 request_id,
1766 editable_anchor_range,
1767 editable_region_before_prediction,
1768 predicted_editable_region,
1769 ts_error_count_before_prediction,
1770 ts_error_count_after_prediction,
1771 example,
1772 e2e_latency,
1773 enqueued_at: now,
1774 last_edit_at: now,
1775 });
1776 this.settled_predictions_tx.unbounded_send(now).ok();
1777 }
1778
1779 fn reject_current_prediction(
1780 &mut self,
1781 reason: EditPredictionRejectReason,
1782 project: &Entity<Project>,
1783 cx: &App,
1784 ) {
1785 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1786 project_state.pending_predictions.clear();
1787 if let Some(prediction) = project_state.current_prediction.take() {
1788 let model_version = prediction.prediction.model_version.clone();
1789 self.reject_prediction(
1790 prediction.prediction.id,
1791 reason,
1792 prediction.was_shown,
1793 model_version,
1794 Some(prediction.e2e_latency),
1795 cx,
1796 );
1797 }
1798 };
1799 }
1800
1801 fn did_show_current_prediction(
1802 &mut self,
1803 project: &Entity<Project>,
1804 display_type: edit_prediction_types::SuggestionDisplayType,
1805 _cx: &mut Context<Self>,
1806 ) {
1807 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1808 return;
1809 };
1810
1811 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1812 return;
1813 };
1814
1815 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1816 let previous_shown_with = current_prediction.shown_with;
1817
1818 if previous_shown_with.is_none() || !is_jump {
1819 current_prediction.shown_with = Some(display_type);
1820 }
1821
1822 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1823
1824 if is_first_non_jump_show {
1825 current_prediction.was_shown = true;
1826 }
1827
1828 if is_first_non_jump_show {
1829 self.shown_predictions
1830 .push_front(current_prediction.prediction.clone());
1831 if self.shown_predictions.len() > 50 {
1832 let completion = self.shown_predictions.pop_back().unwrap();
1833 self.rated_predictions.remove(&completion.id);
1834 }
1835 }
1836 }
1837
1838 fn reject_prediction(
1839 &mut self,
1840 prediction_id: EditPredictionId,
1841 reason: EditPredictionRejectReason,
1842 was_shown: bool,
1843 model_version: Option<String>,
1844 e2e_latency: Option<std::time::Duration>,
1845 cx: &App,
1846 ) {
1847 match self.edit_prediction_model {
1848 EditPredictionModel::Zeta => {
1849 let is_cloud = !matches!(
1850 all_language_settings(None, cx).edit_predictions.provider,
1851 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1852 );
1853
1854 if is_cloud {
1855 let organization_id = self
1856 .user_store
1857 .read(cx)
1858 .current_organization()
1859 .map(|organization| organization.id.clone());
1860
1861 self.reject_predictions_tx
1862 .unbounded_send(EditPredictionRejectionPayload {
1863 rejection: EditPredictionRejection {
1864 request_id: prediction_id.to_string(),
1865 reason,
1866 was_shown,
1867 model_version,
1868 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1869 },
1870 organization_id,
1871 })
1872 .log_err();
1873 }
1874 }
1875 EditPredictionModel::Mercury => {
1876 mercury::edit_prediction_rejected(
1877 prediction_id,
1878 was_shown,
1879 reason,
1880 self.client.http_client(),
1881 cx,
1882 );
1883 }
1884 EditPredictionModel::Fim { .. } => {}
1885 }
1886 }
1887
1888 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1889 self.projects
1890 .get(&project.entity_id())
1891 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1892 }
1893
1894 pub fn refresh_prediction_from_buffer(
1895 &mut self,
1896 project: Entity<Project>,
1897 buffer: Entity<Buffer>,
1898 position: language::Anchor,
1899 cx: &mut Context<Self>,
1900 ) {
1901 self.queue_prediction_refresh(
1902 project.clone(),
1903 PredictEditsRequestTrigger::Other,
1904 buffer.entity_id(),
1905 cx,
1906 move |this, cx| {
1907 let Some(request_task) = this
1908 .update(cx, |this, cx| {
1909 this.request_prediction(
1910 &project,
1911 &buffer,
1912 position,
1913 PredictEditsRequestTrigger::Other,
1914 cx,
1915 )
1916 })
1917 .log_err()
1918 else {
1919 return Task::ready(anyhow::Ok(None));
1920 };
1921
1922 cx.spawn(async move |_cx| {
1923 request_task.await.map(|prediction_result| {
1924 prediction_result.map(|prediction_result| {
1925 (
1926 prediction_result,
1927 PredictionRequestedBy::Buffer(buffer.entity_id()),
1928 )
1929 })
1930 })
1931 })
1932 },
1933 )
1934 }
1935
1936 pub fn refresh_prediction_from_diagnostics(
1937 &mut self,
1938 project: Entity<Project>,
1939 scope: DiagnosticSearchScope,
1940 cx: &mut Context<Self>,
1941 ) {
1942 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1943 return;
1944 }
1945
1946 if currently_following(&project, cx) {
1947 return;
1948 }
1949
1950 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1951 return;
1952 };
1953
1954 // Prefer predictions from buffer
1955 if project_state.current_prediction.is_some() {
1956 log::debug!(
1957 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1958 );
1959 return;
1960 }
1961
1962 self.queue_prediction_refresh(
1963 project.clone(),
1964 PredictEditsRequestTrigger::Diagnostics,
1965 project.entity_id(),
1966 cx,
1967 move |this, cx| {
1968 let Some((active_buffer, snapshot, cursor_point)) = this
1969 .read_with(cx, |this, cx| {
1970 let project_state = this.projects.get(&project.entity_id())?;
1971 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1972 let snapshot = buffer.read(cx).snapshot();
1973
1974 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1975 return None;
1976 }
1977
1978 let cursor_point = position
1979 .map(|pos| pos.to_point(&snapshot))
1980 .unwrap_or_default();
1981
1982 Some((buffer, snapshot, cursor_point))
1983 })
1984 .log_err()
1985 .flatten()
1986 else {
1987 return Task::ready(anyhow::Ok(None));
1988 };
1989
1990 cx.spawn(async move |cx| {
1991 let diagnostic_search_range = match scope {
1992 DiagnosticSearchScope::Local => {
1993 let diagnostic_search_start =
1994 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1995 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1996 Point::new(diagnostic_search_start, 0)
1997 ..Point::new(diagnostic_search_end, 0)
1998 }
1999 DiagnosticSearchScope::Global => Default::default(),
2000 };
2001
2002 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
2003 active_buffer,
2004 &snapshot,
2005 diagnostic_search_range,
2006 cursor_point,
2007 &project,
2008 cx,
2009 )
2010 .await?
2011 else {
2012 return anyhow::Ok(None);
2013 };
2014
2015 let Some(prediction_result) = this
2016 .update(cx, |this, cx| {
2017 this.request_prediction(
2018 &project,
2019 &jump_buffer,
2020 jump_position,
2021 PredictEditsRequestTrigger::Diagnostics,
2022 cx,
2023 )
2024 })?
2025 .await?
2026 else {
2027 return anyhow::Ok(None);
2028 };
2029
2030 this.update(cx, |this, cx| {
2031 Some((
2032 if this
2033 .get_or_init_project(&project, cx)
2034 .current_prediction
2035 .is_none()
2036 {
2037 prediction_result
2038 } else {
2039 EditPredictionResult {
2040 id: prediction_result.id,
2041 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2042 model_version: prediction_result.model_version,
2043 e2e_latency: prediction_result.e2e_latency,
2044 }
2045 },
2046 PredictionRequestedBy::DiagnosticsUpdate,
2047 ))
2048 })
2049 })
2050 },
2051 );
2052 }
2053
2054 fn predictions_enabled_at(
2055 snapshot: &BufferSnapshot,
2056 position: Option<language::Anchor>,
2057 cx: &App,
2058 ) -> bool {
2059 let file = snapshot.file();
2060 let all_settings = all_language_settings(file, cx);
2061 if !all_settings.show_edit_predictions(snapshot.language(), cx)
2062 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2063 {
2064 return false;
2065 }
2066
2067 if let Some(last_position) = position {
2068 let settings = snapshot.settings_at(last_position, cx);
2069
2070 if !settings.edit_predictions_disabled_in.is_empty()
2071 && let Some(scope) = snapshot.language_scope_at(last_position)
2072 && let Some(scope_name) = scope.override_name()
2073 && settings
2074 .edit_predictions_disabled_in
2075 .iter()
2076 .any(|s| s == scope_name)
2077 {
2078 return false;
2079 }
2080 }
2081
2082 true
2083 }
2084
2085 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2086}
2087
2088fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2089 let Some(app_state) = AppState::try_global(cx) else {
2090 return false;
2091 };
2092
2093 app_state
2094 .workspace_store
2095 .read(cx)
2096 .workspaces()
2097 .filter_map(|workspace| workspace.upgrade())
2098 .any(|workspace| {
2099 workspace.read(cx).project().entity_id() == project.entity_id()
2100 && workspace
2101 .read(cx)
2102 .leader_for_pane(workspace.read(cx).active_pane())
2103 .is_some()
2104 })
2105}
2106
2107fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2108 match provider {
2109 EditPredictionProvider::Zed
2110 | EditPredictionProvider::Mercury
2111 | EditPredictionProvider::Ollama
2112 | EditPredictionProvider::OpenAiCompatibleApi
2113 | EditPredictionProvider::Experimental(_) => true,
2114 EditPredictionProvider::None
2115 | EditPredictionProvider::Copilot
2116 | EditPredictionProvider::Codestral => false,
2117 }
2118}
2119
2120impl EditPredictionStore {
2121 fn queue_prediction_refresh(
2122 &mut self,
2123 project: Entity<Project>,
2124 request_trigger: PredictEditsRequestTrigger,
2125 throttle_entity: EntityId,
2126 cx: &mut Context<Self>,
2127 do_refresh: impl FnOnce(
2128 WeakEntity<Self>,
2129 &mut AsyncApp,
2130 )
2131 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2132 + 'static,
2133 ) {
2134 fn select_throttle(
2135 project_state: &mut ProjectState,
2136 request_trigger: PredictEditsRequestTrigger,
2137 ) -> &mut Option<(EntityId, Instant)> {
2138 match request_trigger {
2139 PredictEditsRequestTrigger::Diagnostics => {
2140 &mut project_state.last_jump_prediction_refresh
2141 }
2142 _ => &mut project_state.last_edit_prediction_refresh,
2143 }
2144 }
2145
2146 let (needs_acceptance_tracking, max_pending_predictions) =
2147 match all_language_settings(None, cx).edit_predictions.provider {
2148 EditPredictionProvider::Zed
2149 | EditPredictionProvider::Mercury
2150 | EditPredictionProvider::Experimental(_) => (true, 2),
2151 EditPredictionProvider::Ollama => (false, 1),
2152 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2153 EditPredictionProvider::None
2154 | EditPredictionProvider::Copilot
2155 | EditPredictionProvider::Codestral => {
2156 log::error!("queue_prediction_refresh called with non-store provider");
2157 return;
2158 }
2159 };
2160
2161 let drop_on_cancel = !needs_acceptance_tracking;
2162 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2163 let project_state = self.get_or_init_project(&project, cx);
2164 let pending_prediction_id = project_state.next_pending_prediction_id;
2165 project_state.next_pending_prediction_id += 1;
2166 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2167
2168 let task = cx.spawn(async move |this, cx| {
2169 let throttle_wait = this
2170 .update(cx, |this, cx| {
2171 let project_state = this.get_or_init_project(&project, cx);
2172 let throttle = *select_throttle(project_state, request_trigger);
2173
2174 let now = cx.background_executor().now();
2175 throttle.and_then(|(last_entity, last_timestamp)| {
2176 if throttle_entity != last_entity {
2177 return None;
2178 }
2179 (last_timestamp + throttle_timeout).checked_duration_since(now)
2180 })
2181 })
2182 .ok()
2183 .flatten();
2184
2185 if let Some(timeout) = throttle_wait {
2186 cx.background_executor().timer(timeout).await;
2187 }
2188
2189 // If this task was cancelled before the throttle timeout expired,
2190 // do not perform a request. Also skip if another task already
2191 // proceeded since we were enqueued (duplicate).
2192 let mut is_cancelled = true;
2193 this.update(cx, |this, cx| {
2194 let project_state = this.get_or_init_project(&project, cx);
2195 let was_cancelled = project_state
2196 .cancelled_predictions
2197 .remove(&pending_prediction_id);
2198 if was_cancelled {
2199 return;
2200 }
2201
2202 // Another request has been already sent since this was enqueued
2203 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2204 return;
2205 }
2206
2207 let new_refresh = (throttle_entity, cx.background_executor().now());
2208 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2209 is_cancelled = false;
2210 })
2211 .ok();
2212 if is_cancelled {
2213 return None;
2214 }
2215
2216 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2217 let new_prediction_metadata = new_prediction_result
2218 .as_ref()
2219 .map(|(prediction, _)| (prediction.id.clone(), prediction.model_version.clone()));
2220
2221 // When a prediction completes, remove it from the pending list, and cancel
2222 // any pending predictions that were enqueued before it.
2223 this.update(cx, |this, cx| {
2224 let project_state = this.get_or_init_project(&project, cx);
2225
2226 let is_cancelled = project_state
2227 .cancelled_predictions
2228 .remove(&pending_prediction_id);
2229
2230 let new_current_prediction = if !is_cancelled
2231 && let Some((prediction_result, requested_by)) = new_prediction_result
2232 {
2233 match prediction_result.prediction {
2234 Ok(prediction) => {
2235 let new_prediction = CurrentEditPrediction {
2236 requested_by,
2237 prediction,
2238 was_shown: false,
2239 shown_with: None,
2240 e2e_latency: prediction_result.e2e_latency,
2241 };
2242
2243 if let Some(current_prediction) =
2244 project_state.current_prediction.as_ref()
2245 {
2246 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2247 {
2248 this.reject_current_prediction(
2249 EditPredictionRejectReason::Replaced,
2250 &project,
2251 cx,
2252 );
2253
2254 Some(new_prediction)
2255 } else {
2256 this.reject_prediction(
2257 new_prediction.prediction.id,
2258 EditPredictionRejectReason::CurrentPreferred,
2259 false,
2260 new_prediction.prediction.model_version,
2261 Some(new_prediction.e2e_latency),
2262 cx,
2263 );
2264 None
2265 }
2266 } else {
2267 Some(new_prediction)
2268 }
2269 }
2270 Err(reject_reason) => {
2271 this.reject_prediction(
2272 prediction_result.id,
2273 reject_reason,
2274 false,
2275 prediction_result.model_version,
2276 Some(prediction_result.e2e_latency),
2277 cx,
2278 );
2279 None
2280 }
2281 }
2282 } else {
2283 None
2284 };
2285
2286 let project_state = this.get_or_init_project(&project, cx);
2287
2288 if let Some(new_prediction) = new_current_prediction {
2289 project_state.current_prediction = Some(new_prediction);
2290 }
2291
2292 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2293 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2294 if pending_prediction.id == pending_prediction_id {
2295 pending_predictions.remove(ix);
2296 for pending_prediction in pending_predictions.drain(0..ix) {
2297 project_state.cancel_pending_prediction(pending_prediction, cx)
2298 }
2299 break;
2300 }
2301 }
2302 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2303 cx.notify();
2304 })
2305 .ok();
2306
2307 new_prediction_metadata
2308 });
2309
2310 if project_state.pending_predictions.len() < max_pending_predictions {
2311 project_state
2312 .pending_predictions
2313 .push(PendingPrediction {
2314 id: pending_prediction_id,
2315 task,
2316 drop_on_cancel,
2317 })
2318 .unwrap();
2319 } else {
2320 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2321 project_state
2322 .pending_predictions
2323 .push(PendingPrediction {
2324 id: pending_prediction_id,
2325 task,
2326 drop_on_cancel,
2327 })
2328 .unwrap();
2329 project_state.cancel_pending_prediction(pending_prediction, cx);
2330 }
2331 }
2332
2333 pub fn request_prediction(
2334 &mut self,
2335 project: &Entity<Project>,
2336 active_buffer: &Entity<Buffer>,
2337 position: language::Anchor,
2338 trigger: PredictEditsRequestTrigger,
2339 cx: &mut Context<Self>,
2340 ) -> Task<Result<Option<EditPredictionResult>>> {
2341 self.request_prediction_internal(
2342 project.clone(),
2343 active_buffer.clone(),
2344 position,
2345 trigger,
2346 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2347 cx,
2348 )
2349 }
2350
2351 fn request_prediction_internal(
2352 &mut self,
2353 project: Entity<Project>,
2354 active_buffer: Entity<Buffer>,
2355 position: language::Anchor,
2356 trigger: PredictEditsRequestTrigger,
2357 allow_jump: bool,
2358 cx: &mut Context<Self>,
2359 ) -> Task<Result<Option<EditPredictionResult>>> {
2360 self.get_or_init_project(&project, cx);
2361 let project_state = self.projects.get(&project.entity_id()).unwrap();
2362 let stored_events = project_state.events(cx);
2363 let has_events = !stored_events.is_empty();
2364 let events: Vec<Arc<zeta_prompt::Event>> =
2365 stored_events.iter().map(|e| e.event.clone()).collect();
2366 let debug_tx = project_state.debug_tx.clone();
2367
2368 let snapshot = active_buffer.read(cx).snapshot();
2369 let cursor_point = position.to_point(&snapshot);
2370 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2371 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2372 let diagnostic_search_range =
2373 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2374
2375 let related_files = self.context_for_project(&project, cx);
2376 let mode = match all_language_settings(snapshot.file(), cx).edit_predictions_mode() {
2377 EditPredictionsMode::Eager => PredictEditsMode::Eager,
2378 EditPredictionsMode::Subtle => PredictEditsMode::Subtle,
2379 };
2380
2381 let is_open_source = snapshot
2382 .file()
2383 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2384 && events.iter().all(|event| event.in_open_source_repo())
2385 && related_files.iter().all(|file| file.in_open_source_repo);
2386
2387 let can_collect_data = !cfg!(test)
2388 && is_open_source
2389 && self.is_data_collection_enabled(cx)
2390 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2391 let inputs = EditPredictionModelInput {
2392 project: project.clone(),
2393 buffer: active_buffer,
2394 snapshot,
2395 position,
2396 events,
2397 related_files,
2398 mode,
2399 trigger,
2400 diagnostic_search_range,
2401 debug_tx,
2402 can_collect_data,
2403 is_open_source,
2404 };
2405
2406 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2407
2408 let task = match self.edit_prediction_model {
2409 EditPredictionModel::Zeta => {
2410 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2411 }
2412 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2413 EditPredictionModel::Mercury => {
2414 self.mercury
2415 .request_prediction(inputs, self.credentials_provider.clone(), cx)
2416 }
2417 };
2418
2419 cx.spawn(async move |this, cx| {
2420 let prediction = task.await?;
2421
2422 // Only fall back to diagnostics-based prediction if we got a
2423 // the model had nothing to suggest for the buffer
2424 if prediction.is_none()
2425 && allow_jump
2426 && has_events
2427 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2428 {
2429 this.update(cx, |this, cx| {
2430 this.refresh_prediction_from_diagnostics(
2431 project,
2432 DiagnosticSearchScope::Local,
2433 cx,
2434 );
2435 })?;
2436 return anyhow::Ok(None);
2437 }
2438
2439 Ok(prediction)
2440 })
2441 }
2442
2443 pub(crate) async fn next_diagnostic_location(
2444 active_buffer: Entity<Buffer>,
2445 active_buffer_snapshot: &BufferSnapshot,
2446 active_buffer_diagnostic_search_range: Range<Point>,
2447 active_buffer_cursor_point: Point,
2448 project: &Entity<Project>,
2449 cx: &mut AsyncApp,
2450 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2451 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2452 .selections_in_range(
2453 Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2454 false,
2455 )
2456 .flat_map(|(_, _, _, selections)| {
2457 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2458 })
2459 .collect();
2460
2461 let mut jump_location = active_buffer_snapshot
2462 .diagnostic_groups(None)
2463 .into_iter()
2464 .filter_map(|(_, group)| {
2465 let range = &group.entries[group.primary_ix]
2466 .range
2467 .to_point(&active_buffer_snapshot);
2468 if range.overlaps(&active_buffer_diagnostic_search_range) {
2469 return None;
2470 }
2471 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2472 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2473 });
2474 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2475 <= DIAGNOSTIC_LINES_RANGE;
2476 if near_collaborator && !near_local {
2477 return None;
2478 }
2479 Some(range.start)
2480 })
2481 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2482 .map(|position| {
2483 (
2484 active_buffer.clone(),
2485 active_buffer_snapshot.anchor_before(position),
2486 )
2487 });
2488
2489 if jump_location.is_none() {
2490 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2491 let file = buffer.file()?;
2492
2493 Some(ProjectPath {
2494 worktree_id: file.worktree_id(cx),
2495 path: file.path().clone(),
2496 })
2497 });
2498
2499 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2500 project
2501 .diagnostic_summaries(false, cx)
2502 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2503 .map(|(path, _, _)| {
2504 let shared_prefix = path
2505 .path
2506 .components()
2507 .zip(
2508 active_buffer_path
2509 .as_ref()
2510 .map(|p| p.path.components())
2511 .unwrap_or_default(),
2512 )
2513 .take_while(|(a, b)| a == b)
2514 .count();
2515 (path, shared_prefix)
2516 })
2517 .collect()
2518 });
2519
2520 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2521
2522 for (path, _) in candidates {
2523 let candidate_buffer = project
2524 .update(cx, |project, cx| project.open_buffer(path, cx))
2525 .await?;
2526
2527 let (has_collaborators, diagnostic_position) =
2528 candidate_buffer.read_with(cx, |buffer, _cx| {
2529 let snapshot = buffer.snapshot();
2530 let has_collaborators = snapshot
2531 .selections_in_range(
2532 Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2533 false,
2534 )
2535 .next()
2536 .is_some();
2537 let position = buffer
2538 .buffer_diagnostics(None)
2539 .into_iter()
2540 .min_by_key(|entry| entry.diagnostic.severity)
2541 .map(|entry| entry.range.start);
2542 (has_collaborators, position)
2543 });
2544
2545 if has_collaborators {
2546 continue;
2547 }
2548
2549 if let Some(position) = diagnostic_position {
2550 jump_location = Some((candidate_buffer, position));
2551 break;
2552 }
2553 }
2554 }
2555
2556 anyhow::Ok(jump_location)
2557 }
2558
2559 async fn send_raw_llm_request(
2560 request: RawCompletionRequest,
2561 client: Arc<Client>,
2562 custom_url: Option<Arc<Url>>,
2563 llm_token: LlmApiToken,
2564 organization_id: Option<OrganizationId>,
2565 app_version: Version,
2566 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2567 let url = if let Some(custom_url) = custom_url {
2568 custom_url.as_ref().clone()
2569 } else {
2570 client
2571 .http_client()
2572 .build_zed_llm_url("/predict_edits/raw", &[])?
2573 };
2574
2575 Self::send_api_request(
2576 |builder| {
2577 let req = builder
2578 .uri(url.as_ref())
2579 .body(serde_json::to_string(&request)?.into());
2580 Ok(req?)
2581 },
2582 client,
2583 llm_token,
2584 organization_id,
2585 app_version,
2586 true,
2587 )
2588 .await
2589 }
2590
2591 pub(crate) async fn send_v3_request(
2592 input: ZetaPromptInput,
2593 preferred_experiment: Option<String>,
2594 client: Arc<Client>,
2595 llm_token: LlmApiToken,
2596 organization_id: Option<OrganizationId>,
2597 app_version: Version,
2598 trigger: PredictEditsRequestTrigger,
2599 mode: PredictEditsMode,
2600 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2601 let url = client
2602 .http_client()
2603 .build_zed_llm_url("/predict_edits/v3", &[])?;
2604
2605 let request = PredictEditsV3Request { input, trigger };
2606
2607 let json_bytes = serde_json::to_vec(&request)?;
2608 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2609
2610 Self::send_api_request(
2611 |builder| {
2612 let builder = builder
2613 .uri(url.as_ref())
2614 .header("Content-Encoding", "zstd")
2615 .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref());
2616 let builder = if let Some(preferred_experiment) = preferred_experiment.as_deref() {
2617 builder.header(PREFERRED_EXPERIMENT_HEADER_NAME, preferred_experiment)
2618 } else {
2619 builder
2620 };
2621 let req = builder.body(compressed.clone().into());
2622 Ok(req?)
2623 },
2624 client,
2625 llm_token,
2626 organization_id,
2627 app_version,
2628 true,
2629 )
2630 .await
2631 }
2632
2633 async fn send_api_request<Res>(
2634 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2635 client: Arc<Client>,
2636 llm_token: LlmApiToken,
2637 organization_id: Option<OrganizationId>,
2638 app_version: Version,
2639 require_auth: bool,
2640 ) -> Result<(Res, Option<EditPredictionUsage>)>
2641 where
2642 Res: DeserializeOwned,
2643 {
2644 let http_client = client.http_client();
2645 let mut token = if require_auth {
2646 Some(
2647 client
2648 .acquire_llm_token(&llm_token, organization_id.clone())
2649 .await?,
2650 )
2651 } else {
2652 client
2653 .acquire_llm_token(&llm_token, organization_id.clone())
2654 .await
2655 .ok()
2656 };
2657 let mut did_retry = false;
2658
2659 loop {
2660 let request_builder = http_client::Request::builder().method(Method::POST);
2661
2662 let mut request_builder = request_builder
2663 .header("Content-Type", "application/json")
2664 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2665
2666 // Only add Authorization header if we have a token
2667 if let Some(ref token_value) = token {
2668 request_builder =
2669 request_builder.header("Authorization", format!("Bearer {}", token_value));
2670 }
2671
2672 let request = build(request_builder)?;
2673
2674 let mut response = http_client.send(request).await?;
2675
2676 if let Some(minimum_required_version) = response
2677 .headers()
2678 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2679 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2680 {
2681 anyhow::ensure!(
2682 app_version >= minimum_required_version,
2683 ZedUpdateRequiredError {
2684 minimum_version: minimum_required_version
2685 }
2686 );
2687 }
2688
2689 if response.status().is_success() {
2690 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2691
2692 let mut body = Vec::new();
2693 response.body_mut().read_to_end(&mut body).await?;
2694 return Ok((serde_json::from_slice(&body)?, usage));
2695 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2696 did_retry = true;
2697 token = Some(
2698 client
2699 .refresh_llm_token(&llm_token, organization_id.clone())
2700 .await?,
2701 );
2702 } else {
2703 let mut body = String::new();
2704 response.body_mut().read_to_string(&mut body).await?;
2705 anyhow::bail!(
2706 "Request failed with status: {:?}\nBody: {}",
2707 response.status(),
2708 body
2709 );
2710 }
2711 }
2712 }
2713
2714 pub fn refresh_context(
2715 &mut self,
2716 project: &Entity<Project>,
2717 buffer: &Entity<language::Buffer>,
2718 cursor_position: language::Anchor,
2719 cx: &mut Context<Self>,
2720 ) {
2721 self.get_or_init_project(project, cx)
2722 .context
2723 .update(cx, |store, cx| {
2724 store.refresh(buffer.clone(), cursor_position, cx);
2725 });
2726 }
2727
2728 #[cfg(feature = "cli-support")]
2729 pub fn set_context_for_buffer(
2730 &mut self,
2731 project: &Entity<Project>,
2732 related_files: Vec<RelatedFile>,
2733 cx: &mut Context<Self>,
2734 ) {
2735 self.get_or_init_project(project, cx)
2736 .context
2737 .update(cx, |store, cx| {
2738 store.set_related_files(related_files, cx);
2739 });
2740 }
2741
2742 #[cfg(feature = "cli-support")]
2743 pub fn set_recent_paths_for_project(
2744 &mut self,
2745 project: &Entity<Project>,
2746 paths: impl IntoIterator<Item = project::ProjectPath>,
2747 cx: &mut Context<Self>,
2748 ) {
2749 let project_state = self.get_or_init_project(project, cx);
2750 project_state.recent_paths = paths.into_iter().collect();
2751 }
2752
2753 fn is_file_open_source(
2754 &self,
2755 project: &Entity<Project>,
2756 file: &Arc<dyn File>,
2757 cx: &App,
2758 ) -> bool {
2759 if !file.is_local() || file.is_private() {
2760 return false;
2761 }
2762 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2763 return false;
2764 };
2765 project_state
2766 .license_detection_watchers
2767 .get(&file.worktree_id(cx))
2768 .as_ref()
2769 .is_some_and(|watcher| watcher.is_project_open_source())
2770 }
2771
2772 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2773 self.data_collection_choice.is_enabled(cx)
2774 }
2775
2776 fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2777 let choice = KeyValueStore::global(cx)
2778 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2779 .log_err()
2780 .flatten();
2781
2782 match choice.as_deref() {
2783 Some("true") => DataCollectionChoice::Enabled,
2784 Some("false") => DataCollectionChoice::Disabled,
2785 Some(_) => {
2786 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2787 DataCollectionChoice::NotAnswered
2788 }
2789 None => DataCollectionChoice::NotAnswered,
2790 }
2791 }
2792
2793 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2794 self.data_collection_choice = self.data_collection_choice.toggle();
2795 let new_choice = self.data_collection_choice;
2796 let is_enabled = new_choice.is_enabled(cx);
2797 let kvp = KeyValueStore::global(cx);
2798 db::write_and_log(cx, move || async move {
2799 kvp.write_kvp(
2800 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2801 is_enabled.to_string(),
2802 )
2803 .await
2804 });
2805 }
2806
2807 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2808 self.shown_predictions.iter()
2809 }
2810
2811 pub fn shown_completions_len(&self) -> usize {
2812 self.shown_predictions.len()
2813 }
2814
2815 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2816 self.rated_predictions.contains(id)
2817 }
2818
2819 pub fn rate_prediction(
2820 &mut self,
2821 prediction: &EditPrediction,
2822 rating: EditPredictionRating,
2823 feedback: String,
2824 cx: &mut Context<Self>,
2825 ) {
2826 let organization = self.user_store.read(cx).current_organization();
2827
2828 self.rated_predictions.insert(prediction.id.clone());
2829
2830 cx.background_spawn({
2831 let client = self.client.clone();
2832 let prediction_id = prediction.id.to_string();
2833 let inputs = serde_json::to_value(&prediction.inputs);
2834 let output = prediction
2835 .edit_preview
2836 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2837 async move {
2838 client
2839 .cloud_client()
2840 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2841 organization_id: organization.map(|organization| organization.id.clone()),
2842 request_id: prediction_id,
2843 rating: match rating {
2844 EditPredictionRating::Positive => "positive".to_string(),
2845 EditPredictionRating::Negative => "negative".to_string(),
2846 },
2847 inputs: inputs?,
2848 output,
2849 feedback,
2850 })
2851 .await?;
2852
2853 anyhow::Ok(())
2854 }
2855 })
2856 .detach_and_log_err(cx);
2857
2858 cx.notify();
2859 }
2860}
2861
2862fn collaborator_edit_overlaps_locality_region(
2863 project_state: &ProjectState,
2864 project: &Entity<Project>,
2865 buffer: &Entity<Buffer>,
2866 snapshot: &BufferSnapshot,
2867 edit_range: &Range<Anchor>,
2868 cx: &App,
2869) -> bool {
2870 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2871 return false;
2872 };
2873
2874 if active_buffer.entity_id() != buffer.entity_id() {
2875 return false;
2876 }
2877
2878 let locality_point_range = expand_context_syntactically_then_linewise(
2879 snapshot,
2880 (position..position).to_point(snapshot),
2881 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2882 );
2883 let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2884
2885 edit_range.overlaps(&locality_anchor_range, snapshot)
2886}
2887
2888fn merge_trailing_events_if_needed(
2889 events: &mut VecDeque<StoredEvent>,
2890 end_snapshot: &TextBufferSnapshot,
2891 latest_snapshot: &TextBufferSnapshot,
2892 latest_edit_range: &Range<Anchor>,
2893) {
2894 if let Some(last_event) = events.back() {
2895 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2896 return;
2897 }
2898 if !latest_snapshot
2899 .version
2900 .observed_all(&last_event.new_snapshot_version)
2901 {
2902 return;
2903 }
2904 }
2905
2906 let mut next_old_event = None;
2907 let mut mergeable_count = 0;
2908 for old_event in events.iter().rev() {
2909 if let Some(next_old_event) = next_old_event
2910 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2911 {
2912 break;
2913 }
2914 mergeable_count += 1;
2915 next_old_event = Some(old_event);
2916 }
2917
2918 if mergeable_count <= 1 {
2919 return;
2920 }
2921
2922 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2923 let oldest_event = events_to_merge.peek().unwrap();
2924 let oldest_snapshot = oldest_event.old_snapshot.clone();
2925 let newest_snapshot = end_snapshot;
2926 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2927
2928 for event in events.range(events.len() - mergeable_count + 1..) {
2929 merged_edit_range =
2930 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2931 }
2932
2933 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2934 &oldest_snapshot,
2935 newest_snapshot,
2936 &merged_edit_range,
2937 ) {
2938 let merged_event = match oldest_event.event.as_ref() {
2939 zeta_prompt::Event::BufferChange {
2940 old_path,
2941 path,
2942 in_open_source_repo,
2943 ..
2944 } => StoredEvent {
2945 event: Arc::new(zeta_prompt::Event::BufferChange {
2946 old_path: old_path.clone(),
2947 path: path.clone(),
2948 diff,
2949 in_open_source_repo: *in_open_source_repo,
2950 predicted: events_to_merge.all(|e| {
2951 matches!(
2952 e.event.as_ref(),
2953 zeta_prompt::Event::BufferChange {
2954 predicted: true,
2955 ..
2956 }
2957 )
2958 }),
2959 }),
2960 old_snapshot: oldest_snapshot.clone(),
2961 new_snapshot_version: newest_snapshot.version.clone(),
2962 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2963 ..newest_snapshot.anchor_before(edit_range.end),
2964 },
2965 };
2966 events.truncate(events.len() - mergeable_count);
2967 events.push_back(merged_event);
2968 }
2969}
2970
2971fn merge_anchor_ranges(
2972 left: &Range<Anchor>,
2973 right: &Range<Anchor>,
2974 snapshot: &TextBufferSnapshot,
2975) -> Range<Anchor> {
2976 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2977 left.start
2978 } else {
2979 right.start
2980 };
2981 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2982 left.end
2983 } else {
2984 right.end
2985 };
2986 start..end
2987}
2988
2989#[derive(Error, Debug)]
2990#[error(
2991 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2992)]
2993pub struct ZedUpdateRequiredError {
2994 minimum_version: Version,
2995}
2996
2997#[derive(Debug, Clone, Copy)]
2998pub enum DataCollectionChoice {
2999 NotAnswered,
3000 Enabled,
3001 Disabled,
3002}
3003
3004impl DataCollectionChoice {
3005 pub fn is_enabled(self, cx: &App) -> bool {
3006 if cx.is_staff() {
3007 return true;
3008 }
3009 match self {
3010 Self::Enabled => true,
3011 Self::NotAnswered | Self::Disabled => false,
3012 }
3013 }
3014
3015 #[must_use]
3016 pub fn toggle(&self) -> DataCollectionChoice {
3017 match self {
3018 Self::Enabled => Self::Disabled,
3019 Self::Disabled => Self::Enabled,
3020 Self::NotAnswered => Self::Enabled,
3021 }
3022 }
3023}
3024
3025impl From<bool> for DataCollectionChoice {
3026 fn from(value: bool) -> Self {
3027 match value {
3028 true => DataCollectionChoice::Enabled,
3029 false => DataCollectionChoice::Disabled,
3030 }
3031 }
3032}
3033
3034struct ZedPredictUpsell;
3035
3036impl Dismissable for ZedPredictUpsell {
3037 const KEY: &'static str = "dismissed-edit-predict-upsell";
3038
3039 fn dismissed(cx: &App) -> bool {
3040 // To make this backwards compatible with older versions of Zed, we
3041 // check if the user has seen the previous Edit Prediction Onboarding
3042 // before, by checking the data collection choice which was written to
3043 // the database once the user clicked on "Accept and Enable"
3044 let kvp = KeyValueStore::global(cx);
3045 if kvp
3046 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3047 .log_err()
3048 .is_some_and(|s| s.is_some())
3049 {
3050 return true;
3051 }
3052
3053 kvp.read_kvp(Self::KEY)
3054 .log_err()
3055 .is_some_and(|s| s.is_some())
3056 }
3057}
3058
3059pub fn should_show_upsell_modal(cx: &App) -> bool {
3060 !ZedPredictUpsell::dismissed(cx)
3061}
3062
3063pub fn init(cx: &mut App) {
3064 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3065 workspace.register_action(
3066 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3067 ZedPredictModal::toggle(
3068 workspace,
3069 workspace.user_store().clone(),
3070 workspace.client().clone(),
3071 window,
3072 cx,
3073 )
3074 },
3075 );
3076
3077 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3078 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3079 settings
3080 .project
3081 .all_languages
3082 .edit_predictions
3083 .get_or_insert_default()
3084 .provider = Some(EditPredictionProvider::None)
3085 });
3086 });
3087 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3088 EditPredictionStore::try_global(cx).and_then(|store| {
3089 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3090 })
3091 }
3092
3093 workspace.register_action(|workspace, _: &SignIn, window, cx| {
3094 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3095 copilot_ui::initiate_sign_in(copilot, window, cx);
3096 }
3097 });
3098 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3099 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3100 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3101 }
3102 });
3103 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3104 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3105 copilot_ui::initiate_sign_out(copilot, window, cx);
3106 }
3107 });
3108 })
3109 .detach();
3110}