1use anyhow::Result;
2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
3use cloud_api_client::LlmApiToken;
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PREDICT_EDITS_MODE_HEADER_NAME, PredictEditsMode, PredictEditsV3Request,
7 PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
8};
9use cloud_llm_client::{
10 EditPredictionRejectReason, EditPredictionRejection,
11 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
12 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
13};
14use collections::{HashMap, HashSet};
15use copilot::{Copilot, Reinstall, SignIn, SignOut};
16use credentials_provider::CredentialsProvider;
17use db::kvp::{Dismissable, KeyValueStore};
18use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
19use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
20use futures::{
21 AsyncReadExt as _, FutureExt as _, StreamExt as _,
22 channel::mpsc::{self, UnboundedReceiver},
23 select_biased,
24};
25use gpui::BackgroundExecutor;
26use gpui::http_client::Url;
27use gpui::{
28 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
29 http_client::{self, AsyncBody, Method},
30 prelude::*,
31};
32use heapless::Vec as ArrayVec;
33use language::{
34 Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point,
35 TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings,
36};
37use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
38use release_channel::AppVersion;
39use semver::Version;
40use serde::de::DeserializeOwned;
41use settings::{
42 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
43};
44use std::collections::{VecDeque, hash_map};
45use std::env;
46use text::{AnchorRangeExt, Edit};
47use workspace::{AppState, Workspace};
48use zeta_prompt::{ZetaFormat, ZetaPromptInput};
49
50use std::mem;
51use std::ops::Range;
52use std::path::Path;
53use std::rc::Rc;
54use std::str::FromStr as _;
55use std::sync::Arc;
56use std::time::{Duration, Instant};
57
58use thiserror::Error;
59use util::{RangeExt as _, ResultExt as _};
60
61pub mod cursor_excerpt;
62pub mod example_spec;
63pub mod fim;
64mod license_detection;
65pub mod mercury;
66pub mod metrics;
67pub mod ollama;
68mod onboarding_modal;
69pub mod open_ai_response;
70mod prediction;
71
72pub mod udiff;
73
74mod capture_example;
75pub mod open_ai_compatible;
76mod zed_edit_prediction_delegate;
77pub mod zeta;
78
79#[cfg(test)]
80mod edit_prediction_tests;
81
82use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
83use crate::example_spec::ExampleSpec;
84use crate::license_detection::LicenseDetectionWatcher;
85use crate::mercury::Mercury;
86pub use crate::metrics::{KeptRateResult, compute_kept_rate};
87use crate::onboarding_modal::ZedPredictModal;
88pub use crate::prediction::EditPrediction;
89pub use crate::prediction::EditPredictionId;
90use crate::prediction::EditPredictionResult;
91pub use capture_example::capture_example;
92pub use language_model::ApiKeyState;
93pub use telemetry_events::EditPredictionRating;
94pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
95
96actions!(
97 edit_prediction,
98 [
99 /// Resets the edit prediction onboarding state.
100 ResetOnboarding,
101 /// Clears the edit prediction history.
102 ClearHistory,
103 ]
104);
105
106/// Maximum number of events to track.
107const EVENT_COUNT_MAX: usize = 10;
108const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
109const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
110const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
111const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
112const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
113const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
114const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
115const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
116const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
117
118pub struct EditPredictionJumpsFeatureFlag;
119
120impl FeatureFlag for EditPredictionJumpsFeatureFlag {
121 const NAME: &'static str = "edit_prediction_jumps";
122}
123
124#[derive(Clone)]
125struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
126
127impl Global for EditPredictionStoreGlobal {}
128
129/// Configuration for using the raw Zeta2 endpoint.
130/// When set, the client uses the raw endpoint and constructs the prompt itself.
131/// The version is also used as the Baseten environment name (lowercased).
132#[derive(Clone)]
133pub struct Zeta2RawConfig {
134 pub model_id: Option<String>,
135 pub environment: Option<String>,
136 pub format: ZetaFormat,
137}
138
139pub struct EditPredictionStore {
140 client: Arc<Client>,
141 user_store: Entity<UserStore>,
142 llm_token: LlmApiToken,
143 _fetch_experiments_task: Task<()>,
144 projects: HashMap<EntityId, ProjectState>,
145 update_required: bool,
146 edit_prediction_model: EditPredictionModel,
147 zeta2_raw_config: Option<Zeta2RawConfig>,
148 preferred_experiment: Option<String>,
149 available_experiments: Vec<String>,
150 pub mercury: Mercury,
151 data_collection_choice: DataCollectionChoice,
152 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
153 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
154 shown_predictions: VecDeque<EditPrediction>,
155 rated_predictions: HashSet<EditPredictionId>,
156 #[cfg(test)]
157 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
158 credentials_provider: Arc<dyn CredentialsProvider>,
159}
160
161pub(crate) struct EditPredictionRejectionPayload {
162 rejection: EditPredictionRejection,
163 organization_id: Option<OrganizationId>,
164}
165
166#[derive(Copy, Clone, PartialEq, Eq)]
167pub enum EditPredictionModel {
168 Zeta,
169 Fim { format: EditPredictionPromptFormat },
170 Mercury,
171}
172
173#[derive(Clone)]
174pub struct EditPredictionModelInput {
175 project: Entity<Project>,
176 buffer: Entity<Buffer>,
177 snapshot: BufferSnapshot,
178 position: Anchor,
179 events: Vec<Arc<zeta_prompt::Event>>,
180 related_files: Vec<RelatedFile>,
181 mode: PredictEditsMode,
182 trigger: PredictEditsRequestTrigger,
183 diagnostic_search_range: Range<Point>,
184 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
185 can_collect_data: bool,
186 is_open_source: bool,
187}
188
189#[derive(Debug)]
190pub enum DebugEvent {
191 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
192 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
193 EditPredictionStarted(EditPredictionStartedDebugEvent),
194 EditPredictionFinished(EditPredictionFinishedDebugEvent),
195}
196
197#[derive(Debug)]
198pub struct ContextRetrievalStartedDebugEvent {
199 pub project_entity_id: EntityId,
200 pub timestamp: Instant,
201 pub search_prompt: String,
202}
203
204#[derive(Debug)]
205pub struct ContextRetrievalFinishedDebugEvent {
206 pub project_entity_id: EntityId,
207 pub timestamp: Instant,
208 pub metadata: Vec<(&'static str, SharedString)>,
209}
210
211#[derive(Debug)]
212pub struct EditPredictionStartedDebugEvent {
213 pub buffer: WeakEntity<Buffer>,
214 pub position: Anchor,
215 pub prompt: Option<String>,
216}
217
218#[derive(Debug)]
219pub struct EditPredictionFinishedDebugEvent {
220 pub buffer: WeakEntity<Buffer>,
221 pub position: Anchor,
222 pub model_output: Option<String>,
223}
224
225/// An event with associated metadata for reconstructing buffer state.
226#[derive(Clone)]
227pub struct StoredEvent {
228 pub event: Arc<zeta_prompt::Event>,
229 pub old_snapshot: TextBufferSnapshot,
230 pub new_snapshot_version: clock::Global,
231 pub total_edit_range: Range<Anchor>,
232}
233
234impl StoredEvent {
235 fn can_merge(
236 &self,
237 next_old_event: &StoredEvent,
238 latest_snapshot: &TextBufferSnapshot,
239 latest_edit_range: &Range<Anchor>,
240 ) -> bool {
241 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
242 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
243 return false;
244 }
245 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
246 return false;
247 }
248 if self.new_snapshot_version != next_old_event.old_snapshot.version {
249 return false;
250 }
251 if !latest_snapshot
252 .version
253 .observed_all(&next_old_event.new_snapshot_version)
254 {
255 return false;
256 }
257
258 let a_is_predicted = matches!(
259 self.event.as_ref(),
260 zeta_prompt::Event::BufferChange {
261 predicted: true,
262 ..
263 }
264 );
265 let b_is_predicted = matches!(
266 next_old_event.event.as_ref(),
267 zeta_prompt::Event::BufferChange {
268 predicted: true,
269 ..
270 }
271 );
272
273 // If events come from the same source (both predicted or both manual) then
274 // we would have coalesced them already.
275 if a_is_predicted == b_is_predicted {
276 return false;
277 }
278
279 let left_range = self.total_edit_range.to_point(latest_snapshot);
280 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
281 let latest_range = latest_edit_range.to_point(latest_snapshot);
282
283 // Events near to the latest edit are not merged if their sources differ.
284 if lines_between_ranges(&left_range, &latest_range)
285 .min(lines_between_ranges(&right_range, &latest_range))
286 <= CHANGE_GROUPING_LINE_SPAN
287 {
288 return false;
289 }
290
291 // Events that are distant from each other are not merged.
292 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
293 return false;
294 }
295
296 true
297 }
298}
299
300fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
301 if left.start > right.end {
302 return left.start.row - right.end.row;
303 }
304 if right.start > left.end {
305 return right.start.row - left.end.row;
306 }
307 0
308}
309
310struct ProjectState {
311 events: VecDeque<StoredEvent>,
312 last_event: Option<LastEvent>,
313 recent_paths: VecDeque<ProjectPath>,
314 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
315 current_prediction: Option<CurrentEditPrediction>,
316 next_pending_prediction_id: usize,
317 pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
318 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
319 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
320 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
321 cancelled_predictions: HashSet<usize>,
322 context: Entity<RelatedExcerptStore>,
323 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
324 _subscriptions: [gpui::Subscription; 2],
325 copilot: Option<Entity<Copilot>>,
326}
327
328impl ProjectState {
329 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
330 self.events
331 .iter()
332 .cloned()
333 .chain(self.last_event.as_ref().iter().flat_map(|event| {
334 let (one, two) = event.split_by_pause();
335 let one = one.finalize(&self.license_detection_watchers, cx);
336 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
337 one.into_iter().chain(two)
338 }))
339 .collect()
340 }
341
342 fn cancel_pending_prediction(
343 &mut self,
344 pending_prediction: PendingPrediction,
345 cx: &mut Context<EditPredictionStore>,
346 ) {
347 self.cancelled_predictions.insert(pending_prediction.id);
348
349 if pending_prediction.drop_on_cancel {
350 drop(pending_prediction.task);
351 } else {
352 cx.spawn(async move |this, cx| {
353 let Some(prediction_id) = pending_prediction.task.await else {
354 return;
355 };
356
357 this.update(cx, |this, cx| {
358 this.reject_prediction(
359 prediction_id,
360 EditPredictionRejectReason::Canceled,
361 false,
362 None,
363 None,
364 cx,
365 );
366 })
367 .ok();
368 })
369 .detach()
370 }
371 }
372
373 fn active_buffer(
374 &self,
375 project: &Entity<Project>,
376 cx: &App,
377 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
378 let project = project.read(cx);
379 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
380 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
381 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
382 Some((active_buffer, registered_buffer.last_position))
383 }
384}
385
386#[derive(Debug, Clone)]
387struct CurrentEditPrediction {
388 pub requested_by: PredictionRequestedBy,
389 pub prediction: EditPrediction,
390 pub was_shown: bool,
391 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
392 pub e2e_latency: std::time::Duration,
393}
394
395impl CurrentEditPrediction {
396 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
397 let Some(new_edits) = self
398 .prediction
399 .interpolate(&self.prediction.buffer.read(cx))
400 else {
401 return false;
402 };
403
404 if self.prediction.buffer != old_prediction.prediction.buffer {
405 return true;
406 }
407
408 let Some(old_edits) = old_prediction
409 .prediction
410 .interpolate(&old_prediction.prediction.buffer.read(cx))
411 else {
412 return true;
413 };
414
415 let requested_by_buffer_id = self.requested_by.buffer_id();
416
417 // This reduces the occurrence of UI thrash from replacing edits
418 //
419 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
420 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
421 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
422 && old_edits.len() == 1
423 && new_edits.len() == 1
424 {
425 let (old_range, old_text) = &old_edits[0];
426 let (new_range, new_text) = &new_edits[0];
427 new_range == old_range && new_text.starts_with(old_text.as_ref())
428 } else {
429 true
430 }
431 }
432}
433
434#[derive(Debug, Clone)]
435enum PredictionRequestedBy {
436 DiagnosticsUpdate,
437 Buffer(EntityId),
438}
439
440impl PredictionRequestedBy {
441 pub fn buffer_id(&self) -> Option<EntityId> {
442 match self {
443 PredictionRequestedBy::DiagnosticsUpdate => None,
444 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
445 }
446 }
447}
448
449const DIAGNOSTIC_LINES_RANGE: u32 = 20;
450
451#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
452pub enum DiagnosticSearchScope {
453 Local,
454 Global,
455}
456
457#[derive(Debug)]
458struct PendingPrediction {
459 id: usize,
460 task: Task<Option<EditPredictionId>>,
461 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
462 /// If false, the task is awaited to completion so rejection can be reported.
463 drop_on_cancel: bool,
464}
465
466/// A prediction from the perspective of a buffer.
467#[derive(Debug)]
468enum BufferEditPrediction<'a> {
469 Local { prediction: &'a EditPrediction },
470 Jump { prediction: &'a EditPrediction },
471}
472
473#[cfg(test)]
474impl std::ops::Deref for BufferEditPrediction<'_> {
475 type Target = EditPrediction;
476
477 fn deref(&self) -> &Self::Target {
478 match self {
479 BufferEditPrediction::Local { prediction } => prediction,
480 BufferEditPrediction::Jump { prediction } => prediction,
481 }
482 }
483}
484
485#[derive(Clone)]
486struct PendingSettledPrediction {
487 request_id: EditPredictionId,
488 editable_anchor_range: Range<Anchor>,
489 editable_region_before_prediction: String,
490 predicted_editable_region: String,
491 ts_error_count_before_prediction: usize,
492 ts_error_count_after_prediction: usize,
493 example: Option<ExampleSpec>,
494 enqueued_at: Instant,
495 last_edit_at: Instant,
496 e2e_latency: std::time::Duration,
497}
498
499struct RegisteredBuffer {
500 file: Option<Arc<dyn File>>,
501 snapshot: TextBufferSnapshot,
502 pending_predictions: Vec<PendingSettledPrediction>,
503 last_position: Option<Anchor>,
504 _subscriptions: [gpui::Subscription; 2],
505}
506
507#[derive(Clone)]
508struct LastEvent {
509 old_snapshot: TextBufferSnapshot,
510 new_snapshot: TextBufferSnapshot,
511 old_file: Option<Arc<dyn File>>,
512 new_file: Option<Arc<dyn File>>,
513 latest_edit_range: Range<Anchor>,
514 total_edit_range: Range<Anchor>,
515 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
516 predicted: bool,
517 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
518 last_edit_time: Option<Instant>,
519}
520
521impl LastEvent {
522 pub fn finalize(
523 &self,
524 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
525 cx: &App,
526 ) -> Option<StoredEvent> {
527 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
528 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
529
530 let in_open_source_repo =
531 [self.new_file.as_ref(), self.old_file.as_ref()]
532 .iter()
533 .all(|file| {
534 file.is_some_and(|file| {
535 license_detection_watchers
536 .get(&file.worktree_id(cx))
537 .is_some_and(|watcher| watcher.is_project_open_source())
538 })
539 });
540
541 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
542 &self.old_snapshot,
543 &self.new_snapshot,
544 &self.total_edit_range,
545 )?;
546
547 if path == old_path && diff.is_empty() {
548 None
549 } else {
550 Some(StoredEvent {
551 event: Arc::new(zeta_prompt::Event::BufferChange {
552 old_path,
553 path,
554 diff,
555 in_open_source_repo,
556 predicted: self.predicted,
557 }),
558 old_snapshot: self.old_snapshot.clone(),
559 new_snapshot_version: self.new_snapshot.version.clone(),
560 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
561 ..self.new_snapshot.anchor_before(edit_range.end),
562 })
563 }
564 }
565
566 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
567 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
568 return (self.clone(), None);
569 };
570
571 let total_edit_range_before_pause = self
572 .total_edit_range_at_last_pause_boundary
573 .clone()
574 .unwrap_or_else(|| self.total_edit_range.clone());
575
576 let Some(total_edit_range_after_pause) =
577 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
578 else {
579 return (self.clone(), None);
580 };
581
582 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
583 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
584
585 let before = LastEvent {
586 old_snapshot: self.old_snapshot.clone(),
587 new_snapshot: boundary_snapshot.clone(),
588 old_file: self.old_file.clone(),
589 new_file: self.new_file.clone(),
590 latest_edit_range: latest_edit_range_before_pause,
591 total_edit_range: total_edit_range_before_pause,
592 total_edit_range_at_last_pause_boundary: None,
593 predicted: self.predicted,
594 snapshot_after_last_editing_pause: None,
595 last_edit_time: self.last_edit_time,
596 };
597
598 let after = LastEvent {
599 old_snapshot: boundary_snapshot.clone(),
600 new_snapshot: self.new_snapshot.clone(),
601 old_file: self.old_file.clone(),
602 new_file: self.new_file.clone(),
603 latest_edit_range: latest_edit_range_after_pause,
604 total_edit_range: total_edit_range_after_pause,
605 total_edit_range_at_last_pause_boundary: None,
606 predicted: self.predicted,
607 snapshot_after_last_editing_pause: None,
608 last_edit_time: self.last_edit_time,
609 };
610
611 (before, Some(after))
612 }
613}
614
615fn compute_total_edit_range_between_snapshots(
616 old_snapshot: &TextBufferSnapshot,
617 new_snapshot: &TextBufferSnapshot,
618) -> Option<Range<Anchor>> {
619 let edits: Vec<Edit<usize>> = new_snapshot
620 .edits_since::<usize>(&old_snapshot.version)
621 .collect();
622
623 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
624 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
625 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
626
627 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
628}
629
630fn compute_old_range_for_new_range(
631 old_snapshot: &TextBufferSnapshot,
632 new_snapshot: &TextBufferSnapshot,
633 total_edit_range: &Range<Anchor>,
634) -> Option<Range<Point>> {
635 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
636 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
637
638 let edits: Vec<Edit<usize>> = new_snapshot
639 .edits_since::<usize>(&old_snapshot.version)
640 .collect();
641 let mut old_start_offset = None;
642 let mut old_end_offset = None;
643 let mut delta: isize = 0;
644
645 for edit in &edits {
646 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
647 old_start_offset = Some(if new_start_offset < edit.new.start {
648 new_start_offset.checked_add_signed(-delta)?
649 } else {
650 edit.old.start
651 });
652 }
653
654 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
655 old_end_offset = Some(if new_end_offset < edit.new.start {
656 new_end_offset.checked_add_signed(-delta)?
657 } else {
658 edit.old.end
659 });
660 }
661
662 delta += edit.new.len() as isize - edit.old.len() as isize;
663 }
664
665 let old_start_offset =
666 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
667 let old_end_offset =
668 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
669
670 Some(
671 old_snapshot.offset_to_point(old_start_offset)
672 ..old_snapshot.offset_to_point(old_end_offset),
673 )
674}
675
676fn compute_diff_between_snapshots_in_range(
677 old_snapshot: &TextBufferSnapshot,
678 new_snapshot: &TextBufferSnapshot,
679 total_edit_range: &Range<Anchor>,
680) -> Option<(String, Range<Point>)> {
681 let new_start_point = total_edit_range.start.to_point(new_snapshot);
682 let new_end_point = total_edit_range.end.to_point(new_snapshot);
683 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
684 let old_start_point = old_range.start;
685 let old_end_point = old_range.end;
686
687 const CONTEXT_LINES: u32 = 3;
688
689 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
690 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
691 let old_context_end_row =
692 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
693 let new_context_end_row =
694 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
695
696 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
697 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
698 let old_end_line_offset = old_snapshot
699 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
700 let new_end_line_offset = new_snapshot
701 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
702 let old_edit_range = old_start_line_offset..old_end_line_offset;
703 let new_edit_range = new_start_line_offset..new_end_line_offset;
704
705 if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
706 || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
707 {
708 return None;
709 }
710
711 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
712 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
713
714 let diff = language::unified_diff_with_offsets(
715 &old_region_text,
716 &new_region_text,
717 old_context_start_row,
718 new_context_start_row,
719 );
720
721 Some((diff, new_start_point..new_end_point))
722}
723
724fn buffer_path_with_id_fallback(
725 file: Option<&Arc<dyn File>>,
726 snapshot: &TextBufferSnapshot,
727 cx: &App,
728) -> Arc<Path> {
729 if let Some(file) = file {
730 file.full_path(cx).into()
731 } else {
732 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
733 }
734}
735
736impl EditPredictionStore {
737 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
738 cx.try_global::<EditPredictionStoreGlobal>()
739 .map(|global| global.0.clone())
740 }
741
742 pub fn global(
743 client: &Arc<Client>,
744 user_store: &Entity<UserStore>,
745 cx: &mut App,
746 ) -> Entity<Self> {
747 cx.try_global::<EditPredictionStoreGlobal>()
748 .map(|global| global.0.clone())
749 .unwrap_or_else(|| {
750 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
751 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
752 ep_store
753 })
754 }
755
756 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
757 let data_collection_choice = Self::load_data_collection_choice(cx);
758
759 let llm_token = global_llm_token(cx);
760
761 let (reject_tx, reject_rx) = mpsc::unbounded();
762 cx.background_spawn({
763 let client = client.clone();
764 let llm_token = llm_token.clone();
765 let app_version = AppVersion::global(cx);
766 let background_executor = cx.background_executor().clone();
767 async move {
768 Self::handle_rejected_predictions(
769 reject_rx,
770 client,
771 llm_token,
772 app_version,
773 background_executor,
774 )
775 .await
776 }
777 })
778 .detach();
779
780 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
781 cx.spawn(async move |this, cx| {
782 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
783 })
784 .detach();
785
786 let mut current_user = user_store.read(cx).watch_current_user();
787 let fetch_experiments_task = cx.spawn(async move |this, cx| {
788 while current_user.borrow().is_none() {
789 current_user.next().await;
790 }
791
792 this.update(cx, |this, cx| {
793 if cx.is_staff() {
794 this.refresh_available_experiments(cx);
795 }
796 })
797 .log_err();
798 });
799
800 let credentials_provider = zed_credentials_provider::global(cx);
801
802 let this = Self {
803 projects: HashMap::default(),
804 client,
805 user_store,
806 llm_token,
807 _fetch_experiments_task: fetch_experiments_task,
808 update_required: false,
809 edit_prediction_model: EditPredictionModel::Zeta,
810 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
811 preferred_experiment: None,
812 available_experiments: Vec::new(),
813 mercury: Mercury::new(cx),
814
815 data_collection_choice,
816 reject_predictions_tx: reject_tx,
817 settled_predictions_tx,
818 rated_predictions: Default::default(),
819 shown_predictions: Default::default(),
820 #[cfg(test)]
821 settled_event_callback: None,
822
823 credentials_provider,
824 };
825
826 this
827 }
828
829 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
830 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
831 let format = ZetaFormat::parse(&version_str).ok()?;
832 let model_id = env::var("ZED_ZETA_MODEL").ok();
833 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
834 Some(Zeta2RawConfig {
835 model_id,
836 environment,
837 format,
838 })
839 }
840
841 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
842 self.edit_prediction_model = model;
843 }
844
845 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
846 self.zeta2_raw_config = Some(config);
847 }
848
849 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
850 self.zeta2_raw_config.as_ref()
851 }
852
853 pub fn preferred_experiment(&self) -> Option<&str> {
854 self.preferred_experiment.as_deref()
855 }
856
857 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
858 self.preferred_experiment = experiment;
859 }
860
861 pub fn available_experiments(&self) -> &[String] {
862 &self.available_experiments
863 }
864
865 pub fn active_experiment(&self) -> Option<&str> {
866 self.preferred_experiment.as_deref().or_else(|| {
867 self.shown_predictions
868 .iter()
869 .find_map(|p| p.model_version.as_ref())
870 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
871 })
872 }
873
874 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
875 let client = self.client.clone();
876 let llm_token = self.llm_token.clone();
877 let app_version = AppVersion::global(cx);
878 let organization_id = self
879 .user_store
880 .read(cx)
881 .current_organization()
882 .map(|organization| organization.id.clone());
883
884 cx.spawn(async move |this, cx| {
885 let experiments = cx
886 .background_spawn(async move {
887 let http_client = client.http_client();
888 let token = client
889 .acquire_llm_token(&llm_token, organization_id.clone())
890 .await?;
891 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
892 let request = http_client::Request::builder()
893 .method(Method::GET)
894 .uri(url.as_ref())
895 .header("Authorization", format!("Bearer {}", token))
896 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
897 .body(Default::default())?;
898 let mut response = http_client.send(request).await?;
899 if response.status().is_success() {
900 let mut body = Vec::new();
901 response.body_mut().read_to_end(&mut body).await?;
902 let experiments: Vec<String> = serde_json::from_slice(&body)?;
903 Ok(experiments)
904 } else {
905 let mut body = String::new();
906 response.body_mut().read_to_string(&mut body).await?;
907 anyhow::bail!(
908 "Failed to fetch experiments: {:?}\nBody: {}",
909 response.status(),
910 body
911 );
912 }
913 })
914 .await?;
915 this.update(cx, |this, cx| {
916 this.available_experiments = experiments;
917 cx.notify();
918 })?;
919 anyhow::Ok(())
920 })
921 .detach_and_log_err(cx);
922 }
923
924 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
925 use ui::IconName;
926 match self.edit_prediction_model {
927 EditPredictionModel::Mercury => {
928 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
929 }
930 EditPredictionModel::Zeta => {
931 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
932 .with_disabled(IconName::ZedPredictDisabled)
933 .with_up(IconName::ZedPredictUp)
934 .with_down(IconName::ZedPredictDown)
935 .with_error(IconName::ZedPredictError)
936 }
937 EditPredictionModel::Fim { .. } => {
938 let settings = &all_language_settings(None, cx).edit_predictions;
939 match settings.provider {
940 EditPredictionProvider::Ollama => {
941 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
942 }
943 _ => {
944 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
945 }
946 }
947 }
948 }
949 }
950
951 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
952 self.mercury.api_token.read(cx).has_key()
953 }
954
955 pub fn mercury_has_payment_required_error(&self) -> bool {
956 self.mercury.has_payment_required_error()
957 }
958
959 pub fn clear_history(&mut self) {
960 for project_state in self.projects.values_mut() {
961 project_state.events.clear();
962 project_state.last_event.take();
963 }
964 }
965
966 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
967 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
968 project_state.events.clear();
969 project_state.last_event.take();
970 }
971 }
972
973 pub fn edit_history_for_project(
974 &self,
975 project: &Entity<Project>,
976 cx: &App,
977 ) -> Vec<StoredEvent> {
978 self.projects
979 .get(&project.entity_id())
980 .map(|project_state| project_state.events(cx))
981 .unwrap_or_default()
982 }
983
984 pub fn context_for_project<'a>(
985 &'a self,
986 project: &Entity<Project>,
987 cx: &'a mut App,
988 ) -> Vec<RelatedFile> {
989 self.projects
990 .get(&project.entity_id())
991 .map(|project_state| {
992 project_state.context.update(cx, |context, cx| {
993 context
994 .related_files_with_buffers(cx)
995 .map(|(mut related_file, buffer)| {
996 related_file.in_open_source_repo = buffer
997 .read(cx)
998 .file()
999 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1000 related_file
1001 })
1002 .collect()
1003 })
1004 })
1005 .unwrap_or_default()
1006 }
1007
1008 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1009 self.projects
1010 .get(&project.entity_id())
1011 .and_then(|project| project.copilot.clone())
1012 }
1013
1014 pub fn start_copilot_for_project(
1015 &mut self,
1016 project: &Entity<Project>,
1017 cx: &mut Context<Self>,
1018 ) -> Option<Entity<Copilot>> {
1019 if DisableAiSettings::get(None, cx).disable_ai {
1020 return None;
1021 }
1022 let state = self.get_or_init_project(project, cx);
1023
1024 if state.copilot.is_some() {
1025 return state.copilot.clone();
1026 }
1027 let _project = project.clone();
1028 let project = project.read(cx);
1029
1030 let node = project.node_runtime().cloned();
1031 if let Some(node) = node {
1032 let next_id = project.languages().next_language_server_id();
1033 let fs = project.fs().clone();
1034
1035 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1036 state.copilot = Some(copilot.clone());
1037 Some(copilot)
1038 } else {
1039 None
1040 }
1041 }
1042
1043 pub fn context_for_project_with_buffers<'a>(
1044 &'a self,
1045 project: &Entity<Project>,
1046 cx: &'a mut App,
1047 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1048 self.projects
1049 .get(&project.entity_id())
1050 .map(|project| {
1051 project.context.update(cx, |context, cx| {
1052 context.related_files_with_buffers(cx).collect()
1053 })
1054 })
1055 .unwrap_or_default()
1056 }
1057
1058 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1059 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1060 self.user_store.read(cx).edit_prediction_usage()
1061 } else {
1062 None
1063 }
1064 }
1065
1066 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1067 self.get_or_init_project(project, cx);
1068 }
1069
1070 pub fn register_buffer(
1071 &mut self,
1072 buffer: &Entity<Buffer>,
1073 project: &Entity<Project>,
1074 cx: &mut Context<Self>,
1075 ) {
1076 let project_state = self.get_or_init_project(project, cx);
1077 Self::register_buffer_impl(project_state, buffer, project, cx);
1078 }
1079
1080 fn get_or_init_project(
1081 &mut self,
1082 project: &Entity<Project>,
1083 cx: &mut Context<Self>,
1084 ) -> &mut ProjectState {
1085 let entity_id = project.entity_id();
1086 self.projects
1087 .entry(entity_id)
1088 .or_insert_with(|| ProjectState {
1089 context: {
1090 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1091 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1092 this.handle_excerpt_store_event(entity_id, event);
1093 })
1094 .detach();
1095 related_excerpt_store
1096 },
1097 events: VecDeque::new(),
1098 last_event: None,
1099 recent_paths: VecDeque::new(),
1100 debug_tx: None,
1101 registered_buffers: HashMap::default(),
1102 current_prediction: None,
1103 cancelled_predictions: HashSet::default(),
1104 pending_predictions: ArrayVec::new(),
1105 next_pending_prediction_id: 0,
1106 last_edit_prediction_refresh: None,
1107 last_jump_prediction_refresh: None,
1108 license_detection_watchers: HashMap::default(),
1109 _subscriptions: [
1110 cx.subscribe(&project, Self::handle_project_event),
1111 cx.observe_release(&project, move |this, _, cx| {
1112 this.projects.remove(&entity_id);
1113 cx.notify();
1114 }),
1115 ],
1116 copilot: None,
1117 })
1118 }
1119
1120 pub fn remove_project(&mut self, project: &Entity<Project>) {
1121 self.projects.remove(&project.entity_id());
1122 }
1123
1124 fn handle_excerpt_store_event(
1125 &mut self,
1126 project_entity_id: EntityId,
1127 event: &RelatedExcerptStoreEvent,
1128 ) {
1129 if let Some(project_state) = self.projects.get(&project_entity_id) {
1130 if let Some(debug_tx) = project_state.debug_tx.clone() {
1131 match event {
1132 RelatedExcerptStoreEvent::StartedRefresh => {
1133 debug_tx
1134 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1135 ContextRetrievalStartedDebugEvent {
1136 project_entity_id: project_entity_id,
1137 timestamp: Instant::now(),
1138 search_prompt: String::new(),
1139 },
1140 ))
1141 .ok();
1142 }
1143 RelatedExcerptStoreEvent::FinishedRefresh {
1144 cache_hit_count,
1145 cache_miss_count,
1146 mean_definition_latency,
1147 max_definition_latency,
1148 } => {
1149 debug_tx
1150 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1151 ContextRetrievalFinishedDebugEvent {
1152 project_entity_id: project_entity_id,
1153 timestamp: Instant::now(),
1154 metadata: vec![
1155 (
1156 "Cache Hits",
1157 format!(
1158 "{}/{}",
1159 cache_hit_count,
1160 cache_hit_count + cache_miss_count
1161 )
1162 .into(),
1163 ),
1164 (
1165 "Max LSP Time",
1166 format!("{} ms", max_definition_latency.as_millis())
1167 .into(),
1168 ),
1169 (
1170 "Mean LSP Time",
1171 format!("{} ms", mean_definition_latency.as_millis())
1172 .into(),
1173 ),
1174 ],
1175 },
1176 ))
1177 .ok();
1178 }
1179 }
1180 }
1181 }
1182 }
1183
1184 pub fn debug_info(
1185 &mut self,
1186 project: &Entity<Project>,
1187 cx: &mut Context<Self>,
1188 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1189 let project_state = self.get_or_init_project(project, cx);
1190 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1191 project_state.debug_tx = Some(debug_watch_tx);
1192 debug_watch_rx
1193 }
1194
1195 fn handle_project_event(
1196 &mut self,
1197 project: Entity<Project>,
1198 event: &project::Event,
1199 cx: &mut Context<Self>,
1200 ) {
1201 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1202 return;
1203 }
1204 // TODO [zeta2] init with recent paths
1205 match event {
1206 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1207 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1208 return;
1209 };
1210 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1211 if let Some(path) = path {
1212 if let Some(ix) = project_state
1213 .recent_paths
1214 .iter()
1215 .position(|probe| probe == &path)
1216 {
1217 project_state.recent_paths.remove(ix);
1218 }
1219 project_state.recent_paths.push_front(path);
1220 }
1221 }
1222 project::Event::DiagnosticsUpdated { .. } => {
1223 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1224 self.refresh_prediction_from_diagnostics(
1225 project,
1226 DiagnosticSearchScope::Global,
1227 cx,
1228 );
1229 }
1230 }
1231 _ => (),
1232 }
1233 }
1234
1235 fn register_buffer_impl<'a>(
1236 project_state: &'a mut ProjectState,
1237 buffer: &Entity<Buffer>,
1238 project: &Entity<Project>,
1239 cx: &mut Context<Self>,
1240 ) -> &'a mut RegisteredBuffer {
1241 let buffer_id = buffer.entity_id();
1242
1243 if let Some(file) = buffer.read(cx).file() {
1244 let worktree_id = file.worktree_id(cx);
1245 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1246 project_state
1247 .license_detection_watchers
1248 .entry(worktree_id)
1249 .or_insert_with(|| {
1250 let project_entity_id = project.entity_id();
1251 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1252 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1253 else {
1254 return;
1255 };
1256 project_state
1257 .license_detection_watchers
1258 .remove(&worktree_id);
1259 })
1260 .detach();
1261 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1262 });
1263 }
1264 }
1265
1266 match project_state.registered_buffers.entry(buffer_id) {
1267 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1268 hash_map::Entry::Vacant(entry) => {
1269 let buf = buffer.read(cx);
1270 let snapshot = buf.text_snapshot();
1271 let file = buf.file().cloned();
1272 let project_entity_id = project.entity_id();
1273 entry.insert(RegisteredBuffer {
1274 snapshot,
1275 file,
1276 last_position: None,
1277 pending_predictions: Vec::new(),
1278 _subscriptions: [
1279 cx.subscribe(buffer, {
1280 let project = project.downgrade();
1281 move |this, buffer, event, cx| {
1282 if let language::BufferEvent::Edited { is_local } = event
1283 && let Some(project) = project.upgrade()
1284 {
1285 this.report_changes_for_buffer(
1286 &buffer, &project, false, *is_local, cx,
1287 );
1288 }
1289 }
1290 }),
1291 cx.observe_release(buffer, move |this, _buffer, _cx| {
1292 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1293 else {
1294 return;
1295 };
1296 project_state.registered_buffers.remove(&buffer_id);
1297 }),
1298 ],
1299 })
1300 }
1301 }
1302 }
1303
1304 fn report_changes_for_buffer(
1305 &mut self,
1306 buffer: &Entity<Buffer>,
1307 project: &Entity<Project>,
1308 is_predicted: bool,
1309 is_local: bool,
1310 cx: &mut Context<Self>,
1311 ) {
1312 let project_state = self.get_or_init_project(project, cx);
1313 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1314
1315 let buf = buffer.read(cx);
1316 let new_file = buf.file().cloned();
1317 let new_snapshot = buf.text_snapshot();
1318 if new_snapshot.version == registered_buffer.snapshot.version {
1319 return;
1320 }
1321 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1322 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1323 let mut edit_range: Option<Range<Anchor>> = None;
1324 let now = cx.background_executor().now();
1325
1326 for (_edit, anchor_range) in
1327 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1328 {
1329 edit_range = Some(match edit_range {
1330 None => anchor_range,
1331 Some(acc) => acc.start..anchor_range.end,
1332 });
1333 }
1334
1335 let Some(edit_range) = edit_range else {
1336 return;
1337 };
1338
1339 for pending_prediction in &mut registered_buffer.pending_predictions {
1340 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1341 pending_prediction.last_edit_at = now;
1342 }
1343 }
1344
1345 let include_in_history = is_local
1346 || collaborator_edit_overlaps_locality_region(
1347 project_state,
1348 project,
1349 buffer,
1350 &buf.snapshot(),
1351 &edit_range,
1352 cx,
1353 );
1354
1355 if !include_in_history {
1356 return;
1357 }
1358
1359 let is_recordable_history_edit =
1360 compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1361 .is_some();
1362
1363 let events = &mut project_state.events;
1364
1365 if !is_recordable_history_edit {
1366 if let Some(event) = project_state.last_event.take() {
1367 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1368 if events.len() + 1 >= EVENT_COUNT_MAX {
1369 events.pop_front();
1370 }
1371 events.push_back(event);
1372 }
1373 }
1374 return;
1375 }
1376
1377 if let Some(last_event) = project_state.last_event.as_mut() {
1378 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1379 == last_event.new_snapshot.remote_id()
1380 && old_snapshot.version == last_event.new_snapshot.version;
1381
1382 let prediction_source_changed = is_predicted != last_event.predicted;
1383
1384 let should_coalesce = is_next_snapshot_of_same_buffer
1385 && !prediction_source_changed
1386 && lines_between_ranges(
1387 &edit_range.to_point(&new_snapshot),
1388 &last_event.latest_edit_range.to_point(&new_snapshot),
1389 ) <= CHANGE_GROUPING_LINE_SPAN;
1390
1391 if should_coalesce {
1392 let pause_elapsed = last_event
1393 .last_edit_time
1394 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1395 .unwrap_or(false);
1396 if pause_elapsed {
1397 last_event.snapshot_after_last_editing_pause =
1398 Some(last_event.new_snapshot.clone());
1399 last_event.total_edit_range_at_last_pause_boundary =
1400 Some(last_event.total_edit_range.clone());
1401 }
1402
1403 last_event.latest_edit_range = edit_range.clone();
1404 last_event.total_edit_range =
1405 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1406 last_event.new_snapshot = new_snapshot;
1407 last_event.last_edit_time = Some(now);
1408 return;
1409 }
1410 }
1411
1412 if let Some(event) = project_state.last_event.take() {
1413 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1414 if events.len() + 1 >= EVENT_COUNT_MAX {
1415 events.pop_front();
1416 }
1417 events.push_back(event);
1418 }
1419 }
1420
1421 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1422
1423 project_state.last_event = Some(LastEvent {
1424 old_file,
1425 new_file,
1426 old_snapshot,
1427 new_snapshot,
1428 latest_edit_range: edit_range.clone(),
1429 total_edit_range: edit_range,
1430 total_edit_range_at_last_pause_boundary: None,
1431 predicted: is_predicted,
1432 snapshot_after_last_editing_pause: None,
1433 last_edit_time: Some(now),
1434 });
1435 }
1436
1437 fn prediction_at(
1438 &mut self,
1439 buffer: &Entity<Buffer>,
1440 position: Option<language::Anchor>,
1441 project: &Entity<Project>,
1442 cx: &App,
1443 ) -> Option<BufferEditPrediction<'_>> {
1444 let project_state = self.projects.get_mut(&project.entity_id())?;
1445 if let Some(position) = position
1446 && let Some(buffer) = project_state
1447 .registered_buffers
1448 .get_mut(&buffer.entity_id())
1449 {
1450 buffer.last_position = Some(position);
1451 }
1452
1453 let CurrentEditPrediction {
1454 requested_by,
1455 prediction,
1456 ..
1457 } = project_state.current_prediction.as_ref()?;
1458
1459 if prediction.targets_buffer(buffer.read(cx)) {
1460 Some(BufferEditPrediction::Local { prediction })
1461 } else {
1462 let show_jump = match requested_by {
1463 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1464 requested_by_buffer_id == &buffer.entity_id()
1465 }
1466 PredictionRequestedBy::DiagnosticsUpdate => true,
1467 };
1468
1469 if show_jump {
1470 Some(BufferEditPrediction::Jump { prediction })
1471 } else {
1472 None
1473 }
1474 }
1475 }
1476
1477 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1478 let Some(current_prediction) = self
1479 .projects
1480 .get_mut(&project.entity_id())
1481 .and_then(|project_state| project_state.current_prediction.take())
1482 else {
1483 return;
1484 };
1485
1486 self.report_changes_for_buffer(
1487 ¤t_prediction.prediction.buffer,
1488 project,
1489 true,
1490 true,
1491 cx,
1492 );
1493
1494 // can't hold &mut project_state ref across report_changes_for_buffer_call
1495 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1496 return;
1497 };
1498
1499 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1500 project_state.cancel_pending_prediction(pending_prediction, cx);
1501 }
1502
1503 match self.edit_prediction_model {
1504 EditPredictionModel::Mercury => {
1505 mercury::edit_prediction_accepted(
1506 current_prediction.prediction.id,
1507 self.client.http_client(),
1508 cx,
1509 );
1510 }
1511 EditPredictionModel::Zeta => {
1512 let is_cloud = !matches!(
1513 all_language_settings(None, cx).edit_predictions.provider,
1514 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1515 );
1516 if is_cloud {
1517 zeta::edit_prediction_accepted(self, current_prediction, cx)
1518 }
1519 }
1520 EditPredictionModel::Fim { .. } => {}
1521 }
1522 }
1523
1524 async fn handle_rejected_predictions(
1525 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1526 client: Arc<Client>,
1527 llm_token: LlmApiToken,
1528 app_version: Version,
1529 background_executor: BackgroundExecutor,
1530 ) {
1531 let mut rx = std::pin::pin!(rx.peekable());
1532 let mut batched = Vec::new();
1533
1534 while let Some(EditPredictionRejectionPayload {
1535 rejection,
1536 organization_id,
1537 }) = rx.next().await
1538 {
1539 batched.push(rejection);
1540
1541 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1542 select_biased! {
1543 next = rx.as_mut().peek().fuse() => {
1544 if next.is_some() {
1545 continue;
1546 }
1547 }
1548 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1549 }
1550 }
1551
1552 let url = client
1553 .http_client()
1554 .build_zed_llm_url("/predict_edits/reject", &[])
1555 .unwrap();
1556
1557 let flush_count = batched
1558 .len()
1559 // in case items have accumulated after failure
1560 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1561 let start = batched.len() - flush_count;
1562
1563 let body = RejectEditPredictionsBodyRef {
1564 rejections: &batched[start..],
1565 };
1566
1567 let result = Self::send_api_request::<()>(
1568 |builder| {
1569 let req = builder
1570 .uri(url.as_ref())
1571 .body(serde_json::to_string(&body)?.into());
1572 anyhow::Ok(req?)
1573 },
1574 client.clone(),
1575 llm_token.clone(),
1576 organization_id,
1577 app_version.clone(),
1578 true,
1579 )
1580 .await;
1581
1582 if result.log_err().is_some() {
1583 batched.drain(start..);
1584 }
1585 }
1586 }
1587
1588 async fn run_settled_predictions_worker(
1589 this: WeakEntity<Self>,
1590 mut rx: UnboundedReceiver<Instant>,
1591 cx: &mut AsyncApp,
1592 ) {
1593 let mut next_wake_time: Option<Instant> = None;
1594 loop {
1595 let now = cx.background_executor().now();
1596 if let Some(wake_time) = next_wake_time.take() {
1597 cx.background_executor()
1598 .timer(wake_time.duration_since(now))
1599 .await;
1600 } else {
1601 let Some(new_enqueue_time) = rx.next().await else {
1602 break;
1603 };
1604 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1605 while rx.next().now_or_never().flatten().is_some() {}
1606 continue;
1607 }
1608
1609 let Some(this) = this.upgrade() else {
1610 break;
1611 };
1612
1613 let now = cx.background_executor().now();
1614 let mut oldest_edited_at = None;
1615 let mut ready_predictions = Vec::new();
1616
1617 this.update(cx, |this, _| {
1618 for (_, project_state) in this.projects.iter_mut() {
1619 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1620 let mut pending_index = 0;
1621 while pending_index < registered_buffer.pending_predictions.len() {
1622 let pending_prediction =
1623 ®istered_buffer.pending_predictions[pending_index];
1624 let age = now.saturating_duration_since(pending_prediction.enqueued_at);
1625 if age >= EDIT_PREDICTION_SETTLED_TTL {
1626 registered_buffer.pending_predictions.remove(pending_index);
1627 continue;
1628 }
1629
1630 let quiet_for =
1631 now.saturating_duration_since(pending_prediction.last_edit_at);
1632 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1633 let pending_prediction =
1634 registered_buffer.pending_predictions.remove(pending_index);
1635 let settled_editable_region = registered_buffer
1636 .snapshot
1637 .text_for_range(
1638 pending_prediction.editable_anchor_range.clone(),
1639 )
1640 .collect::<String>();
1641 ready_predictions
1642 .push((pending_prediction, settled_editable_region));
1643 continue;
1644 }
1645
1646 if oldest_edited_at
1647 .is_none_or(|time| pending_prediction.last_edit_at < time)
1648 {
1649 oldest_edited_at = Some(pending_prediction.last_edit_at);
1650 }
1651 pending_index += 1;
1652 }
1653 }
1654 }
1655 });
1656
1657 for (pending_prediction, settled_editable_region) in ready_predictions {
1658 let PendingSettledPrediction {
1659 request_id,
1660 editable_region_before_prediction,
1661 predicted_editable_region,
1662 ts_error_count_before_prediction,
1663 ts_error_count_after_prediction,
1664 example,
1665 e2e_latency,
1666 ..
1667 } = pending_prediction;
1668 let settled_editable_region_for_metrics = settled_editable_region.clone();
1669 let kept_rate_result = cx
1670 .background_spawn(async move {
1671 compute_kept_rate(
1672 &editable_region_before_prediction,
1673 &predicted_editable_region,
1674 &settled_editable_region_for_metrics,
1675 )
1676 })
1677 .await;
1678
1679 #[cfg(test)]
1680 {
1681 let request_id = request_id.clone();
1682 let settled_editable_region = settled_editable_region.clone();
1683 this.update(cx, |this, _| {
1684 if let Some(callback) = &this.settled_event_callback {
1685 callback(request_id, settled_editable_region);
1686 }
1687 });
1688 }
1689
1690 telemetry::event!(
1691 EDIT_PREDICTION_SETTLED_EVENT,
1692 request_id = request_id.0.clone(),
1693 settled_editable_region,
1694 ts_error_count_before_prediction,
1695 ts_error_count_after_prediction,
1696 edit_bytes_candidate_new = kept_rate_result.candidate_new_chars,
1697 edit_bytes_reference_new = kept_rate_result.reference_new_chars,
1698 edit_bytes_candidate_deleted = kept_rate_result.candidate_deleted_chars,
1699 edit_bytes_reference_deleted = kept_rate_result.reference_deleted_chars,
1700 edit_bytes_kept = kept_rate_result.kept_chars,
1701 edit_bytes_correctly_deleted = kept_rate_result.correctly_deleted_chars,
1702 edit_bytes_discarded = kept_rate_result.discarded_chars,
1703 edit_bytes_context = kept_rate_result.context_chars,
1704 edit_bytes_kept_rate = kept_rate_result.kept_rate,
1705 edit_bytes_recall_rate = kept_rate_result.recall_rate,
1706 example,
1707 e2e_latency = e2e_latency.as_millis(),
1708 );
1709 }
1710
1711 next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1712 }
1713 }
1714
1715 pub(crate) fn enqueue_settled_prediction(
1716 &mut self,
1717 request_id: EditPredictionId,
1718 project: &Entity<Project>,
1719 edited_buffer: &Entity<Buffer>,
1720 edited_buffer_snapshot: &BufferSnapshot,
1721 editable_offset_range: Range<usize>,
1722 edit_preview: &EditPreview,
1723 example: Option<ExampleSpec>,
1724 e2e_latency: std::time::Duration,
1725 cx: &mut Context<Self>,
1726 ) {
1727 let this = &mut *self;
1728 let project_state = this.get_or_init_project(project, cx);
1729 let Some(registered_buffer) = project_state
1730 .registered_buffers
1731 .get_mut(&edited_buffer.entity_id())
1732 else {
1733 return;
1734 };
1735
1736 let editable_region_before_prediction = edited_buffer_snapshot
1737 .text_for_range(editable_offset_range.clone())
1738 .collect::<String>();
1739 let editable_anchor_range_for_result =
1740 edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
1741 let predicted_editable_region = edit_preview
1742 .result_text_snapshot()
1743 .text_for_range(editable_anchor_range_for_result.clone())
1744 .collect();
1745 let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
1746 edited_buffer_snapshot
1747 .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
1748 );
1749 let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
1750 edit_preview.result_syntax_snapshot().layers_for_range(
1751 editable_anchor_range_for_result,
1752 edit_preview.result_text_snapshot(),
1753 true,
1754 ),
1755 );
1756 let editable_anchor_range =
1757 edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
1758 let now = cx.background_executor().now();
1759 registered_buffer
1760 .pending_predictions
1761 .push(PendingSettledPrediction {
1762 request_id,
1763 editable_anchor_range,
1764 editable_region_before_prediction,
1765 predicted_editable_region,
1766 ts_error_count_before_prediction,
1767 ts_error_count_after_prediction,
1768 example,
1769 e2e_latency,
1770 enqueued_at: now,
1771 last_edit_at: now,
1772 });
1773 this.settled_predictions_tx.unbounded_send(now).ok();
1774 }
1775
1776 fn reject_current_prediction(
1777 &mut self,
1778 reason: EditPredictionRejectReason,
1779 project: &Entity<Project>,
1780 cx: &App,
1781 ) {
1782 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1783 project_state.pending_predictions.clear();
1784 if let Some(prediction) = project_state.current_prediction.take() {
1785 let model_version = prediction.prediction.model_version.clone();
1786 self.reject_prediction(
1787 prediction.prediction.id,
1788 reason,
1789 prediction.was_shown,
1790 model_version,
1791 Some(prediction.e2e_latency),
1792 cx,
1793 );
1794 }
1795 };
1796 }
1797
1798 fn did_show_current_prediction(
1799 &mut self,
1800 project: &Entity<Project>,
1801 display_type: edit_prediction_types::SuggestionDisplayType,
1802 _cx: &mut Context<Self>,
1803 ) {
1804 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1805 return;
1806 };
1807
1808 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1809 return;
1810 };
1811
1812 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1813 let previous_shown_with = current_prediction.shown_with;
1814
1815 if previous_shown_with.is_none() || !is_jump {
1816 current_prediction.shown_with = Some(display_type);
1817 }
1818
1819 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1820
1821 if is_first_non_jump_show {
1822 current_prediction.was_shown = true;
1823 }
1824
1825 if is_first_non_jump_show {
1826 self.shown_predictions
1827 .push_front(current_prediction.prediction.clone());
1828 if self.shown_predictions.len() > 50 {
1829 let completion = self.shown_predictions.pop_back().unwrap();
1830 self.rated_predictions.remove(&completion.id);
1831 }
1832 }
1833 }
1834
1835 fn reject_prediction(
1836 &mut self,
1837 prediction_id: EditPredictionId,
1838 reason: EditPredictionRejectReason,
1839 was_shown: bool,
1840 model_version: Option<String>,
1841 e2e_latency: Option<std::time::Duration>,
1842 cx: &App,
1843 ) {
1844 match self.edit_prediction_model {
1845 EditPredictionModel::Zeta => {
1846 let is_cloud = !matches!(
1847 all_language_settings(None, cx).edit_predictions.provider,
1848 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1849 );
1850
1851 if is_cloud {
1852 let organization_id = self
1853 .user_store
1854 .read(cx)
1855 .current_organization()
1856 .map(|organization| organization.id.clone());
1857
1858 self.reject_predictions_tx
1859 .unbounded_send(EditPredictionRejectionPayload {
1860 rejection: EditPredictionRejection {
1861 request_id: prediction_id.to_string(),
1862 reason,
1863 was_shown,
1864 model_version,
1865 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1866 },
1867 organization_id,
1868 })
1869 .log_err();
1870 }
1871 }
1872 EditPredictionModel::Mercury => {
1873 mercury::edit_prediction_rejected(
1874 prediction_id,
1875 was_shown,
1876 reason,
1877 self.client.http_client(),
1878 cx,
1879 );
1880 }
1881 EditPredictionModel::Fim { .. } => {}
1882 }
1883 }
1884
1885 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1886 self.projects
1887 .get(&project.entity_id())
1888 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1889 }
1890
1891 pub fn refresh_prediction_from_buffer(
1892 &mut self,
1893 project: Entity<Project>,
1894 buffer: Entity<Buffer>,
1895 position: language::Anchor,
1896 cx: &mut Context<Self>,
1897 ) {
1898 self.queue_prediction_refresh(
1899 project.clone(),
1900 PredictEditsRequestTrigger::Other,
1901 buffer.entity_id(),
1902 cx,
1903 move |this, cx| {
1904 let Some(request_task) = this
1905 .update(cx, |this, cx| {
1906 this.request_prediction(
1907 &project,
1908 &buffer,
1909 position,
1910 PredictEditsRequestTrigger::Other,
1911 cx,
1912 )
1913 })
1914 .log_err()
1915 else {
1916 return Task::ready(anyhow::Ok(None));
1917 };
1918
1919 cx.spawn(async move |_cx| {
1920 request_task.await.map(|prediction_result| {
1921 prediction_result.map(|prediction_result| {
1922 (
1923 prediction_result,
1924 PredictionRequestedBy::Buffer(buffer.entity_id()),
1925 )
1926 })
1927 })
1928 })
1929 },
1930 )
1931 }
1932
1933 pub fn refresh_prediction_from_diagnostics(
1934 &mut self,
1935 project: Entity<Project>,
1936 scope: DiagnosticSearchScope,
1937 cx: &mut Context<Self>,
1938 ) {
1939 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1940 return;
1941 }
1942
1943 if currently_following(&project, cx) {
1944 return;
1945 }
1946
1947 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1948 return;
1949 };
1950
1951 // Prefer predictions from buffer
1952 if project_state.current_prediction.is_some() {
1953 log::debug!(
1954 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1955 );
1956 return;
1957 }
1958
1959 self.queue_prediction_refresh(
1960 project.clone(),
1961 PredictEditsRequestTrigger::Diagnostics,
1962 project.entity_id(),
1963 cx,
1964 move |this, cx| {
1965 let Some((active_buffer, snapshot, cursor_point)) = this
1966 .read_with(cx, |this, cx| {
1967 let project_state = this.projects.get(&project.entity_id())?;
1968 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1969 let snapshot = buffer.read(cx).snapshot();
1970
1971 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1972 return None;
1973 }
1974
1975 let cursor_point = position
1976 .map(|pos| pos.to_point(&snapshot))
1977 .unwrap_or_default();
1978
1979 Some((buffer, snapshot, cursor_point))
1980 })
1981 .log_err()
1982 .flatten()
1983 else {
1984 return Task::ready(anyhow::Ok(None));
1985 };
1986
1987 cx.spawn(async move |cx| {
1988 let diagnostic_search_range = match scope {
1989 DiagnosticSearchScope::Local => {
1990 let diagnostic_search_start =
1991 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1992 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1993 Point::new(diagnostic_search_start, 0)
1994 ..Point::new(diagnostic_search_end, 0)
1995 }
1996 DiagnosticSearchScope::Global => Default::default(),
1997 };
1998
1999 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
2000 active_buffer,
2001 &snapshot,
2002 diagnostic_search_range,
2003 cursor_point,
2004 &project,
2005 cx,
2006 )
2007 .await?
2008 else {
2009 return anyhow::Ok(None);
2010 };
2011
2012 let Some(prediction_result) = this
2013 .update(cx, |this, cx| {
2014 this.request_prediction(
2015 &project,
2016 &jump_buffer,
2017 jump_position,
2018 PredictEditsRequestTrigger::Diagnostics,
2019 cx,
2020 )
2021 })?
2022 .await?
2023 else {
2024 return anyhow::Ok(None);
2025 };
2026
2027 this.update(cx, |this, cx| {
2028 Some((
2029 if this
2030 .get_or_init_project(&project, cx)
2031 .current_prediction
2032 .is_none()
2033 {
2034 prediction_result
2035 } else {
2036 EditPredictionResult {
2037 id: prediction_result.id,
2038 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2039 e2e_latency: prediction_result.e2e_latency,
2040 }
2041 },
2042 PredictionRequestedBy::DiagnosticsUpdate,
2043 ))
2044 })
2045 })
2046 },
2047 );
2048 }
2049
2050 fn predictions_enabled_at(
2051 snapshot: &BufferSnapshot,
2052 position: Option<language::Anchor>,
2053 cx: &App,
2054 ) -> bool {
2055 let file = snapshot.file();
2056 let all_settings = all_language_settings(file, cx);
2057 if !all_settings.show_edit_predictions(snapshot.language(), cx)
2058 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2059 {
2060 return false;
2061 }
2062
2063 if let Some(last_position) = position {
2064 let settings = snapshot.settings_at(last_position, cx);
2065
2066 if !settings.edit_predictions_disabled_in.is_empty()
2067 && let Some(scope) = snapshot.language_scope_at(last_position)
2068 && let Some(scope_name) = scope.override_name()
2069 && settings
2070 .edit_predictions_disabled_in
2071 .iter()
2072 .any(|s| s == scope_name)
2073 {
2074 return false;
2075 }
2076 }
2077
2078 true
2079 }
2080
2081 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2082}
2083
2084fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2085 let Some(app_state) = AppState::try_global(cx) else {
2086 return false;
2087 };
2088
2089 app_state
2090 .workspace_store
2091 .read(cx)
2092 .workspaces()
2093 .filter_map(|workspace| workspace.upgrade())
2094 .any(|workspace| {
2095 workspace.read(cx).project().entity_id() == project.entity_id()
2096 && workspace
2097 .read(cx)
2098 .leader_for_pane(workspace.read(cx).active_pane())
2099 .is_some()
2100 })
2101}
2102
2103fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2104 match provider {
2105 EditPredictionProvider::Zed
2106 | EditPredictionProvider::Mercury
2107 | EditPredictionProvider::Ollama
2108 | EditPredictionProvider::OpenAiCompatibleApi
2109 | EditPredictionProvider::Experimental(_) => true,
2110 EditPredictionProvider::None
2111 | EditPredictionProvider::Copilot
2112 | EditPredictionProvider::Codestral => false,
2113 }
2114}
2115
2116impl EditPredictionStore {
2117 fn queue_prediction_refresh(
2118 &mut self,
2119 project: Entity<Project>,
2120 request_trigger: PredictEditsRequestTrigger,
2121 throttle_entity: EntityId,
2122 cx: &mut Context<Self>,
2123 do_refresh: impl FnOnce(
2124 WeakEntity<Self>,
2125 &mut AsyncApp,
2126 )
2127 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2128 + 'static,
2129 ) {
2130 fn select_throttle(
2131 project_state: &mut ProjectState,
2132 request_trigger: PredictEditsRequestTrigger,
2133 ) -> &mut Option<(EntityId, Instant)> {
2134 match request_trigger {
2135 PredictEditsRequestTrigger::Diagnostics => {
2136 &mut project_state.last_jump_prediction_refresh
2137 }
2138 _ => &mut project_state.last_edit_prediction_refresh,
2139 }
2140 }
2141
2142 let (needs_acceptance_tracking, max_pending_predictions) =
2143 match all_language_settings(None, cx).edit_predictions.provider {
2144 EditPredictionProvider::Zed
2145 | EditPredictionProvider::Mercury
2146 | EditPredictionProvider::Experimental(_) => (true, 2),
2147 EditPredictionProvider::Ollama => (false, 1),
2148 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2149 EditPredictionProvider::None
2150 | EditPredictionProvider::Copilot
2151 | EditPredictionProvider::Codestral => {
2152 log::error!("queue_prediction_refresh called with non-store provider");
2153 return;
2154 }
2155 };
2156
2157 let drop_on_cancel = !needs_acceptance_tracking;
2158 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2159 let project_state = self.get_or_init_project(&project, cx);
2160 let pending_prediction_id = project_state.next_pending_prediction_id;
2161 project_state.next_pending_prediction_id += 1;
2162 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2163
2164 let task = cx.spawn(async move |this, cx| {
2165 let throttle_wait = this
2166 .update(cx, |this, cx| {
2167 let project_state = this.get_or_init_project(&project, cx);
2168 let throttle = *select_throttle(project_state, request_trigger);
2169
2170 let now = cx.background_executor().now();
2171 throttle.and_then(|(last_entity, last_timestamp)| {
2172 if throttle_entity != last_entity {
2173 return None;
2174 }
2175 (last_timestamp + throttle_timeout).checked_duration_since(now)
2176 })
2177 })
2178 .ok()
2179 .flatten();
2180
2181 if let Some(timeout) = throttle_wait {
2182 cx.background_executor().timer(timeout).await;
2183 }
2184
2185 // If this task was cancelled before the throttle timeout expired,
2186 // do not perform a request. Also skip if another task already
2187 // proceeded since we were enqueued (duplicate).
2188 let mut is_cancelled = true;
2189 this.update(cx, |this, cx| {
2190 let project_state = this.get_or_init_project(&project, cx);
2191 let was_cancelled = project_state
2192 .cancelled_predictions
2193 .remove(&pending_prediction_id);
2194 if was_cancelled {
2195 return;
2196 }
2197
2198 // Another request has been already sent since this was enqueued
2199 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2200 return;
2201 }
2202
2203 let new_refresh = (throttle_entity, cx.background_executor().now());
2204 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2205 is_cancelled = false;
2206 })
2207 .ok();
2208 if is_cancelled {
2209 return None;
2210 }
2211
2212 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2213 let new_prediction_id = new_prediction_result
2214 .as_ref()
2215 .map(|(prediction, _)| prediction.id.clone());
2216
2217 // When a prediction completes, remove it from the pending list, and cancel
2218 // any pending predictions that were enqueued before it.
2219 this.update(cx, |this, cx| {
2220 let project_state = this.get_or_init_project(&project, cx);
2221
2222 let is_cancelled = project_state
2223 .cancelled_predictions
2224 .remove(&pending_prediction_id);
2225
2226 let new_current_prediction = if !is_cancelled
2227 && let Some((prediction_result, requested_by)) = new_prediction_result
2228 {
2229 match prediction_result.prediction {
2230 Ok(prediction) => {
2231 let new_prediction = CurrentEditPrediction {
2232 requested_by,
2233 prediction,
2234 was_shown: false,
2235 shown_with: None,
2236 e2e_latency: prediction_result.e2e_latency,
2237 };
2238
2239 if let Some(current_prediction) =
2240 project_state.current_prediction.as_ref()
2241 {
2242 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2243 {
2244 this.reject_current_prediction(
2245 EditPredictionRejectReason::Replaced,
2246 &project,
2247 cx,
2248 );
2249
2250 Some(new_prediction)
2251 } else {
2252 this.reject_prediction(
2253 new_prediction.prediction.id,
2254 EditPredictionRejectReason::CurrentPreferred,
2255 false,
2256 new_prediction.prediction.model_version,
2257 Some(new_prediction.e2e_latency),
2258 cx,
2259 );
2260 None
2261 }
2262 } else {
2263 Some(new_prediction)
2264 }
2265 }
2266 Err(reject_reason) => {
2267 this.reject_prediction(
2268 prediction_result.id,
2269 reject_reason,
2270 false,
2271 None,
2272 Some(prediction_result.e2e_latency),
2273 cx,
2274 );
2275 None
2276 }
2277 }
2278 } else {
2279 None
2280 };
2281
2282 let project_state = this.get_or_init_project(&project, cx);
2283
2284 if let Some(new_prediction) = new_current_prediction {
2285 project_state.current_prediction = Some(new_prediction);
2286 }
2287
2288 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2289 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2290 if pending_prediction.id == pending_prediction_id {
2291 pending_predictions.remove(ix);
2292 for pending_prediction in pending_predictions.drain(0..ix) {
2293 project_state.cancel_pending_prediction(pending_prediction, cx)
2294 }
2295 break;
2296 }
2297 }
2298 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2299 cx.notify();
2300 })
2301 .ok();
2302
2303 new_prediction_id
2304 });
2305
2306 if project_state.pending_predictions.len() < max_pending_predictions {
2307 project_state
2308 .pending_predictions
2309 .push(PendingPrediction {
2310 id: pending_prediction_id,
2311 task,
2312 drop_on_cancel,
2313 })
2314 .unwrap();
2315 } else {
2316 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2317 project_state
2318 .pending_predictions
2319 .push(PendingPrediction {
2320 id: pending_prediction_id,
2321 task,
2322 drop_on_cancel,
2323 })
2324 .unwrap();
2325 project_state.cancel_pending_prediction(pending_prediction, cx);
2326 }
2327 }
2328
2329 pub fn request_prediction(
2330 &mut self,
2331 project: &Entity<Project>,
2332 active_buffer: &Entity<Buffer>,
2333 position: language::Anchor,
2334 trigger: PredictEditsRequestTrigger,
2335 cx: &mut Context<Self>,
2336 ) -> Task<Result<Option<EditPredictionResult>>> {
2337 self.request_prediction_internal(
2338 project.clone(),
2339 active_buffer.clone(),
2340 position,
2341 trigger,
2342 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2343 cx,
2344 )
2345 }
2346
2347 fn request_prediction_internal(
2348 &mut self,
2349 project: Entity<Project>,
2350 active_buffer: Entity<Buffer>,
2351 position: language::Anchor,
2352 trigger: PredictEditsRequestTrigger,
2353 allow_jump: bool,
2354 cx: &mut Context<Self>,
2355 ) -> Task<Result<Option<EditPredictionResult>>> {
2356 self.get_or_init_project(&project, cx);
2357 let project_state = self.projects.get(&project.entity_id()).unwrap();
2358 let stored_events = project_state.events(cx);
2359 let has_events = !stored_events.is_empty();
2360 let events: Vec<Arc<zeta_prompt::Event>> =
2361 stored_events.iter().map(|e| e.event.clone()).collect();
2362 let debug_tx = project_state.debug_tx.clone();
2363
2364 let snapshot = active_buffer.read(cx).snapshot();
2365 let cursor_point = position.to_point(&snapshot);
2366 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2367 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2368 let diagnostic_search_range =
2369 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2370
2371 let related_files = self.context_for_project(&project, cx);
2372 let mode = match all_language_settings(snapshot.file(), cx).edit_predictions_mode() {
2373 EditPredictionsMode::Eager => PredictEditsMode::Eager,
2374 EditPredictionsMode::Subtle => PredictEditsMode::Subtle,
2375 };
2376
2377 let is_open_source = snapshot
2378 .file()
2379 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2380 && events.iter().all(|event| event.in_open_source_repo())
2381 && related_files.iter().all(|file| file.in_open_source_repo);
2382
2383 let can_collect_data = !cfg!(test)
2384 && is_open_source
2385 && self.is_data_collection_enabled(cx)
2386 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2387 let inputs = EditPredictionModelInput {
2388 project: project.clone(),
2389 buffer: active_buffer,
2390 snapshot,
2391 position,
2392 events,
2393 related_files,
2394 mode,
2395 trigger,
2396 diagnostic_search_range,
2397 debug_tx,
2398 can_collect_data,
2399 is_open_source,
2400 };
2401
2402 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2403
2404 let task = match self.edit_prediction_model {
2405 EditPredictionModel::Zeta => {
2406 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2407 }
2408 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2409 EditPredictionModel::Mercury => {
2410 self.mercury
2411 .request_prediction(inputs, self.credentials_provider.clone(), cx)
2412 }
2413 };
2414
2415 cx.spawn(async move |this, cx| {
2416 let prediction = task.await?;
2417
2418 // Only fall back to diagnostics-based prediction if we got a
2419 // the model had nothing to suggest for the buffer
2420 if prediction.is_none()
2421 && allow_jump
2422 && has_events
2423 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2424 {
2425 this.update(cx, |this, cx| {
2426 this.refresh_prediction_from_diagnostics(
2427 project,
2428 DiagnosticSearchScope::Local,
2429 cx,
2430 );
2431 })?;
2432 return anyhow::Ok(None);
2433 }
2434
2435 Ok(prediction)
2436 })
2437 }
2438
2439 pub(crate) async fn next_diagnostic_location(
2440 active_buffer: Entity<Buffer>,
2441 active_buffer_snapshot: &BufferSnapshot,
2442 active_buffer_diagnostic_search_range: Range<Point>,
2443 active_buffer_cursor_point: Point,
2444 project: &Entity<Project>,
2445 cx: &mut AsyncApp,
2446 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2447 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2448 .selections_in_range(
2449 Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2450 false,
2451 )
2452 .flat_map(|(_, _, _, selections)| {
2453 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2454 })
2455 .collect();
2456
2457 let mut jump_location = active_buffer_snapshot
2458 .diagnostic_groups(None)
2459 .into_iter()
2460 .filter_map(|(_, group)| {
2461 let range = &group.entries[group.primary_ix]
2462 .range
2463 .to_point(&active_buffer_snapshot);
2464 if range.overlaps(&active_buffer_diagnostic_search_range) {
2465 return None;
2466 }
2467 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2468 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2469 });
2470 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2471 <= DIAGNOSTIC_LINES_RANGE;
2472 if near_collaborator && !near_local {
2473 return None;
2474 }
2475 Some(range.start)
2476 })
2477 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2478 .map(|position| {
2479 (
2480 active_buffer.clone(),
2481 active_buffer_snapshot.anchor_before(position),
2482 )
2483 });
2484
2485 if jump_location.is_none() {
2486 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2487 let file = buffer.file()?;
2488
2489 Some(ProjectPath {
2490 worktree_id: file.worktree_id(cx),
2491 path: file.path().clone(),
2492 })
2493 });
2494
2495 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2496 project
2497 .diagnostic_summaries(false, cx)
2498 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2499 .map(|(path, _, _)| {
2500 let shared_prefix = path
2501 .path
2502 .components()
2503 .zip(
2504 active_buffer_path
2505 .as_ref()
2506 .map(|p| p.path.components())
2507 .unwrap_or_default(),
2508 )
2509 .take_while(|(a, b)| a == b)
2510 .count();
2511 (path, shared_prefix)
2512 })
2513 .collect()
2514 });
2515
2516 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2517
2518 for (path, _) in candidates {
2519 let candidate_buffer = project
2520 .update(cx, |project, cx| project.open_buffer(path, cx))
2521 .await?;
2522
2523 let (has_collaborators, diagnostic_position) =
2524 candidate_buffer.read_with(cx, |buffer, _cx| {
2525 let snapshot = buffer.snapshot();
2526 let has_collaborators = snapshot
2527 .selections_in_range(
2528 Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2529 false,
2530 )
2531 .next()
2532 .is_some();
2533 let position = buffer
2534 .buffer_diagnostics(None)
2535 .into_iter()
2536 .min_by_key(|entry| entry.diagnostic.severity)
2537 .map(|entry| entry.range.start);
2538 (has_collaborators, position)
2539 });
2540
2541 if has_collaborators {
2542 continue;
2543 }
2544
2545 if let Some(position) = diagnostic_position {
2546 jump_location = Some((candidate_buffer, position));
2547 break;
2548 }
2549 }
2550 }
2551
2552 anyhow::Ok(jump_location)
2553 }
2554
2555 async fn send_raw_llm_request(
2556 request: RawCompletionRequest,
2557 client: Arc<Client>,
2558 custom_url: Option<Arc<Url>>,
2559 llm_token: LlmApiToken,
2560 organization_id: Option<OrganizationId>,
2561 app_version: Version,
2562 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2563 let url = if let Some(custom_url) = custom_url {
2564 custom_url.as_ref().clone()
2565 } else {
2566 client
2567 .http_client()
2568 .build_zed_llm_url("/predict_edits/raw", &[])?
2569 };
2570
2571 Self::send_api_request(
2572 |builder| {
2573 let req = builder
2574 .uri(url.as_ref())
2575 .body(serde_json::to_string(&request)?.into());
2576 Ok(req?)
2577 },
2578 client,
2579 llm_token,
2580 organization_id,
2581 app_version,
2582 true,
2583 )
2584 .await
2585 }
2586
2587 pub(crate) async fn send_v3_request(
2588 input: ZetaPromptInput,
2589 client: Arc<Client>,
2590 llm_token: LlmApiToken,
2591 organization_id: Option<OrganizationId>,
2592 app_version: Version,
2593 trigger: PredictEditsRequestTrigger,
2594 mode: PredictEditsMode,
2595 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2596 let url = client
2597 .http_client()
2598 .build_zed_llm_url("/predict_edits/v3", &[])?;
2599
2600 let request = PredictEditsV3Request { input, trigger };
2601
2602 let json_bytes = serde_json::to_vec(&request)?;
2603 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2604
2605 Self::send_api_request(
2606 |builder| {
2607 let req = builder
2608 .uri(url.as_ref())
2609 .header("Content-Encoding", "zstd")
2610 .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref())
2611 .body(compressed.clone().into());
2612 Ok(req?)
2613 },
2614 client,
2615 llm_token,
2616 organization_id,
2617 app_version,
2618 true,
2619 )
2620 .await
2621 }
2622
2623 async fn send_api_request<Res>(
2624 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2625 client: Arc<Client>,
2626 llm_token: LlmApiToken,
2627 organization_id: Option<OrganizationId>,
2628 app_version: Version,
2629 require_auth: bool,
2630 ) -> Result<(Res, Option<EditPredictionUsage>)>
2631 where
2632 Res: DeserializeOwned,
2633 {
2634 let http_client = client.http_client();
2635 let mut token = if require_auth {
2636 Some(
2637 client
2638 .acquire_llm_token(&llm_token, organization_id.clone())
2639 .await?,
2640 )
2641 } else {
2642 client
2643 .acquire_llm_token(&llm_token, organization_id.clone())
2644 .await
2645 .ok()
2646 };
2647 let mut did_retry = false;
2648
2649 loop {
2650 let request_builder = http_client::Request::builder().method(Method::POST);
2651
2652 let mut request_builder = request_builder
2653 .header("Content-Type", "application/json")
2654 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2655
2656 // Only add Authorization header if we have a token
2657 if let Some(ref token_value) = token {
2658 request_builder =
2659 request_builder.header("Authorization", format!("Bearer {}", token_value));
2660 }
2661
2662 let request = build(request_builder)?;
2663
2664 let mut response = http_client.send(request).await?;
2665
2666 if let Some(minimum_required_version) = response
2667 .headers()
2668 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2669 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2670 {
2671 anyhow::ensure!(
2672 app_version >= minimum_required_version,
2673 ZedUpdateRequiredError {
2674 minimum_version: minimum_required_version
2675 }
2676 );
2677 }
2678
2679 if response.status().is_success() {
2680 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2681
2682 let mut body = Vec::new();
2683 response.body_mut().read_to_end(&mut body).await?;
2684 return Ok((serde_json::from_slice(&body)?, usage));
2685 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2686 did_retry = true;
2687 token = Some(
2688 client
2689 .refresh_llm_token(&llm_token, organization_id.clone())
2690 .await?,
2691 );
2692 } else {
2693 let mut body = String::new();
2694 response.body_mut().read_to_string(&mut body).await?;
2695 anyhow::bail!(
2696 "Request failed with status: {:?}\nBody: {}",
2697 response.status(),
2698 body
2699 );
2700 }
2701 }
2702 }
2703
2704 pub fn refresh_context(
2705 &mut self,
2706 project: &Entity<Project>,
2707 buffer: &Entity<language::Buffer>,
2708 cursor_position: language::Anchor,
2709 cx: &mut Context<Self>,
2710 ) {
2711 self.get_or_init_project(project, cx)
2712 .context
2713 .update(cx, |store, cx| {
2714 store.refresh(buffer.clone(), cursor_position, cx);
2715 });
2716 }
2717
2718 #[cfg(feature = "cli-support")]
2719 pub fn set_context_for_buffer(
2720 &mut self,
2721 project: &Entity<Project>,
2722 related_files: Vec<RelatedFile>,
2723 cx: &mut Context<Self>,
2724 ) {
2725 self.get_or_init_project(project, cx)
2726 .context
2727 .update(cx, |store, cx| {
2728 store.set_related_files(related_files, cx);
2729 });
2730 }
2731
2732 #[cfg(feature = "cli-support")]
2733 pub fn set_recent_paths_for_project(
2734 &mut self,
2735 project: &Entity<Project>,
2736 paths: impl IntoIterator<Item = project::ProjectPath>,
2737 cx: &mut Context<Self>,
2738 ) {
2739 let project_state = self.get_or_init_project(project, cx);
2740 project_state.recent_paths = paths.into_iter().collect();
2741 }
2742
2743 fn is_file_open_source(
2744 &self,
2745 project: &Entity<Project>,
2746 file: &Arc<dyn File>,
2747 cx: &App,
2748 ) -> bool {
2749 if !file.is_local() || file.is_private() {
2750 return false;
2751 }
2752 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2753 return false;
2754 };
2755 project_state
2756 .license_detection_watchers
2757 .get(&file.worktree_id(cx))
2758 .as_ref()
2759 .is_some_and(|watcher| watcher.is_project_open_source())
2760 }
2761
2762 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2763 self.data_collection_choice.is_enabled(cx)
2764 }
2765
2766 fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2767 let choice = KeyValueStore::global(cx)
2768 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2769 .log_err()
2770 .flatten();
2771
2772 match choice.as_deref() {
2773 Some("true") => DataCollectionChoice::Enabled,
2774 Some("false") => DataCollectionChoice::Disabled,
2775 Some(_) => {
2776 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2777 DataCollectionChoice::NotAnswered
2778 }
2779 None => DataCollectionChoice::NotAnswered,
2780 }
2781 }
2782
2783 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2784 self.data_collection_choice = self.data_collection_choice.toggle();
2785 let new_choice = self.data_collection_choice;
2786 let is_enabled = new_choice.is_enabled(cx);
2787 let kvp = KeyValueStore::global(cx);
2788 db::write_and_log(cx, move || async move {
2789 kvp.write_kvp(
2790 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2791 is_enabled.to_string(),
2792 )
2793 .await
2794 });
2795 }
2796
2797 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2798 self.shown_predictions.iter()
2799 }
2800
2801 pub fn shown_completions_len(&self) -> usize {
2802 self.shown_predictions.len()
2803 }
2804
2805 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2806 self.rated_predictions.contains(id)
2807 }
2808
2809 pub fn rate_prediction(
2810 &mut self,
2811 prediction: &EditPrediction,
2812 rating: EditPredictionRating,
2813 feedback: String,
2814 cx: &mut Context<Self>,
2815 ) {
2816 let organization = self.user_store.read(cx).current_organization();
2817
2818 self.rated_predictions.insert(prediction.id.clone());
2819
2820 cx.background_spawn({
2821 let client = self.client.clone();
2822 let prediction_id = prediction.id.to_string();
2823 let inputs = serde_json::to_value(&prediction.inputs);
2824 let output = prediction
2825 .edit_preview
2826 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2827 async move {
2828 client
2829 .cloud_client()
2830 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2831 organization_id: organization.map(|organization| organization.id.clone()),
2832 request_id: prediction_id,
2833 rating: match rating {
2834 EditPredictionRating::Positive => "positive".to_string(),
2835 EditPredictionRating::Negative => "negative".to_string(),
2836 },
2837 inputs: inputs?,
2838 output,
2839 feedback,
2840 })
2841 .await?;
2842
2843 anyhow::Ok(())
2844 }
2845 })
2846 .detach_and_log_err(cx);
2847
2848 cx.notify();
2849 }
2850}
2851
2852fn collaborator_edit_overlaps_locality_region(
2853 project_state: &ProjectState,
2854 project: &Entity<Project>,
2855 buffer: &Entity<Buffer>,
2856 snapshot: &BufferSnapshot,
2857 edit_range: &Range<Anchor>,
2858 cx: &App,
2859) -> bool {
2860 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2861 return false;
2862 };
2863
2864 if active_buffer.entity_id() != buffer.entity_id() {
2865 return false;
2866 }
2867
2868 let locality_point_range = expand_context_syntactically_then_linewise(
2869 snapshot,
2870 (position..position).to_point(snapshot),
2871 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2872 );
2873 let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2874
2875 edit_range.overlaps(&locality_anchor_range, snapshot)
2876}
2877
2878fn merge_trailing_events_if_needed(
2879 events: &mut VecDeque<StoredEvent>,
2880 end_snapshot: &TextBufferSnapshot,
2881 latest_snapshot: &TextBufferSnapshot,
2882 latest_edit_range: &Range<Anchor>,
2883) {
2884 if let Some(last_event) = events.back() {
2885 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2886 return;
2887 }
2888 if !latest_snapshot
2889 .version
2890 .observed_all(&last_event.new_snapshot_version)
2891 {
2892 return;
2893 }
2894 }
2895
2896 let mut next_old_event = None;
2897 let mut mergeable_count = 0;
2898 for old_event in events.iter().rev() {
2899 if let Some(next_old_event) = next_old_event
2900 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2901 {
2902 break;
2903 }
2904 mergeable_count += 1;
2905 next_old_event = Some(old_event);
2906 }
2907
2908 if mergeable_count <= 1 {
2909 return;
2910 }
2911
2912 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2913 let oldest_event = events_to_merge.peek().unwrap();
2914 let oldest_snapshot = oldest_event.old_snapshot.clone();
2915 let newest_snapshot = end_snapshot;
2916 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2917
2918 for event in events.range(events.len() - mergeable_count + 1..) {
2919 merged_edit_range =
2920 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2921 }
2922
2923 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2924 &oldest_snapshot,
2925 newest_snapshot,
2926 &merged_edit_range,
2927 ) {
2928 let merged_event = match oldest_event.event.as_ref() {
2929 zeta_prompt::Event::BufferChange {
2930 old_path,
2931 path,
2932 in_open_source_repo,
2933 ..
2934 } => StoredEvent {
2935 event: Arc::new(zeta_prompt::Event::BufferChange {
2936 old_path: old_path.clone(),
2937 path: path.clone(),
2938 diff,
2939 in_open_source_repo: *in_open_source_repo,
2940 predicted: events_to_merge.all(|e| {
2941 matches!(
2942 e.event.as_ref(),
2943 zeta_prompt::Event::BufferChange {
2944 predicted: true,
2945 ..
2946 }
2947 )
2948 }),
2949 }),
2950 old_snapshot: oldest_snapshot.clone(),
2951 new_snapshot_version: newest_snapshot.version.clone(),
2952 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2953 ..newest_snapshot.anchor_before(edit_range.end),
2954 },
2955 };
2956 events.truncate(events.len() - mergeable_count);
2957 events.push_back(merged_event);
2958 }
2959}
2960
2961fn merge_anchor_ranges(
2962 left: &Range<Anchor>,
2963 right: &Range<Anchor>,
2964 snapshot: &TextBufferSnapshot,
2965) -> Range<Anchor> {
2966 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2967 left.start
2968 } else {
2969 right.start
2970 };
2971 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2972 left.end
2973 } else {
2974 right.end
2975 };
2976 start..end
2977}
2978
2979#[derive(Error, Debug)]
2980#[error(
2981 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2982)]
2983pub struct ZedUpdateRequiredError {
2984 minimum_version: Version,
2985}
2986
2987#[derive(Debug, Clone, Copy)]
2988pub enum DataCollectionChoice {
2989 NotAnswered,
2990 Enabled,
2991 Disabled,
2992}
2993
2994impl DataCollectionChoice {
2995 pub fn is_enabled(self, cx: &App) -> bool {
2996 if cx.is_staff() {
2997 return true;
2998 }
2999 match self {
3000 Self::Enabled => true,
3001 Self::NotAnswered | Self::Disabled => false,
3002 }
3003 }
3004
3005 #[must_use]
3006 pub fn toggle(&self) -> DataCollectionChoice {
3007 match self {
3008 Self::Enabled => Self::Disabled,
3009 Self::Disabled => Self::Enabled,
3010 Self::NotAnswered => Self::Enabled,
3011 }
3012 }
3013}
3014
3015impl From<bool> for DataCollectionChoice {
3016 fn from(value: bool) -> Self {
3017 match value {
3018 true => DataCollectionChoice::Enabled,
3019 false => DataCollectionChoice::Disabled,
3020 }
3021 }
3022}
3023
3024struct ZedPredictUpsell;
3025
3026impl Dismissable for ZedPredictUpsell {
3027 const KEY: &'static str = "dismissed-edit-predict-upsell";
3028
3029 fn dismissed(cx: &App) -> bool {
3030 // To make this backwards compatible with older versions of Zed, we
3031 // check if the user has seen the previous Edit Prediction Onboarding
3032 // before, by checking the data collection choice which was written to
3033 // the database once the user clicked on "Accept and Enable"
3034 let kvp = KeyValueStore::global(cx);
3035 if kvp
3036 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3037 .log_err()
3038 .is_some_and(|s| s.is_some())
3039 {
3040 return true;
3041 }
3042
3043 kvp.read_kvp(Self::KEY)
3044 .log_err()
3045 .is_some_and(|s| s.is_some())
3046 }
3047}
3048
3049pub fn should_show_upsell_modal(cx: &App) -> bool {
3050 !ZedPredictUpsell::dismissed(cx)
3051}
3052
3053pub fn init(cx: &mut App) {
3054 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3055 workspace.register_action(
3056 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3057 ZedPredictModal::toggle(
3058 workspace,
3059 workspace.user_store().clone(),
3060 workspace.client().clone(),
3061 window,
3062 cx,
3063 )
3064 },
3065 );
3066
3067 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3068 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3069 settings
3070 .project
3071 .all_languages
3072 .edit_predictions
3073 .get_or_insert_default()
3074 .provider = Some(EditPredictionProvider::None)
3075 });
3076 });
3077 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3078 EditPredictionStore::try_global(cx).and_then(|store| {
3079 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3080 })
3081 }
3082
3083 workspace.register_action(|workspace, _: &SignIn, window, cx| {
3084 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3085 copilot_ui::initiate_sign_in(copilot, window, cx);
3086 }
3087 });
3088 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3089 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3090 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3091 }
3092 });
3093 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3094 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3095 copilot_ui::initiate_sign_out(copilot, window, cx);
3096 }
3097 });
3098 })
3099 .detach();
3100}