1use anyhow::Result;
2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
3use cloud_api_client::LlmApiToken;
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use credentials_provider::CredentialsProvider;
16use db::kvp::{Dismissable, KeyValueStore};
17use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
19use futures::{
20 AsyncReadExt as _, FutureExt as _, StreamExt as _,
21 channel::mpsc::{self, UnboundedReceiver},
22 select_biased,
23};
24use gpui::BackgroundExecutor;
25use gpui::http_client::Url;
26use gpui::{
27 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
28 http_client::{self, AsyncBody, Method},
29 prelude::*,
30};
31use heapless::Vec as ArrayVec;
32use language::language_settings::all_language_settings;
33use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
34use language::{BufferSnapshot, OffsetRangeExt};
35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
36use release_channel::AppVersion;
37use semver::Version;
38use serde::de::DeserializeOwned;
39use settings::{
40 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
41};
42use std::collections::{VecDeque, hash_map};
43use std::env;
44use text::{AnchorRangeExt, Edit};
45use workspace::{AppState, Workspace};
46use zeta_prompt::{ZetaFormat, ZetaPromptInput};
47
48use std::mem;
49use std::ops::Range;
50use std::path::Path;
51use std::rc::Rc;
52use std::str::FromStr as _;
53use std::sync::Arc;
54use std::time::{Duration, Instant};
55
56use thiserror::Error;
57use util::{RangeExt as _, ResultExt as _};
58
59pub mod cursor_excerpt;
60pub mod example_spec;
61pub mod fim;
62mod license_detection;
63pub mod mercury;
64pub mod metrics;
65pub mod ollama;
66mod onboarding_modal;
67pub mod open_ai_response;
68mod prediction;
69
70pub mod udiff;
71
72mod capture_example;
73pub mod open_ai_compatible;
74mod zed_edit_prediction_delegate;
75pub mod zeta;
76
77#[cfg(test)]
78mod edit_prediction_tests;
79
80use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
81use crate::example_spec::ExampleSpec;
82use crate::license_detection::LicenseDetectionWatcher;
83use crate::mercury::Mercury;
84pub use crate::metrics::{KeptRateResult, compute_kept_rate};
85use crate::onboarding_modal::ZedPredictModal;
86pub use crate::prediction::EditPrediction;
87pub use crate::prediction::EditPredictionId;
88use crate::prediction::EditPredictionResult;
89pub use capture_example::capture_example;
90pub use language_model::ApiKeyState;
91pub use telemetry_events::EditPredictionRating;
92pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
93
94actions!(
95 edit_prediction,
96 [
97 /// Resets the edit prediction onboarding state.
98 ResetOnboarding,
99 /// Clears the edit prediction history.
100 ClearHistory,
101 ]
102);
103
104/// Maximum number of events to track.
105const EVENT_COUNT_MAX: usize = 10;
106const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
107const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
108const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
109const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
110const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
111const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
112const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
113const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
114const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
115
116pub struct EditPredictionJumpsFeatureFlag;
117
118impl FeatureFlag for EditPredictionJumpsFeatureFlag {
119 const NAME: &'static str = "edit_prediction_jumps";
120}
121
122#[derive(Clone)]
123struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
124
125impl Global for EditPredictionStoreGlobal {}
126
127/// Configuration for using the raw Zeta2 endpoint.
128/// When set, the client uses the raw endpoint and constructs the prompt itself.
129/// The version is also used as the Baseten environment name (lowercased).
130#[derive(Clone)]
131pub struct Zeta2RawConfig {
132 pub model_id: Option<String>,
133 pub environment: Option<String>,
134 pub format: ZetaFormat,
135}
136
137pub struct EditPredictionStore {
138 client: Arc<Client>,
139 user_store: Entity<UserStore>,
140 llm_token: LlmApiToken,
141 _fetch_experiments_task: Task<()>,
142 projects: HashMap<EntityId, ProjectState>,
143 update_required: bool,
144 edit_prediction_model: EditPredictionModel,
145 zeta2_raw_config: Option<Zeta2RawConfig>,
146 preferred_experiment: Option<String>,
147 available_experiments: Vec<String>,
148 pub mercury: Mercury,
149 data_collection_choice: DataCollectionChoice,
150 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
151 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
152 shown_predictions: VecDeque<EditPrediction>,
153 rated_predictions: HashSet<EditPredictionId>,
154 #[cfg(test)]
155 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
156 credentials_provider: Arc<dyn CredentialsProvider>,
157}
158
159pub(crate) struct EditPredictionRejectionPayload {
160 rejection: EditPredictionRejection,
161 organization_id: Option<OrganizationId>,
162}
163
164#[derive(Copy, Clone, PartialEq, Eq)]
165pub enum EditPredictionModel {
166 Zeta,
167 Fim { format: EditPredictionPromptFormat },
168 Mercury,
169}
170
171#[derive(Clone)]
172pub struct EditPredictionModelInput {
173 project: Entity<Project>,
174 buffer: Entity<Buffer>,
175 snapshot: BufferSnapshot,
176 position: Anchor,
177 events: Vec<Arc<zeta_prompt::Event>>,
178 related_files: Vec<RelatedFile>,
179 trigger: PredictEditsRequestTrigger,
180 diagnostic_search_range: Range<Point>,
181 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
182 can_collect_data: bool,
183 is_open_source: bool,
184}
185
186#[derive(Debug)]
187pub enum DebugEvent {
188 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
189 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
190 EditPredictionStarted(EditPredictionStartedDebugEvent),
191 EditPredictionFinished(EditPredictionFinishedDebugEvent),
192}
193
194#[derive(Debug)]
195pub struct ContextRetrievalStartedDebugEvent {
196 pub project_entity_id: EntityId,
197 pub timestamp: Instant,
198 pub search_prompt: String,
199}
200
201#[derive(Debug)]
202pub struct ContextRetrievalFinishedDebugEvent {
203 pub project_entity_id: EntityId,
204 pub timestamp: Instant,
205 pub metadata: Vec<(&'static str, SharedString)>,
206}
207
208#[derive(Debug)]
209pub struct EditPredictionStartedDebugEvent {
210 pub buffer: WeakEntity<Buffer>,
211 pub position: Anchor,
212 pub prompt: Option<String>,
213}
214
215#[derive(Debug)]
216pub struct EditPredictionFinishedDebugEvent {
217 pub buffer: WeakEntity<Buffer>,
218 pub position: Anchor,
219 pub model_output: Option<String>,
220}
221
222/// An event with associated metadata for reconstructing buffer state.
223#[derive(Clone)]
224pub struct StoredEvent {
225 pub event: Arc<zeta_prompt::Event>,
226 pub old_snapshot: TextBufferSnapshot,
227 pub new_snapshot_version: clock::Global,
228 pub total_edit_range: Range<Anchor>,
229}
230
231impl StoredEvent {
232 fn can_merge(
233 &self,
234 next_old_event: &StoredEvent,
235 latest_snapshot: &TextBufferSnapshot,
236 latest_edit_range: &Range<Anchor>,
237 ) -> bool {
238 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
239 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
240 return false;
241 }
242 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
243 return false;
244 }
245 if self.new_snapshot_version != next_old_event.old_snapshot.version {
246 return false;
247 }
248 if !latest_snapshot
249 .version
250 .observed_all(&next_old_event.new_snapshot_version)
251 {
252 return false;
253 }
254
255 let a_is_predicted = matches!(
256 self.event.as_ref(),
257 zeta_prompt::Event::BufferChange {
258 predicted: true,
259 ..
260 }
261 );
262 let b_is_predicted = matches!(
263 next_old_event.event.as_ref(),
264 zeta_prompt::Event::BufferChange {
265 predicted: true,
266 ..
267 }
268 );
269
270 // If events come from the same source (both predicted or both manual) then
271 // we would have coalesced them already.
272 if a_is_predicted == b_is_predicted {
273 return false;
274 }
275
276 let left_range = self.total_edit_range.to_point(latest_snapshot);
277 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
278 let latest_range = latest_edit_range.to_point(latest_snapshot);
279
280 // Events near to the latest edit are not merged if their sources differ.
281 if lines_between_ranges(&left_range, &latest_range)
282 .min(lines_between_ranges(&right_range, &latest_range))
283 <= CHANGE_GROUPING_LINE_SPAN
284 {
285 return false;
286 }
287
288 // Events that are distant from each other are not merged.
289 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
290 return false;
291 }
292
293 true
294 }
295}
296
297fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
298 if left.start > right.end {
299 return left.start.row - right.end.row;
300 }
301 if right.start > left.end {
302 return right.start.row - left.end.row;
303 }
304 0
305}
306
307struct ProjectState {
308 events: VecDeque<StoredEvent>,
309 last_event: Option<LastEvent>,
310 recent_paths: VecDeque<ProjectPath>,
311 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
312 current_prediction: Option<CurrentEditPrediction>,
313 next_pending_prediction_id: usize,
314 pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
315 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
316 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
317 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
318 cancelled_predictions: HashSet<usize>,
319 context: Entity<RelatedExcerptStore>,
320 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
321 _subscriptions: [gpui::Subscription; 2],
322 copilot: Option<Entity<Copilot>>,
323}
324
325impl ProjectState {
326 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
327 self.events
328 .iter()
329 .cloned()
330 .chain(self.last_event.as_ref().iter().flat_map(|event| {
331 let (one, two) = event.split_by_pause();
332 let one = one.finalize(&self.license_detection_watchers, cx);
333 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
334 one.into_iter().chain(two)
335 }))
336 .collect()
337 }
338
339 fn cancel_pending_prediction(
340 &mut self,
341 pending_prediction: PendingPrediction,
342 cx: &mut Context<EditPredictionStore>,
343 ) {
344 self.cancelled_predictions.insert(pending_prediction.id);
345
346 if pending_prediction.drop_on_cancel {
347 drop(pending_prediction.task);
348 } else {
349 cx.spawn(async move |this, cx| {
350 let Some(prediction_id) = pending_prediction.task.await else {
351 return;
352 };
353
354 this.update(cx, |this, cx| {
355 this.reject_prediction(
356 prediction_id,
357 EditPredictionRejectReason::Canceled,
358 false,
359 None,
360 None,
361 cx,
362 );
363 })
364 .ok();
365 })
366 .detach()
367 }
368 }
369
370 fn active_buffer(
371 &self,
372 project: &Entity<Project>,
373 cx: &App,
374 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
375 let project = project.read(cx);
376 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
377 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
378 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
379 Some((active_buffer, registered_buffer.last_position))
380 }
381}
382
383#[derive(Debug, Clone)]
384struct CurrentEditPrediction {
385 pub requested_by: PredictionRequestedBy,
386 pub prediction: EditPrediction,
387 pub was_shown: bool,
388 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
389 pub e2e_latency: std::time::Duration,
390}
391
392impl CurrentEditPrediction {
393 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
394 let Some(new_edits) = self
395 .prediction
396 .interpolate(&self.prediction.buffer.read(cx))
397 else {
398 return false;
399 };
400
401 if self.prediction.buffer != old_prediction.prediction.buffer {
402 return true;
403 }
404
405 let Some(old_edits) = old_prediction
406 .prediction
407 .interpolate(&old_prediction.prediction.buffer.read(cx))
408 else {
409 return true;
410 };
411
412 let requested_by_buffer_id = self.requested_by.buffer_id();
413
414 // This reduces the occurrence of UI thrash from replacing edits
415 //
416 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
417 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
418 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
419 && old_edits.len() == 1
420 && new_edits.len() == 1
421 {
422 let (old_range, old_text) = &old_edits[0];
423 let (new_range, new_text) = &new_edits[0];
424 new_range == old_range && new_text.starts_with(old_text.as_ref())
425 } else {
426 true
427 }
428 }
429}
430
431#[derive(Debug, Clone)]
432enum PredictionRequestedBy {
433 DiagnosticsUpdate,
434 Buffer(EntityId),
435}
436
437impl PredictionRequestedBy {
438 pub fn buffer_id(&self) -> Option<EntityId> {
439 match self {
440 PredictionRequestedBy::DiagnosticsUpdate => None,
441 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
442 }
443 }
444}
445
446const DIAGNOSTIC_LINES_RANGE: u32 = 20;
447
448#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
449pub enum DiagnosticSearchScope {
450 Local,
451 Global,
452}
453
454#[derive(Debug)]
455struct PendingPrediction {
456 id: usize,
457 task: Task<Option<EditPredictionId>>,
458 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
459 /// If false, the task is awaited to completion so rejection can be reported.
460 drop_on_cancel: bool,
461}
462
463/// A prediction from the perspective of a buffer.
464#[derive(Debug)]
465enum BufferEditPrediction<'a> {
466 Local { prediction: &'a EditPrediction },
467 Jump { prediction: &'a EditPrediction },
468}
469
470#[cfg(test)]
471impl std::ops::Deref for BufferEditPrediction<'_> {
472 type Target = EditPrediction;
473
474 fn deref(&self) -> &Self::Target {
475 match self {
476 BufferEditPrediction::Local { prediction } => prediction,
477 BufferEditPrediction::Jump { prediction } => prediction,
478 }
479 }
480}
481
482#[derive(Clone)]
483struct PendingSettledPrediction {
484 request_id: EditPredictionId,
485 editable_anchor_range: Range<Anchor>,
486 editable_region_before_prediction: String,
487 predicted_editable_region: String,
488 ts_error_count_before_prediction: usize,
489 ts_error_count_after_prediction: usize,
490 example: Option<ExampleSpec>,
491 enqueued_at: Instant,
492 last_edit_at: Instant,
493 e2e_latency: std::time::Duration,
494}
495
496struct RegisteredBuffer {
497 file: Option<Arc<dyn File>>,
498 snapshot: TextBufferSnapshot,
499 pending_predictions: Vec<PendingSettledPrediction>,
500 last_position: Option<Anchor>,
501 _subscriptions: [gpui::Subscription; 2],
502}
503
504#[derive(Clone)]
505struct LastEvent {
506 old_snapshot: TextBufferSnapshot,
507 new_snapshot: TextBufferSnapshot,
508 old_file: Option<Arc<dyn File>>,
509 new_file: Option<Arc<dyn File>>,
510 latest_edit_range: Range<Anchor>,
511 total_edit_range: Range<Anchor>,
512 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
513 predicted: bool,
514 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
515 last_edit_time: Option<Instant>,
516}
517
518impl LastEvent {
519 pub fn finalize(
520 &self,
521 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
522 cx: &App,
523 ) -> Option<StoredEvent> {
524 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
525 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
526
527 let in_open_source_repo =
528 [self.new_file.as_ref(), self.old_file.as_ref()]
529 .iter()
530 .all(|file| {
531 file.is_some_and(|file| {
532 license_detection_watchers
533 .get(&file.worktree_id(cx))
534 .is_some_and(|watcher| watcher.is_project_open_source())
535 })
536 });
537
538 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
539 &self.old_snapshot,
540 &self.new_snapshot,
541 &self.total_edit_range,
542 )?;
543
544 if path == old_path && diff.is_empty() {
545 None
546 } else {
547 Some(StoredEvent {
548 event: Arc::new(zeta_prompt::Event::BufferChange {
549 old_path,
550 path,
551 diff,
552 in_open_source_repo,
553 predicted: self.predicted,
554 }),
555 old_snapshot: self.old_snapshot.clone(),
556 new_snapshot_version: self.new_snapshot.version.clone(),
557 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
558 ..self.new_snapshot.anchor_before(edit_range.end),
559 })
560 }
561 }
562
563 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
564 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
565 return (self.clone(), None);
566 };
567
568 let total_edit_range_before_pause = self
569 .total_edit_range_at_last_pause_boundary
570 .clone()
571 .unwrap_or_else(|| self.total_edit_range.clone());
572
573 let Some(total_edit_range_after_pause) =
574 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
575 else {
576 return (self.clone(), None);
577 };
578
579 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
580 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
581
582 let before = LastEvent {
583 old_snapshot: self.old_snapshot.clone(),
584 new_snapshot: boundary_snapshot.clone(),
585 old_file: self.old_file.clone(),
586 new_file: self.new_file.clone(),
587 latest_edit_range: latest_edit_range_before_pause,
588 total_edit_range: total_edit_range_before_pause,
589 total_edit_range_at_last_pause_boundary: None,
590 predicted: self.predicted,
591 snapshot_after_last_editing_pause: None,
592 last_edit_time: self.last_edit_time,
593 };
594
595 let after = LastEvent {
596 old_snapshot: boundary_snapshot.clone(),
597 new_snapshot: self.new_snapshot.clone(),
598 old_file: self.old_file.clone(),
599 new_file: self.new_file.clone(),
600 latest_edit_range: latest_edit_range_after_pause,
601 total_edit_range: total_edit_range_after_pause,
602 total_edit_range_at_last_pause_boundary: None,
603 predicted: self.predicted,
604 snapshot_after_last_editing_pause: None,
605 last_edit_time: self.last_edit_time,
606 };
607
608 (before, Some(after))
609 }
610}
611
612fn compute_total_edit_range_between_snapshots(
613 old_snapshot: &TextBufferSnapshot,
614 new_snapshot: &TextBufferSnapshot,
615) -> Option<Range<Anchor>> {
616 let edits: Vec<Edit<usize>> = new_snapshot
617 .edits_since::<usize>(&old_snapshot.version)
618 .collect();
619
620 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
621 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
622 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
623
624 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
625}
626
627fn compute_old_range_for_new_range(
628 old_snapshot: &TextBufferSnapshot,
629 new_snapshot: &TextBufferSnapshot,
630 total_edit_range: &Range<Anchor>,
631) -> Option<Range<Point>> {
632 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
633 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
634
635 let edits: Vec<Edit<usize>> = new_snapshot
636 .edits_since::<usize>(&old_snapshot.version)
637 .collect();
638 let mut old_start_offset = None;
639 let mut old_end_offset = None;
640 let mut delta: isize = 0;
641
642 for edit in &edits {
643 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
644 old_start_offset = Some(if new_start_offset < edit.new.start {
645 new_start_offset.checked_add_signed(-delta)?
646 } else {
647 edit.old.start
648 });
649 }
650
651 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
652 old_end_offset = Some(if new_end_offset < edit.new.start {
653 new_end_offset.checked_add_signed(-delta)?
654 } else {
655 edit.old.end
656 });
657 }
658
659 delta += edit.new.len() as isize - edit.old.len() as isize;
660 }
661
662 let old_start_offset =
663 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
664 let old_end_offset =
665 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
666
667 Some(
668 old_snapshot.offset_to_point(old_start_offset)
669 ..old_snapshot.offset_to_point(old_end_offset),
670 )
671}
672
673fn compute_diff_between_snapshots_in_range(
674 old_snapshot: &TextBufferSnapshot,
675 new_snapshot: &TextBufferSnapshot,
676 total_edit_range: &Range<Anchor>,
677) -> Option<(String, Range<Point>)> {
678 let new_start_point = total_edit_range.start.to_point(new_snapshot);
679 let new_end_point = total_edit_range.end.to_point(new_snapshot);
680 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
681 let old_start_point = old_range.start;
682 let old_end_point = old_range.end;
683
684 const CONTEXT_LINES: u32 = 3;
685
686 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
687 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
688 let old_context_end_row =
689 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
690 let new_context_end_row =
691 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
692
693 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
694 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
695 let old_end_line_offset = old_snapshot
696 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
697 let new_end_line_offset = new_snapshot
698 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
699 let old_edit_range = old_start_line_offset..old_end_line_offset;
700 let new_edit_range = new_start_line_offset..new_end_line_offset;
701
702 if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
703 || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
704 {
705 return None;
706 }
707
708 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
709 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
710
711 let diff = language::unified_diff_with_offsets(
712 &old_region_text,
713 &new_region_text,
714 old_context_start_row,
715 new_context_start_row,
716 );
717
718 Some((diff, new_start_point..new_end_point))
719}
720
721fn buffer_path_with_id_fallback(
722 file: Option<&Arc<dyn File>>,
723 snapshot: &TextBufferSnapshot,
724 cx: &App,
725) -> Arc<Path> {
726 if let Some(file) = file {
727 file.full_path(cx).into()
728 } else {
729 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
730 }
731}
732
733impl EditPredictionStore {
734 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
735 cx.try_global::<EditPredictionStoreGlobal>()
736 .map(|global| global.0.clone())
737 }
738
739 pub fn global(
740 client: &Arc<Client>,
741 user_store: &Entity<UserStore>,
742 cx: &mut App,
743 ) -> Entity<Self> {
744 cx.try_global::<EditPredictionStoreGlobal>()
745 .map(|global| global.0.clone())
746 .unwrap_or_else(|| {
747 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
748 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
749 ep_store
750 })
751 }
752
753 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
754 let data_collection_choice = Self::load_data_collection_choice(cx);
755
756 let llm_token = global_llm_token(cx);
757
758 let (reject_tx, reject_rx) = mpsc::unbounded();
759 cx.background_spawn({
760 let client = client.clone();
761 let llm_token = llm_token.clone();
762 let app_version = AppVersion::global(cx);
763 let background_executor = cx.background_executor().clone();
764 async move {
765 Self::handle_rejected_predictions(
766 reject_rx,
767 client,
768 llm_token,
769 app_version,
770 background_executor,
771 )
772 .await
773 }
774 })
775 .detach();
776
777 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
778 cx.spawn(async move |this, cx| {
779 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
780 })
781 .detach();
782
783 let mut current_user = user_store.read(cx).watch_current_user();
784 let fetch_experiments_task = cx.spawn(async move |this, cx| {
785 while current_user.borrow().is_none() {
786 current_user.next().await;
787 }
788
789 this.update(cx, |this, cx| {
790 if cx.is_staff() {
791 this.refresh_available_experiments(cx);
792 }
793 })
794 .log_err();
795 });
796
797 let credentials_provider = zed_credentials_provider::global(cx);
798
799 let this = Self {
800 projects: HashMap::default(),
801 client,
802 user_store,
803 llm_token,
804 _fetch_experiments_task: fetch_experiments_task,
805 update_required: false,
806 edit_prediction_model: EditPredictionModel::Zeta,
807 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
808 preferred_experiment: None,
809 available_experiments: Vec::new(),
810 mercury: Mercury::new(cx),
811
812 data_collection_choice,
813 reject_predictions_tx: reject_tx,
814 settled_predictions_tx,
815 rated_predictions: Default::default(),
816 shown_predictions: Default::default(),
817 #[cfg(test)]
818 settled_event_callback: None,
819
820 credentials_provider,
821 };
822
823 this
824 }
825
826 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
827 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
828 let format = ZetaFormat::parse(&version_str).ok()?;
829 let model_id = env::var("ZED_ZETA_MODEL").ok();
830 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
831 Some(Zeta2RawConfig {
832 model_id,
833 environment,
834 format,
835 })
836 }
837
838 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
839 self.edit_prediction_model = model;
840 }
841
842 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
843 self.zeta2_raw_config = Some(config);
844 }
845
846 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
847 self.zeta2_raw_config.as_ref()
848 }
849
850 pub fn preferred_experiment(&self) -> Option<&str> {
851 self.preferred_experiment.as_deref()
852 }
853
854 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
855 self.preferred_experiment = experiment;
856 }
857
858 pub fn available_experiments(&self) -> &[String] {
859 &self.available_experiments
860 }
861
862 pub fn active_experiment(&self) -> Option<&str> {
863 self.preferred_experiment.as_deref().or_else(|| {
864 self.shown_predictions
865 .iter()
866 .find_map(|p| p.model_version.as_ref())
867 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
868 })
869 }
870
871 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
872 let client = self.client.clone();
873 let llm_token = self.llm_token.clone();
874 let app_version = AppVersion::global(cx);
875 let organization_id = self
876 .user_store
877 .read(cx)
878 .current_organization()
879 .map(|organization| organization.id.clone());
880
881 cx.spawn(async move |this, cx| {
882 let experiments = cx
883 .background_spawn(async move {
884 let http_client = client.http_client();
885 let token = client
886 .acquire_llm_token(&llm_token, organization_id.clone())
887 .await?;
888 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
889 let request = http_client::Request::builder()
890 .method(Method::GET)
891 .uri(url.as_ref())
892 .header("Authorization", format!("Bearer {}", token))
893 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
894 .body(Default::default())?;
895 let mut response = http_client.send(request).await?;
896 if response.status().is_success() {
897 let mut body = Vec::new();
898 response.body_mut().read_to_end(&mut body).await?;
899 let experiments: Vec<String> = serde_json::from_slice(&body)?;
900 Ok(experiments)
901 } else {
902 let mut body = String::new();
903 response.body_mut().read_to_string(&mut body).await?;
904 anyhow::bail!(
905 "Failed to fetch experiments: {:?}\nBody: {}",
906 response.status(),
907 body
908 );
909 }
910 })
911 .await?;
912 this.update(cx, |this, cx| {
913 this.available_experiments = experiments;
914 cx.notify();
915 })?;
916 anyhow::Ok(())
917 })
918 .detach_and_log_err(cx);
919 }
920
921 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
922 use ui::IconName;
923 match self.edit_prediction_model {
924 EditPredictionModel::Mercury => {
925 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
926 }
927 EditPredictionModel::Zeta => {
928 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
929 .with_disabled(IconName::ZedPredictDisabled)
930 .with_up(IconName::ZedPredictUp)
931 .with_down(IconName::ZedPredictDown)
932 .with_error(IconName::ZedPredictError)
933 }
934 EditPredictionModel::Fim { .. } => {
935 let settings = &all_language_settings(None, cx).edit_predictions;
936 match settings.provider {
937 EditPredictionProvider::Ollama => {
938 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
939 }
940 _ => {
941 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
942 }
943 }
944 }
945 }
946 }
947
948 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
949 self.mercury.api_token.read(cx).has_key()
950 }
951
952 pub fn mercury_has_payment_required_error(&self) -> bool {
953 self.mercury.has_payment_required_error()
954 }
955
956 pub fn clear_history(&mut self) {
957 for project_state in self.projects.values_mut() {
958 project_state.events.clear();
959 project_state.last_event.take();
960 }
961 }
962
963 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
964 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
965 project_state.events.clear();
966 project_state.last_event.take();
967 }
968 }
969
970 pub fn edit_history_for_project(
971 &self,
972 project: &Entity<Project>,
973 cx: &App,
974 ) -> Vec<StoredEvent> {
975 self.projects
976 .get(&project.entity_id())
977 .map(|project_state| project_state.events(cx))
978 .unwrap_or_default()
979 }
980
981 pub fn context_for_project<'a>(
982 &'a self,
983 project: &Entity<Project>,
984 cx: &'a mut App,
985 ) -> Vec<RelatedFile> {
986 self.projects
987 .get(&project.entity_id())
988 .map(|project_state| {
989 project_state.context.update(cx, |context, cx| {
990 context
991 .related_files_with_buffers(cx)
992 .map(|(mut related_file, buffer)| {
993 related_file.in_open_source_repo = buffer
994 .read(cx)
995 .file()
996 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
997 related_file
998 })
999 .collect()
1000 })
1001 })
1002 .unwrap_or_default()
1003 }
1004
1005 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1006 self.projects
1007 .get(&project.entity_id())
1008 .and_then(|project| project.copilot.clone())
1009 }
1010
1011 pub fn start_copilot_for_project(
1012 &mut self,
1013 project: &Entity<Project>,
1014 cx: &mut Context<Self>,
1015 ) -> Option<Entity<Copilot>> {
1016 if DisableAiSettings::get(None, cx).disable_ai {
1017 return None;
1018 }
1019 let state = self.get_or_init_project(project, cx);
1020
1021 if state.copilot.is_some() {
1022 return state.copilot.clone();
1023 }
1024 let _project = project.clone();
1025 let project = project.read(cx);
1026
1027 let node = project.node_runtime().cloned();
1028 if let Some(node) = node {
1029 let next_id = project.languages().next_language_server_id();
1030 let fs = project.fs().clone();
1031
1032 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1033 state.copilot = Some(copilot.clone());
1034 Some(copilot)
1035 } else {
1036 None
1037 }
1038 }
1039
1040 pub fn context_for_project_with_buffers<'a>(
1041 &'a self,
1042 project: &Entity<Project>,
1043 cx: &'a mut App,
1044 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1045 self.projects
1046 .get(&project.entity_id())
1047 .map(|project| {
1048 project.context.update(cx, |context, cx| {
1049 context.related_files_with_buffers(cx).collect()
1050 })
1051 })
1052 .unwrap_or_default()
1053 }
1054
1055 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1056 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1057 self.user_store.read(cx).edit_prediction_usage()
1058 } else {
1059 None
1060 }
1061 }
1062
1063 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1064 self.get_or_init_project(project, cx);
1065 }
1066
1067 pub fn register_buffer(
1068 &mut self,
1069 buffer: &Entity<Buffer>,
1070 project: &Entity<Project>,
1071 cx: &mut Context<Self>,
1072 ) {
1073 let project_state = self.get_or_init_project(project, cx);
1074 Self::register_buffer_impl(project_state, buffer, project, cx);
1075 }
1076
1077 fn get_or_init_project(
1078 &mut self,
1079 project: &Entity<Project>,
1080 cx: &mut Context<Self>,
1081 ) -> &mut ProjectState {
1082 let entity_id = project.entity_id();
1083 self.projects
1084 .entry(entity_id)
1085 .or_insert_with(|| ProjectState {
1086 context: {
1087 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1088 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1089 this.handle_excerpt_store_event(entity_id, event);
1090 })
1091 .detach();
1092 related_excerpt_store
1093 },
1094 events: VecDeque::new(),
1095 last_event: None,
1096 recent_paths: VecDeque::new(),
1097 debug_tx: None,
1098 registered_buffers: HashMap::default(),
1099 current_prediction: None,
1100 cancelled_predictions: HashSet::default(),
1101 pending_predictions: ArrayVec::new(),
1102 next_pending_prediction_id: 0,
1103 last_edit_prediction_refresh: None,
1104 last_jump_prediction_refresh: None,
1105 license_detection_watchers: HashMap::default(),
1106 _subscriptions: [
1107 cx.subscribe(&project, Self::handle_project_event),
1108 cx.observe_release(&project, move |this, _, cx| {
1109 this.projects.remove(&entity_id);
1110 cx.notify();
1111 }),
1112 ],
1113 copilot: None,
1114 })
1115 }
1116
1117 pub fn remove_project(&mut self, project: &Entity<Project>) {
1118 self.projects.remove(&project.entity_id());
1119 }
1120
1121 fn handle_excerpt_store_event(
1122 &mut self,
1123 project_entity_id: EntityId,
1124 event: &RelatedExcerptStoreEvent,
1125 ) {
1126 if let Some(project_state) = self.projects.get(&project_entity_id) {
1127 if let Some(debug_tx) = project_state.debug_tx.clone() {
1128 match event {
1129 RelatedExcerptStoreEvent::StartedRefresh => {
1130 debug_tx
1131 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1132 ContextRetrievalStartedDebugEvent {
1133 project_entity_id: project_entity_id,
1134 timestamp: Instant::now(),
1135 search_prompt: String::new(),
1136 },
1137 ))
1138 .ok();
1139 }
1140 RelatedExcerptStoreEvent::FinishedRefresh {
1141 cache_hit_count,
1142 cache_miss_count,
1143 mean_definition_latency,
1144 max_definition_latency,
1145 } => {
1146 debug_tx
1147 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1148 ContextRetrievalFinishedDebugEvent {
1149 project_entity_id: project_entity_id,
1150 timestamp: Instant::now(),
1151 metadata: vec![
1152 (
1153 "Cache Hits",
1154 format!(
1155 "{}/{}",
1156 cache_hit_count,
1157 cache_hit_count + cache_miss_count
1158 )
1159 .into(),
1160 ),
1161 (
1162 "Max LSP Time",
1163 format!("{} ms", max_definition_latency.as_millis())
1164 .into(),
1165 ),
1166 (
1167 "Mean LSP Time",
1168 format!("{} ms", mean_definition_latency.as_millis())
1169 .into(),
1170 ),
1171 ],
1172 },
1173 ))
1174 .ok();
1175 }
1176 }
1177 }
1178 }
1179 }
1180
1181 pub fn debug_info(
1182 &mut self,
1183 project: &Entity<Project>,
1184 cx: &mut Context<Self>,
1185 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1186 let project_state = self.get_or_init_project(project, cx);
1187 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1188 project_state.debug_tx = Some(debug_watch_tx);
1189 debug_watch_rx
1190 }
1191
1192 fn handle_project_event(
1193 &mut self,
1194 project: Entity<Project>,
1195 event: &project::Event,
1196 cx: &mut Context<Self>,
1197 ) {
1198 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1199 return;
1200 }
1201 // TODO [zeta2] init with recent paths
1202 match event {
1203 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1204 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1205 return;
1206 };
1207 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1208 if let Some(path) = path {
1209 if let Some(ix) = project_state
1210 .recent_paths
1211 .iter()
1212 .position(|probe| probe == &path)
1213 {
1214 project_state.recent_paths.remove(ix);
1215 }
1216 project_state.recent_paths.push_front(path);
1217 }
1218 }
1219 project::Event::DiagnosticsUpdated { .. } => {
1220 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1221 self.refresh_prediction_from_diagnostics(
1222 project,
1223 DiagnosticSearchScope::Global,
1224 cx,
1225 );
1226 }
1227 }
1228 _ => (),
1229 }
1230 }
1231
1232 fn register_buffer_impl<'a>(
1233 project_state: &'a mut ProjectState,
1234 buffer: &Entity<Buffer>,
1235 project: &Entity<Project>,
1236 cx: &mut Context<Self>,
1237 ) -> &'a mut RegisteredBuffer {
1238 let buffer_id = buffer.entity_id();
1239
1240 if let Some(file) = buffer.read(cx).file() {
1241 let worktree_id = file.worktree_id(cx);
1242 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1243 project_state
1244 .license_detection_watchers
1245 .entry(worktree_id)
1246 .or_insert_with(|| {
1247 let project_entity_id = project.entity_id();
1248 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1249 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1250 else {
1251 return;
1252 };
1253 project_state
1254 .license_detection_watchers
1255 .remove(&worktree_id);
1256 })
1257 .detach();
1258 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1259 });
1260 }
1261 }
1262
1263 match project_state.registered_buffers.entry(buffer_id) {
1264 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1265 hash_map::Entry::Vacant(entry) => {
1266 let buf = buffer.read(cx);
1267 let snapshot = buf.text_snapshot();
1268 let file = buf.file().cloned();
1269 let project_entity_id = project.entity_id();
1270 entry.insert(RegisteredBuffer {
1271 snapshot,
1272 file,
1273 last_position: None,
1274 pending_predictions: Vec::new(),
1275 _subscriptions: [
1276 cx.subscribe(buffer, {
1277 let project = project.downgrade();
1278 move |this, buffer, event, cx| {
1279 if let language::BufferEvent::Edited { is_local } = event
1280 && let Some(project) = project.upgrade()
1281 {
1282 this.report_changes_for_buffer(
1283 &buffer, &project, false, *is_local, cx,
1284 );
1285 }
1286 }
1287 }),
1288 cx.observe_release(buffer, move |this, _buffer, _cx| {
1289 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1290 else {
1291 return;
1292 };
1293 project_state.registered_buffers.remove(&buffer_id);
1294 }),
1295 ],
1296 })
1297 }
1298 }
1299 }
1300
1301 fn report_changes_for_buffer(
1302 &mut self,
1303 buffer: &Entity<Buffer>,
1304 project: &Entity<Project>,
1305 is_predicted: bool,
1306 is_local: bool,
1307 cx: &mut Context<Self>,
1308 ) {
1309 let project_state = self.get_or_init_project(project, cx);
1310 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1311
1312 let buf = buffer.read(cx);
1313 let new_file = buf.file().cloned();
1314 let new_snapshot = buf.text_snapshot();
1315 if new_snapshot.version == registered_buffer.snapshot.version {
1316 return;
1317 }
1318 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1319 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1320 let mut edit_range: Option<Range<Anchor>> = None;
1321 let now = cx.background_executor().now();
1322
1323 for (_edit, anchor_range) in
1324 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1325 {
1326 edit_range = Some(match edit_range {
1327 None => anchor_range,
1328 Some(acc) => acc.start..anchor_range.end,
1329 });
1330 }
1331
1332 let Some(edit_range) = edit_range else {
1333 return;
1334 };
1335
1336 for pending_prediction in &mut registered_buffer.pending_predictions {
1337 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1338 pending_prediction.last_edit_at = now;
1339 }
1340 }
1341
1342 let include_in_history = is_local
1343 || collaborator_edit_overlaps_locality_region(
1344 project_state,
1345 project,
1346 buffer,
1347 &buf.snapshot(),
1348 &edit_range,
1349 cx,
1350 );
1351
1352 if !include_in_history {
1353 return;
1354 }
1355
1356 let is_recordable_history_edit =
1357 compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1358 .is_some();
1359
1360 let events = &mut project_state.events;
1361
1362 if !is_recordable_history_edit {
1363 if let Some(event) = project_state.last_event.take() {
1364 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1365 if events.len() + 1 >= EVENT_COUNT_MAX {
1366 events.pop_front();
1367 }
1368 events.push_back(event);
1369 }
1370 }
1371 return;
1372 }
1373
1374 if let Some(last_event) = project_state.last_event.as_mut() {
1375 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1376 == last_event.new_snapshot.remote_id()
1377 && old_snapshot.version == last_event.new_snapshot.version;
1378
1379 let prediction_source_changed = is_predicted != last_event.predicted;
1380
1381 let should_coalesce = is_next_snapshot_of_same_buffer
1382 && !prediction_source_changed
1383 && lines_between_ranges(
1384 &edit_range.to_point(&new_snapshot),
1385 &last_event.latest_edit_range.to_point(&new_snapshot),
1386 ) <= CHANGE_GROUPING_LINE_SPAN;
1387
1388 if should_coalesce {
1389 let pause_elapsed = last_event
1390 .last_edit_time
1391 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1392 .unwrap_or(false);
1393 if pause_elapsed {
1394 last_event.snapshot_after_last_editing_pause =
1395 Some(last_event.new_snapshot.clone());
1396 last_event.total_edit_range_at_last_pause_boundary =
1397 Some(last_event.total_edit_range.clone());
1398 }
1399
1400 last_event.latest_edit_range = edit_range.clone();
1401 last_event.total_edit_range =
1402 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1403 last_event.new_snapshot = new_snapshot;
1404 last_event.last_edit_time = Some(now);
1405 return;
1406 }
1407 }
1408
1409 if let Some(event) = project_state.last_event.take() {
1410 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1411 if events.len() + 1 >= EVENT_COUNT_MAX {
1412 events.pop_front();
1413 }
1414 events.push_back(event);
1415 }
1416 }
1417
1418 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1419
1420 project_state.last_event = Some(LastEvent {
1421 old_file,
1422 new_file,
1423 old_snapshot,
1424 new_snapshot,
1425 latest_edit_range: edit_range.clone(),
1426 total_edit_range: edit_range,
1427 total_edit_range_at_last_pause_boundary: None,
1428 predicted: is_predicted,
1429 snapshot_after_last_editing_pause: None,
1430 last_edit_time: Some(now),
1431 });
1432 }
1433
1434 fn prediction_at(
1435 &mut self,
1436 buffer: &Entity<Buffer>,
1437 position: Option<language::Anchor>,
1438 project: &Entity<Project>,
1439 cx: &App,
1440 ) -> Option<BufferEditPrediction<'_>> {
1441 let project_state = self.projects.get_mut(&project.entity_id())?;
1442 if let Some(position) = position
1443 && let Some(buffer) = project_state
1444 .registered_buffers
1445 .get_mut(&buffer.entity_id())
1446 {
1447 buffer.last_position = Some(position);
1448 }
1449
1450 let CurrentEditPrediction {
1451 requested_by,
1452 prediction,
1453 ..
1454 } = project_state.current_prediction.as_ref()?;
1455
1456 if prediction.targets_buffer(buffer.read(cx)) {
1457 Some(BufferEditPrediction::Local { prediction })
1458 } else {
1459 let show_jump = match requested_by {
1460 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1461 requested_by_buffer_id == &buffer.entity_id()
1462 }
1463 PredictionRequestedBy::DiagnosticsUpdate => true,
1464 };
1465
1466 if show_jump {
1467 Some(BufferEditPrediction::Jump { prediction })
1468 } else {
1469 None
1470 }
1471 }
1472 }
1473
1474 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1475 let Some(current_prediction) = self
1476 .projects
1477 .get_mut(&project.entity_id())
1478 .and_then(|project_state| project_state.current_prediction.take())
1479 else {
1480 return;
1481 };
1482
1483 self.report_changes_for_buffer(
1484 ¤t_prediction.prediction.buffer,
1485 project,
1486 true,
1487 true,
1488 cx,
1489 );
1490
1491 // can't hold &mut project_state ref across report_changes_for_buffer_call
1492 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1493 return;
1494 };
1495
1496 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1497 project_state.cancel_pending_prediction(pending_prediction, cx);
1498 }
1499
1500 match self.edit_prediction_model {
1501 EditPredictionModel::Mercury => {
1502 mercury::edit_prediction_accepted(
1503 current_prediction.prediction.id,
1504 self.client.http_client(),
1505 cx,
1506 );
1507 }
1508 EditPredictionModel::Zeta => {
1509 let is_cloud = !matches!(
1510 all_language_settings(None, cx).edit_predictions.provider,
1511 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1512 );
1513 if is_cloud {
1514 zeta::edit_prediction_accepted(self, current_prediction, cx)
1515 }
1516 }
1517 EditPredictionModel::Fim { .. } => {}
1518 }
1519 }
1520
1521 async fn handle_rejected_predictions(
1522 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1523 client: Arc<Client>,
1524 llm_token: LlmApiToken,
1525 app_version: Version,
1526 background_executor: BackgroundExecutor,
1527 ) {
1528 let mut rx = std::pin::pin!(rx.peekable());
1529 let mut batched = Vec::new();
1530
1531 while let Some(EditPredictionRejectionPayload {
1532 rejection,
1533 organization_id,
1534 }) = rx.next().await
1535 {
1536 batched.push(rejection);
1537
1538 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1539 select_biased! {
1540 next = rx.as_mut().peek().fuse() => {
1541 if next.is_some() {
1542 continue;
1543 }
1544 }
1545 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1546 }
1547 }
1548
1549 let url = client
1550 .http_client()
1551 .build_zed_llm_url("/predict_edits/reject", &[])
1552 .unwrap();
1553
1554 let flush_count = batched
1555 .len()
1556 // in case items have accumulated after failure
1557 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1558 let start = batched.len() - flush_count;
1559
1560 let body = RejectEditPredictionsBodyRef {
1561 rejections: &batched[start..],
1562 };
1563
1564 let result = Self::send_api_request::<()>(
1565 |builder| {
1566 let req = builder
1567 .uri(url.as_ref())
1568 .body(serde_json::to_string(&body)?.into());
1569 anyhow::Ok(req?)
1570 },
1571 client.clone(),
1572 llm_token.clone(),
1573 organization_id,
1574 app_version.clone(),
1575 true,
1576 )
1577 .await;
1578
1579 if result.log_err().is_some() {
1580 batched.drain(start..);
1581 }
1582 }
1583 }
1584
1585 async fn run_settled_predictions_worker(
1586 this: WeakEntity<Self>,
1587 mut rx: UnboundedReceiver<Instant>,
1588 cx: &mut AsyncApp,
1589 ) {
1590 let mut next_wake_time: Option<Instant> = None;
1591 loop {
1592 let now = cx.background_executor().now();
1593 if let Some(wake_time) = next_wake_time.take() {
1594 cx.background_executor()
1595 .timer(wake_time.duration_since(now))
1596 .await;
1597 } else {
1598 let Some(new_enqueue_time) = rx.next().await else {
1599 break;
1600 };
1601 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1602 while rx.next().now_or_never().flatten().is_some() {}
1603 continue;
1604 }
1605
1606 let Some(this) = this.upgrade() else {
1607 break;
1608 };
1609
1610 let now = cx.background_executor().now();
1611 let mut oldest_edited_at = None;
1612 let mut ready_predictions = Vec::new();
1613
1614 this.update(cx, |this, _| {
1615 for (_, project_state) in this.projects.iter_mut() {
1616 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1617 let mut pending_index = 0;
1618 while pending_index < registered_buffer.pending_predictions.len() {
1619 let pending_prediction =
1620 ®istered_buffer.pending_predictions[pending_index];
1621 let age = now.saturating_duration_since(pending_prediction.enqueued_at);
1622 if age >= EDIT_PREDICTION_SETTLED_TTL {
1623 registered_buffer.pending_predictions.remove(pending_index);
1624 continue;
1625 }
1626
1627 let quiet_for =
1628 now.saturating_duration_since(pending_prediction.last_edit_at);
1629 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1630 let pending_prediction =
1631 registered_buffer.pending_predictions.remove(pending_index);
1632 let settled_editable_region = registered_buffer
1633 .snapshot
1634 .text_for_range(
1635 pending_prediction.editable_anchor_range.clone(),
1636 )
1637 .collect::<String>();
1638 ready_predictions
1639 .push((pending_prediction, settled_editable_region));
1640 continue;
1641 }
1642
1643 if oldest_edited_at
1644 .is_none_or(|time| pending_prediction.last_edit_at < time)
1645 {
1646 oldest_edited_at = Some(pending_prediction.last_edit_at);
1647 }
1648 pending_index += 1;
1649 }
1650 }
1651 }
1652 });
1653
1654 for (pending_prediction, settled_editable_region) in ready_predictions {
1655 let PendingSettledPrediction {
1656 request_id,
1657 editable_region_before_prediction,
1658 predicted_editable_region,
1659 ts_error_count_before_prediction,
1660 ts_error_count_after_prediction,
1661 example,
1662 e2e_latency,
1663 ..
1664 } = pending_prediction;
1665 let settled_editable_region_for_metrics = settled_editable_region.clone();
1666 let kept_rate_result = cx
1667 .background_spawn(async move {
1668 compute_kept_rate(
1669 &editable_region_before_prediction,
1670 &predicted_editable_region,
1671 &settled_editable_region_for_metrics,
1672 )
1673 })
1674 .await;
1675
1676 #[cfg(test)]
1677 {
1678 let request_id = request_id.clone();
1679 let settled_editable_region = settled_editable_region.clone();
1680 this.update(cx, |this, _| {
1681 if let Some(callback) = &this.settled_event_callback {
1682 callback(request_id, settled_editable_region);
1683 }
1684 });
1685 }
1686
1687 telemetry::event!(
1688 EDIT_PREDICTION_SETTLED_EVENT,
1689 request_id = request_id.0.clone(),
1690 settled_editable_region,
1691 ts_error_count_before_prediction,
1692 ts_error_count_after_prediction,
1693 edit_bytes_predicted_new = kept_rate_result.predicted_new_chars,
1694 edit_bytes_final_new = kept_rate_result.final_new_chars,
1695 edit_bytes_kept = kept_rate_result.kept_chars,
1696 edit_bytes_discarded = kept_rate_result.discarded_chars,
1697 edit_bytes_context = kept_rate_result.context_chars,
1698 edit_bytes_kept_rate = kept_rate_result.kept_rate,
1699 example,
1700 e2e_latency = e2e_latency.as_millis(),
1701 );
1702 }
1703
1704 next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1705 }
1706 }
1707
1708 pub(crate) fn enqueue_settled_prediction(
1709 &mut self,
1710 request_id: EditPredictionId,
1711 project: &Entity<Project>,
1712 edited_buffer: &Entity<Buffer>,
1713 edited_buffer_snapshot: &BufferSnapshot,
1714 editable_offset_range: Range<usize>,
1715 edit_preview: &EditPreview,
1716 example: Option<ExampleSpec>,
1717 e2e_latency: std::time::Duration,
1718 cx: &mut Context<Self>,
1719 ) {
1720 let this = &mut *self;
1721 let project_state = this.get_or_init_project(project, cx);
1722 let Some(registered_buffer) = project_state
1723 .registered_buffers
1724 .get_mut(&edited_buffer.entity_id())
1725 else {
1726 return;
1727 };
1728
1729 let editable_region_before_prediction = edited_buffer_snapshot
1730 .text_for_range(editable_offset_range.clone())
1731 .collect::<String>();
1732 let editable_anchor_range_for_result =
1733 edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
1734 let predicted_editable_region = edit_preview
1735 .result_text_snapshot()
1736 .text_for_range(editable_anchor_range_for_result.clone())
1737 .collect();
1738 let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
1739 edited_buffer_snapshot
1740 .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
1741 );
1742 let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
1743 edit_preview.result_syntax_snapshot().layers_for_range(
1744 editable_anchor_range_for_result,
1745 edit_preview.result_text_snapshot(),
1746 true,
1747 ),
1748 );
1749 let editable_anchor_range =
1750 edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
1751 let now = cx.background_executor().now();
1752 registered_buffer
1753 .pending_predictions
1754 .push(PendingSettledPrediction {
1755 request_id,
1756 editable_anchor_range,
1757 editable_region_before_prediction,
1758 predicted_editable_region,
1759 ts_error_count_before_prediction,
1760 ts_error_count_after_prediction,
1761 example,
1762 e2e_latency,
1763 enqueued_at: now,
1764 last_edit_at: now,
1765 });
1766 this.settled_predictions_tx.unbounded_send(now).ok();
1767 }
1768
1769 fn reject_current_prediction(
1770 &mut self,
1771 reason: EditPredictionRejectReason,
1772 project: &Entity<Project>,
1773 cx: &App,
1774 ) {
1775 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1776 project_state.pending_predictions.clear();
1777 if let Some(prediction) = project_state.current_prediction.take() {
1778 let model_version = prediction.prediction.model_version.clone();
1779 self.reject_prediction(
1780 prediction.prediction.id,
1781 reason,
1782 prediction.was_shown,
1783 model_version,
1784 Some(prediction.e2e_latency),
1785 cx,
1786 );
1787 }
1788 };
1789 }
1790
1791 fn did_show_current_prediction(
1792 &mut self,
1793 project: &Entity<Project>,
1794 display_type: edit_prediction_types::SuggestionDisplayType,
1795 _cx: &mut Context<Self>,
1796 ) {
1797 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1798 return;
1799 };
1800
1801 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1802 return;
1803 };
1804
1805 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1806 let previous_shown_with = current_prediction.shown_with;
1807
1808 if previous_shown_with.is_none() || !is_jump {
1809 current_prediction.shown_with = Some(display_type);
1810 }
1811
1812 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1813
1814 if is_first_non_jump_show {
1815 current_prediction.was_shown = true;
1816 }
1817
1818 if is_first_non_jump_show {
1819 self.shown_predictions
1820 .push_front(current_prediction.prediction.clone());
1821 if self.shown_predictions.len() > 50 {
1822 let completion = self.shown_predictions.pop_back().unwrap();
1823 self.rated_predictions.remove(&completion.id);
1824 }
1825 }
1826 }
1827
1828 fn reject_prediction(
1829 &mut self,
1830 prediction_id: EditPredictionId,
1831 reason: EditPredictionRejectReason,
1832 was_shown: bool,
1833 model_version: Option<String>,
1834 e2e_latency: Option<std::time::Duration>,
1835 cx: &App,
1836 ) {
1837 match self.edit_prediction_model {
1838 EditPredictionModel::Zeta => {
1839 let is_cloud = !matches!(
1840 all_language_settings(None, cx).edit_predictions.provider,
1841 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1842 );
1843
1844 if is_cloud {
1845 let organization_id = self
1846 .user_store
1847 .read(cx)
1848 .current_organization()
1849 .map(|organization| organization.id.clone());
1850
1851 self.reject_predictions_tx
1852 .unbounded_send(EditPredictionRejectionPayload {
1853 rejection: EditPredictionRejection {
1854 request_id: prediction_id.to_string(),
1855 reason,
1856 was_shown,
1857 model_version,
1858 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1859 },
1860 organization_id,
1861 })
1862 .log_err();
1863 }
1864 }
1865 EditPredictionModel::Mercury => {
1866 mercury::edit_prediction_rejected(
1867 prediction_id,
1868 was_shown,
1869 reason,
1870 self.client.http_client(),
1871 cx,
1872 );
1873 }
1874 EditPredictionModel::Fim { .. } => {}
1875 }
1876 }
1877
1878 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1879 self.projects
1880 .get(&project.entity_id())
1881 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1882 }
1883
1884 pub fn refresh_prediction_from_buffer(
1885 &mut self,
1886 project: Entity<Project>,
1887 buffer: Entity<Buffer>,
1888 position: language::Anchor,
1889 cx: &mut Context<Self>,
1890 ) {
1891 self.queue_prediction_refresh(
1892 project.clone(),
1893 PredictEditsRequestTrigger::Other,
1894 buffer.entity_id(),
1895 cx,
1896 move |this, cx| {
1897 let Some(request_task) = this
1898 .update(cx, |this, cx| {
1899 this.request_prediction(
1900 &project,
1901 &buffer,
1902 position,
1903 PredictEditsRequestTrigger::Other,
1904 cx,
1905 )
1906 })
1907 .log_err()
1908 else {
1909 return Task::ready(anyhow::Ok(None));
1910 };
1911
1912 cx.spawn(async move |_cx| {
1913 request_task.await.map(|prediction_result| {
1914 prediction_result.map(|prediction_result| {
1915 (
1916 prediction_result,
1917 PredictionRequestedBy::Buffer(buffer.entity_id()),
1918 )
1919 })
1920 })
1921 })
1922 },
1923 )
1924 }
1925
1926 pub fn refresh_prediction_from_diagnostics(
1927 &mut self,
1928 project: Entity<Project>,
1929 scope: DiagnosticSearchScope,
1930 cx: &mut Context<Self>,
1931 ) {
1932 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1933 return;
1934 }
1935
1936 if currently_following(&project, cx) {
1937 return;
1938 }
1939
1940 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1941 return;
1942 };
1943
1944 // Prefer predictions from buffer
1945 if project_state.current_prediction.is_some() {
1946 log::debug!(
1947 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1948 );
1949 return;
1950 }
1951
1952 self.queue_prediction_refresh(
1953 project.clone(),
1954 PredictEditsRequestTrigger::Diagnostics,
1955 project.entity_id(),
1956 cx,
1957 move |this, cx| {
1958 let Some((active_buffer, snapshot, cursor_point)) = this
1959 .read_with(cx, |this, cx| {
1960 let project_state = this.projects.get(&project.entity_id())?;
1961 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1962 let snapshot = buffer.read(cx).snapshot();
1963
1964 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1965 return None;
1966 }
1967
1968 let cursor_point = position
1969 .map(|pos| pos.to_point(&snapshot))
1970 .unwrap_or_default();
1971
1972 Some((buffer, snapshot, cursor_point))
1973 })
1974 .log_err()
1975 .flatten()
1976 else {
1977 return Task::ready(anyhow::Ok(None));
1978 };
1979
1980 cx.spawn(async move |cx| {
1981 let diagnostic_search_range = match scope {
1982 DiagnosticSearchScope::Local => {
1983 let diagnostic_search_start =
1984 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1985 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1986 Point::new(diagnostic_search_start, 0)
1987 ..Point::new(diagnostic_search_end, 0)
1988 }
1989 DiagnosticSearchScope::Global => Default::default(),
1990 };
1991
1992 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1993 active_buffer,
1994 &snapshot,
1995 diagnostic_search_range,
1996 cursor_point,
1997 &project,
1998 cx,
1999 )
2000 .await?
2001 else {
2002 return anyhow::Ok(None);
2003 };
2004
2005 let Some(prediction_result) = this
2006 .update(cx, |this, cx| {
2007 this.request_prediction(
2008 &project,
2009 &jump_buffer,
2010 jump_position,
2011 PredictEditsRequestTrigger::Diagnostics,
2012 cx,
2013 )
2014 })?
2015 .await?
2016 else {
2017 return anyhow::Ok(None);
2018 };
2019
2020 this.update(cx, |this, cx| {
2021 Some((
2022 if this
2023 .get_or_init_project(&project, cx)
2024 .current_prediction
2025 .is_none()
2026 {
2027 prediction_result
2028 } else {
2029 EditPredictionResult {
2030 id: prediction_result.id,
2031 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2032 e2e_latency: prediction_result.e2e_latency,
2033 }
2034 },
2035 PredictionRequestedBy::DiagnosticsUpdate,
2036 ))
2037 })
2038 })
2039 },
2040 );
2041 }
2042
2043 fn predictions_enabled_at(
2044 snapshot: &BufferSnapshot,
2045 position: Option<language::Anchor>,
2046 cx: &App,
2047 ) -> bool {
2048 let file = snapshot.file();
2049 let all_settings = all_language_settings(file, cx);
2050 if !all_settings.show_edit_predictions(snapshot.language(), cx)
2051 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2052 {
2053 return false;
2054 }
2055
2056 if let Some(last_position) = position {
2057 let settings = snapshot.settings_at(last_position, cx);
2058
2059 if !settings.edit_predictions_disabled_in.is_empty()
2060 && let Some(scope) = snapshot.language_scope_at(last_position)
2061 && let Some(scope_name) = scope.override_name()
2062 && settings
2063 .edit_predictions_disabled_in
2064 .iter()
2065 .any(|s| s == scope_name)
2066 {
2067 return false;
2068 }
2069 }
2070
2071 true
2072 }
2073
2074 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2075}
2076
2077fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2078 let Some(app_state) = AppState::try_global(cx) else {
2079 return false;
2080 };
2081
2082 app_state
2083 .workspace_store
2084 .read(cx)
2085 .workspaces()
2086 .filter_map(|workspace| workspace.upgrade())
2087 .any(|workspace| {
2088 workspace.read(cx).project().entity_id() == project.entity_id()
2089 && workspace
2090 .read(cx)
2091 .leader_for_pane(workspace.read(cx).active_pane())
2092 .is_some()
2093 })
2094}
2095
2096fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2097 match provider {
2098 EditPredictionProvider::Zed
2099 | EditPredictionProvider::Mercury
2100 | EditPredictionProvider::Ollama
2101 | EditPredictionProvider::OpenAiCompatibleApi
2102 | EditPredictionProvider::Experimental(_) => true,
2103 EditPredictionProvider::None
2104 | EditPredictionProvider::Copilot
2105 | EditPredictionProvider::Codestral => false,
2106 }
2107}
2108
2109impl EditPredictionStore {
2110 fn queue_prediction_refresh(
2111 &mut self,
2112 project: Entity<Project>,
2113 request_trigger: PredictEditsRequestTrigger,
2114 throttle_entity: EntityId,
2115 cx: &mut Context<Self>,
2116 do_refresh: impl FnOnce(
2117 WeakEntity<Self>,
2118 &mut AsyncApp,
2119 )
2120 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2121 + 'static,
2122 ) {
2123 fn select_throttle(
2124 project_state: &mut ProjectState,
2125 request_trigger: PredictEditsRequestTrigger,
2126 ) -> &mut Option<(EntityId, Instant)> {
2127 match request_trigger {
2128 PredictEditsRequestTrigger::Diagnostics => {
2129 &mut project_state.last_jump_prediction_refresh
2130 }
2131 _ => &mut project_state.last_edit_prediction_refresh,
2132 }
2133 }
2134
2135 let (needs_acceptance_tracking, max_pending_predictions) =
2136 match all_language_settings(None, cx).edit_predictions.provider {
2137 EditPredictionProvider::Zed
2138 | EditPredictionProvider::Mercury
2139 | EditPredictionProvider::Experimental(_) => (true, 2),
2140 EditPredictionProvider::Ollama => (false, 1),
2141 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2142 EditPredictionProvider::None
2143 | EditPredictionProvider::Copilot
2144 | EditPredictionProvider::Codestral => {
2145 log::error!("queue_prediction_refresh called with non-store provider");
2146 return;
2147 }
2148 };
2149
2150 let drop_on_cancel = !needs_acceptance_tracking;
2151 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2152 let project_state = self.get_or_init_project(&project, cx);
2153 let pending_prediction_id = project_state.next_pending_prediction_id;
2154 project_state.next_pending_prediction_id += 1;
2155 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2156
2157 let task = cx.spawn(async move |this, cx| {
2158 let throttle_wait = this
2159 .update(cx, |this, cx| {
2160 let project_state = this.get_or_init_project(&project, cx);
2161 let throttle = *select_throttle(project_state, request_trigger);
2162
2163 let now = cx.background_executor().now();
2164 throttle.and_then(|(last_entity, last_timestamp)| {
2165 if throttle_entity != last_entity {
2166 return None;
2167 }
2168 (last_timestamp + throttle_timeout).checked_duration_since(now)
2169 })
2170 })
2171 .ok()
2172 .flatten();
2173
2174 if let Some(timeout) = throttle_wait {
2175 cx.background_executor().timer(timeout).await;
2176 }
2177
2178 // If this task was cancelled before the throttle timeout expired,
2179 // do not perform a request. Also skip if another task already
2180 // proceeded since we were enqueued (duplicate).
2181 let mut is_cancelled = true;
2182 this.update(cx, |this, cx| {
2183 let project_state = this.get_or_init_project(&project, cx);
2184 let was_cancelled = project_state
2185 .cancelled_predictions
2186 .remove(&pending_prediction_id);
2187 if was_cancelled {
2188 return;
2189 }
2190
2191 // Another request has been already sent since this was enqueued
2192 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2193 return;
2194 }
2195
2196 let new_refresh = (throttle_entity, cx.background_executor().now());
2197 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2198 is_cancelled = false;
2199 })
2200 .ok();
2201 if is_cancelled {
2202 return None;
2203 }
2204
2205 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2206 let new_prediction_id = new_prediction_result
2207 .as_ref()
2208 .map(|(prediction, _)| prediction.id.clone());
2209
2210 // When a prediction completes, remove it from the pending list, and cancel
2211 // any pending predictions that were enqueued before it.
2212 this.update(cx, |this, cx| {
2213 let project_state = this.get_or_init_project(&project, cx);
2214
2215 let is_cancelled = project_state
2216 .cancelled_predictions
2217 .remove(&pending_prediction_id);
2218
2219 let new_current_prediction = if !is_cancelled
2220 && let Some((prediction_result, requested_by)) = new_prediction_result
2221 {
2222 match prediction_result.prediction {
2223 Ok(prediction) => {
2224 let new_prediction = CurrentEditPrediction {
2225 requested_by,
2226 prediction,
2227 was_shown: false,
2228 shown_with: None,
2229 e2e_latency: prediction_result.e2e_latency,
2230 };
2231
2232 if let Some(current_prediction) =
2233 project_state.current_prediction.as_ref()
2234 {
2235 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2236 {
2237 this.reject_current_prediction(
2238 EditPredictionRejectReason::Replaced,
2239 &project,
2240 cx,
2241 );
2242
2243 Some(new_prediction)
2244 } else {
2245 this.reject_prediction(
2246 new_prediction.prediction.id,
2247 EditPredictionRejectReason::CurrentPreferred,
2248 false,
2249 new_prediction.prediction.model_version,
2250 Some(new_prediction.e2e_latency),
2251 cx,
2252 );
2253 None
2254 }
2255 } else {
2256 Some(new_prediction)
2257 }
2258 }
2259 Err(reject_reason) => {
2260 this.reject_prediction(
2261 prediction_result.id,
2262 reject_reason,
2263 false,
2264 None,
2265 Some(prediction_result.e2e_latency),
2266 cx,
2267 );
2268 None
2269 }
2270 }
2271 } else {
2272 None
2273 };
2274
2275 let project_state = this.get_or_init_project(&project, cx);
2276
2277 if let Some(new_prediction) = new_current_prediction {
2278 project_state.current_prediction = Some(new_prediction);
2279 }
2280
2281 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2282 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2283 if pending_prediction.id == pending_prediction_id {
2284 pending_predictions.remove(ix);
2285 for pending_prediction in pending_predictions.drain(0..ix) {
2286 project_state.cancel_pending_prediction(pending_prediction, cx)
2287 }
2288 break;
2289 }
2290 }
2291 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2292 cx.notify();
2293 })
2294 .ok();
2295
2296 new_prediction_id
2297 });
2298
2299 if project_state.pending_predictions.len() < max_pending_predictions {
2300 project_state
2301 .pending_predictions
2302 .push(PendingPrediction {
2303 id: pending_prediction_id,
2304 task,
2305 drop_on_cancel,
2306 })
2307 .unwrap();
2308 } else {
2309 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2310 project_state
2311 .pending_predictions
2312 .push(PendingPrediction {
2313 id: pending_prediction_id,
2314 task,
2315 drop_on_cancel,
2316 })
2317 .unwrap();
2318 project_state.cancel_pending_prediction(pending_prediction, cx);
2319 }
2320 }
2321
2322 pub fn request_prediction(
2323 &mut self,
2324 project: &Entity<Project>,
2325 active_buffer: &Entity<Buffer>,
2326 position: language::Anchor,
2327 trigger: PredictEditsRequestTrigger,
2328 cx: &mut Context<Self>,
2329 ) -> Task<Result<Option<EditPredictionResult>>> {
2330 self.request_prediction_internal(
2331 project.clone(),
2332 active_buffer.clone(),
2333 position,
2334 trigger,
2335 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2336 cx,
2337 )
2338 }
2339
2340 fn request_prediction_internal(
2341 &mut self,
2342 project: Entity<Project>,
2343 active_buffer: Entity<Buffer>,
2344 position: language::Anchor,
2345 trigger: PredictEditsRequestTrigger,
2346 allow_jump: bool,
2347 cx: &mut Context<Self>,
2348 ) -> Task<Result<Option<EditPredictionResult>>> {
2349 self.get_or_init_project(&project, cx);
2350 let project_state = self.projects.get(&project.entity_id()).unwrap();
2351 let stored_events = project_state.events(cx);
2352 let has_events = !stored_events.is_empty();
2353 let events: Vec<Arc<zeta_prompt::Event>> =
2354 stored_events.iter().map(|e| e.event.clone()).collect();
2355 let debug_tx = project_state.debug_tx.clone();
2356
2357 let snapshot = active_buffer.read(cx).snapshot();
2358 let cursor_point = position.to_point(&snapshot);
2359 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2360 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2361 let diagnostic_search_range =
2362 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2363
2364 let related_files = self.context_for_project(&project, cx);
2365
2366 let is_open_source = snapshot
2367 .file()
2368 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2369 && events.iter().all(|event| event.in_open_source_repo())
2370 && related_files.iter().all(|file| file.in_open_source_repo);
2371
2372 let can_collect_data = !cfg!(test)
2373 && is_open_source
2374 && self.is_data_collection_enabled(cx)
2375 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2376
2377 let inputs = EditPredictionModelInput {
2378 project: project.clone(),
2379 buffer: active_buffer,
2380 snapshot,
2381 position,
2382 events,
2383 related_files,
2384 trigger,
2385 diagnostic_search_range: diagnostic_search_range,
2386 debug_tx,
2387 can_collect_data,
2388 is_open_source,
2389 };
2390
2391 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2392
2393 let task = match self.edit_prediction_model {
2394 EditPredictionModel::Zeta => {
2395 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2396 }
2397 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2398 EditPredictionModel::Mercury => {
2399 self.mercury
2400 .request_prediction(inputs, self.credentials_provider.clone(), cx)
2401 }
2402 };
2403
2404 cx.spawn(async move |this, cx| {
2405 let prediction = task.await?;
2406
2407 // Only fall back to diagnostics-based prediction if we got a
2408 // the model had nothing to suggest for the buffer
2409 if prediction.is_none()
2410 && allow_jump
2411 && has_events
2412 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2413 {
2414 this.update(cx, |this, cx| {
2415 this.refresh_prediction_from_diagnostics(
2416 project,
2417 DiagnosticSearchScope::Local,
2418 cx,
2419 );
2420 })?;
2421 return anyhow::Ok(None);
2422 }
2423
2424 Ok(prediction)
2425 })
2426 }
2427
2428 pub(crate) async fn next_diagnostic_location(
2429 active_buffer: Entity<Buffer>,
2430 active_buffer_snapshot: &BufferSnapshot,
2431 active_buffer_diagnostic_search_range: Range<Point>,
2432 active_buffer_cursor_point: Point,
2433 project: &Entity<Project>,
2434 cx: &mut AsyncApp,
2435 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2436 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2437 .selections_in_range(
2438 Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2439 false,
2440 )
2441 .flat_map(|(_, _, _, selections)| {
2442 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2443 })
2444 .collect();
2445
2446 let mut jump_location = active_buffer_snapshot
2447 .diagnostic_groups(None)
2448 .into_iter()
2449 .filter_map(|(_, group)| {
2450 let range = &group.entries[group.primary_ix]
2451 .range
2452 .to_point(&active_buffer_snapshot);
2453 if range.overlaps(&active_buffer_diagnostic_search_range) {
2454 return None;
2455 }
2456 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2457 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2458 });
2459 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2460 <= DIAGNOSTIC_LINES_RANGE;
2461 if near_collaborator && !near_local {
2462 return None;
2463 }
2464 Some(range.start)
2465 })
2466 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2467 .map(|position| {
2468 (
2469 active_buffer.clone(),
2470 active_buffer_snapshot.anchor_before(position),
2471 )
2472 });
2473
2474 if jump_location.is_none() {
2475 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2476 let file = buffer.file()?;
2477
2478 Some(ProjectPath {
2479 worktree_id: file.worktree_id(cx),
2480 path: file.path().clone(),
2481 })
2482 });
2483
2484 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2485 project
2486 .diagnostic_summaries(false, cx)
2487 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2488 .map(|(path, _, _)| {
2489 let shared_prefix = path
2490 .path
2491 .components()
2492 .zip(
2493 active_buffer_path
2494 .as_ref()
2495 .map(|p| p.path.components())
2496 .unwrap_or_default(),
2497 )
2498 .take_while(|(a, b)| a == b)
2499 .count();
2500 (path, shared_prefix)
2501 })
2502 .collect()
2503 });
2504
2505 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2506
2507 for (path, _) in candidates {
2508 let candidate_buffer = project
2509 .update(cx, |project, cx| project.open_buffer(path, cx))
2510 .await?;
2511
2512 let (has_collaborators, diagnostic_position) =
2513 candidate_buffer.read_with(cx, |buffer, _cx| {
2514 let snapshot = buffer.snapshot();
2515 let has_collaborators = snapshot
2516 .selections_in_range(
2517 Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2518 false,
2519 )
2520 .next()
2521 .is_some();
2522 let position = buffer
2523 .buffer_diagnostics(None)
2524 .into_iter()
2525 .min_by_key(|entry| entry.diagnostic.severity)
2526 .map(|entry| entry.range.start);
2527 (has_collaborators, position)
2528 });
2529
2530 if has_collaborators {
2531 continue;
2532 }
2533
2534 if let Some(position) = diagnostic_position {
2535 jump_location = Some((candidate_buffer, position));
2536 break;
2537 }
2538 }
2539 }
2540
2541 anyhow::Ok(jump_location)
2542 }
2543
2544 async fn send_raw_llm_request(
2545 request: RawCompletionRequest,
2546 client: Arc<Client>,
2547 custom_url: Option<Arc<Url>>,
2548 llm_token: LlmApiToken,
2549 organization_id: Option<OrganizationId>,
2550 app_version: Version,
2551 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2552 let url = if let Some(custom_url) = custom_url {
2553 custom_url.as_ref().clone()
2554 } else {
2555 client
2556 .http_client()
2557 .build_zed_llm_url("/predict_edits/raw", &[])?
2558 };
2559
2560 Self::send_api_request(
2561 |builder| {
2562 let req = builder
2563 .uri(url.as_ref())
2564 .body(serde_json::to_string(&request)?.into());
2565 Ok(req?)
2566 },
2567 client,
2568 llm_token,
2569 organization_id,
2570 app_version,
2571 true,
2572 )
2573 .await
2574 }
2575
2576 pub(crate) async fn send_v3_request(
2577 input: ZetaPromptInput,
2578 client: Arc<Client>,
2579 llm_token: LlmApiToken,
2580 organization_id: Option<OrganizationId>,
2581 app_version: Version,
2582 trigger: PredictEditsRequestTrigger,
2583 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2584 let url = client
2585 .http_client()
2586 .build_zed_llm_url("/predict_edits/v3", &[])?;
2587
2588 let request = PredictEditsV3Request { input, trigger };
2589
2590 let json_bytes = serde_json::to_vec(&request)?;
2591 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2592
2593 Self::send_api_request(
2594 |builder| {
2595 let req = builder
2596 .uri(url.as_ref())
2597 .header("Content-Encoding", "zstd")
2598 .body(compressed.clone().into());
2599 Ok(req?)
2600 },
2601 client,
2602 llm_token,
2603 organization_id,
2604 app_version,
2605 true,
2606 )
2607 .await
2608 }
2609
2610 async fn send_api_request<Res>(
2611 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2612 client: Arc<Client>,
2613 llm_token: LlmApiToken,
2614 organization_id: Option<OrganizationId>,
2615 app_version: Version,
2616 require_auth: bool,
2617 ) -> Result<(Res, Option<EditPredictionUsage>)>
2618 where
2619 Res: DeserializeOwned,
2620 {
2621 let http_client = client.http_client();
2622 let mut token = if require_auth {
2623 Some(
2624 client
2625 .acquire_llm_token(&llm_token, organization_id.clone())
2626 .await?,
2627 )
2628 } else {
2629 client
2630 .acquire_llm_token(&llm_token, organization_id.clone())
2631 .await
2632 .ok()
2633 };
2634 let mut did_retry = false;
2635
2636 loop {
2637 let request_builder = http_client::Request::builder().method(Method::POST);
2638
2639 let mut request_builder = request_builder
2640 .header("Content-Type", "application/json")
2641 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2642
2643 // Only add Authorization header if we have a token
2644 if let Some(ref token_value) = token {
2645 request_builder =
2646 request_builder.header("Authorization", format!("Bearer {}", token_value));
2647 }
2648
2649 let request = build(request_builder)?;
2650
2651 let mut response = http_client.send(request).await?;
2652
2653 if let Some(minimum_required_version) = response
2654 .headers()
2655 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2656 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2657 {
2658 anyhow::ensure!(
2659 app_version >= minimum_required_version,
2660 ZedUpdateRequiredError {
2661 minimum_version: minimum_required_version
2662 }
2663 );
2664 }
2665
2666 if response.status().is_success() {
2667 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2668
2669 let mut body = Vec::new();
2670 response.body_mut().read_to_end(&mut body).await?;
2671 return Ok((serde_json::from_slice(&body)?, usage));
2672 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2673 did_retry = true;
2674 token = Some(
2675 client
2676 .refresh_llm_token(&llm_token, organization_id.clone())
2677 .await?,
2678 );
2679 } else {
2680 let mut body = String::new();
2681 response.body_mut().read_to_string(&mut body).await?;
2682 anyhow::bail!(
2683 "Request failed with status: {:?}\nBody: {}",
2684 response.status(),
2685 body
2686 );
2687 }
2688 }
2689 }
2690
2691 pub fn refresh_context(
2692 &mut self,
2693 project: &Entity<Project>,
2694 buffer: &Entity<language::Buffer>,
2695 cursor_position: language::Anchor,
2696 cx: &mut Context<Self>,
2697 ) {
2698 self.get_or_init_project(project, cx)
2699 .context
2700 .update(cx, |store, cx| {
2701 store.refresh(buffer.clone(), cursor_position, cx);
2702 });
2703 }
2704
2705 #[cfg(feature = "cli-support")]
2706 pub fn set_context_for_buffer(
2707 &mut self,
2708 project: &Entity<Project>,
2709 related_files: Vec<RelatedFile>,
2710 cx: &mut Context<Self>,
2711 ) {
2712 self.get_or_init_project(project, cx)
2713 .context
2714 .update(cx, |store, cx| {
2715 store.set_related_files(related_files, cx);
2716 });
2717 }
2718
2719 #[cfg(feature = "cli-support")]
2720 pub fn set_recent_paths_for_project(
2721 &mut self,
2722 project: &Entity<Project>,
2723 paths: impl IntoIterator<Item = project::ProjectPath>,
2724 cx: &mut Context<Self>,
2725 ) {
2726 let project_state = self.get_or_init_project(project, cx);
2727 project_state.recent_paths = paths.into_iter().collect();
2728 }
2729
2730 fn is_file_open_source(
2731 &self,
2732 project: &Entity<Project>,
2733 file: &Arc<dyn File>,
2734 cx: &App,
2735 ) -> bool {
2736 if !file.is_local() || file.is_private() {
2737 return false;
2738 }
2739 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2740 return false;
2741 };
2742 project_state
2743 .license_detection_watchers
2744 .get(&file.worktree_id(cx))
2745 .as_ref()
2746 .is_some_and(|watcher| watcher.is_project_open_source())
2747 }
2748
2749 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2750 self.data_collection_choice.is_enabled(cx)
2751 }
2752
2753 fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2754 let choice = KeyValueStore::global(cx)
2755 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2756 .log_err()
2757 .flatten();
2758
2759 match choice.as_deref() {
2760 Some("true") => DataCollectionChoice::Enabled,
2761 Some("false") => DataCollectionChoice::Disabled,
2762 Some(_) => {
2763 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2764 DataCollectionChoice::NotAnswered
2765 }
2766 None => DataCollectionChoice::NotAnswered,
2767 }
2768 }
2769
2770 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2771 self.data_collection_choice = self.data_collection_choice.toggle();
2772 let new_choice = self.data_collection_choice;
2773 let is_enabled = new_choice.is_enabled(cx);
2774 let kvp = KeyValueStore::global(cx);
2775 db::write_and_log(cx, move || async move {
2776 kvp.write_kvp(
2777 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2778 is_enabled.to_string(),
2779 )
2780 .await
2781 });
2782 }
2783
2784 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2785 self.shown_predictions.iter()
2786 }
2787
2788 pub fn shown_completions_len(&self) -> usize {
2789 self.shown_predictions.len()
2790 }
2791
2792 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2793 self.rated_predictions.contains(id)
2794 }
2795
2796 pub fn rate_prediction(
2797 &mut self,
2798 prediction: &EditPrediction,
2799 rating: EditPredictionRating,
2800 feedback: String,
2801 cx: &mut Context<Self>,
2802 ) {
2803 let organization = self.user_store.read(cx).current_organization();
2804
2805 self.rated_predictions.insert(prediction.id.clone());
2806
2807 cx.background_spawn({
2808 let client = self.client.clone();
2809 let prediction_id = prediction.id.to_string();
2810 let inputs = serde_json::to_value(&prediction.inputs);
2811 let output = prediction
2812 .edit_preview
2813 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2814 async move {
2815 client
2816 .cloud_client()
2817 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2818 organization_id: organization.map(|organization| organization.id.clone()),
2819 request_id: prediction_id,
2820 rating: match rating {
2821 EditPredictionRating::Positive => "positive".to_string(),
2822 EditPredictionRating::Negative => "negative".to_string(),
2823 },
2824 inputs: inputs?,
2825 output,
2826 feedback,
2827 })
2828 .await?;
2829
2830 anyhow::Ok(())
2831 }
2832 })
2833 .detach_and_log_err(cx);
2834
2835 cx.notify();
2836 }
2837}
2838
2839fn collaborator_edit_overlaps_locality_region(
2840 project_state: &ProjectState,
2841 project: &Entity<Project>,
2842 buffer: &Entity<Buffer>,
2843 snapshot: &BufferSnapshot,
2844 edit_range: &Range<Anchor>,
2845 cx: &App,
2846) -> bool {
2847 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2848 return false;
2849 };
2850
2851 if active_buffer.entity_id() != buffer.entity_id() {
2852 return false;
2853 }
2854
2855 let locality_point_range = expand_context_syntactically_then_linewise(
2856 snapshot,
2857 (position..position).to_point(snapshot),
2858 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2859 );
2860 let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2861
2862 edit_range.overlaps(&locality_anchor_range, snapshot)
2863}
2864
2865fn merge_trailing_events_if_needed(
2866 events: &mut VecDeque<StoredEvent>,
2867 end_snapshot: &TextBufferSnapshot,
2868 latest_snapshot: &TextBufferSnapshot,
2869 latest_edit_range: &Range<Anchor>,
2870) {
2871 if let Some(last_event) = events.back() {
2872 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2873 return;
2874 }
2875 if !latest_snapshot
2876 .version
2877 .observed_all(&last_event.new_snapshot_version)
2878 {
2879 return;
2880 }
2881 }
2882
2883 let mut next_old_event = None;
2884 let mut mergeable_count = 0;
2885 for old_event in events.iter().rev() {
2886 if let Some(next_old_event) = next_old_event
2887 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2888 {
2889 break;
2890 }
2891 mergeable_count += 1;
2892 next_old_event = Some(old_event);
2893 }
2894
2895 if mergeable_count <= 1 {
2896 return;
2897 }
2898
2899 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2900 let oldest_event = events_to_merge.peek().unwrap();
2901 let oldest_snapshot = oldest_event.old_snapshot.clone();
2902 let newest_snapshot = end_snapshot;
2903 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2904
2905 for event in events.range(events.len() - mergeable_count + 1..) {
2906 merged_edit_range =
2907 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2908 }
2909
2910 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2911 &oldest_snapshot,
2912 newest_snapshot,
2913 &merged_edit_range,
2914 ) {
2915 let merged_event = match oldest_event.event.as_ref() {
2916 zeta_prompt::Event::BufferChange {
2917 old_path,
2918 path,
2919 in_open_source_repo,
2920 ..
2921 } => StoredEvent {
2922 event: Arc::new(zeta_prompt::Event::BufferChange {
2923 old_path: old_path.clone(),
2924 path: path.clone(),
2925 diff,
2926 in_open_source_repo: *in_open_source_repo,
2927 predicted: events_to_merge.all(|e| {
2928 matches!(
2929 e.event.as_ref(),
2930 zeta_prompt::Event::BufferChange {
2931 predicted: true,
2932 ..
2933 }
2934 )
2935 }),
2936 }),
2937 old_snapshot: oldest_snapshot.clone(),
2938 new_snapshot_version: newest_snapshot.version.clone(),
2939 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2940 ..newest_snapshot.anchor_before(edit_range.end),
2941 },
2942 };
2943 events.truncate(events.len() - mergeable_count);
2944 events.push_back(merged_event);
2945 }
2946}
2947
2948fn merge_anchor_ranges(
2949 left: &Range<Anchor>,
2950 right: &Range<Anchor>,
2951 snapshot: &TextBufferSnapshot,
2952) -> Range<Anchor> {
2953 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2954 left.start
2955 } else {
2956 right.start
2957 };
2958 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2959 left.end
2960 } else {
2961 right.end
2962 };
2963 start..end
2964}
2965
2966#[derive(Error, Debug)]
2967#[error(
2968 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2969)]
2970pub struct ZedUpdateRequiredError {
2971 minimum_version: Version,
2972}
2973
2974#[derive(Debug, Clone, Copy)]
2975pub enum DataCollectionChoice {
2976 NotAnswered,
2977 Enabled,
2978 Disabled,
2979}
2980
2981impl DataCollectionChoice {
2982 pub fn is_enabled(self, cx: &App) -> bool {
2983 if cx.is_staff() {
2984 return true;
2985 }
2986 match self {
2987 Self::Enabled => true,
2988 Self::NotAnswered | Self::Disabled => false,
2989 }
2990 }
2991
2992 #[must_use]
2993 pub fn toggle(&self) -> DataCollectionChoice {
2994 match self {
2995 Self::Enabled => Self::Disabled,
2996 Self::Disabled => Self::Enabled,
2997 Self::NotAnswered => Self::Enabled,
2998 }
2999 }
3000}
3001
3002impl From<bool> for DataCollectionChoice {
3003 fn from(value: bool) -> Self {
3004 match value {
3005 true => DataCollectionChoice::Enabled,
3006 false => DataCollectionChoice::Disabled,
3007 }
3008 }
3009}
3010
3011struct ZedPredictUpsell;
3012
3013impl Dismissable for ZedPredictUpsell {
3014 const KEY: &'static str = "dismissed-edit-predict-upsell";
3015
3016 fn dismissed(cx: &App) -> bool {
3017 // To make this backwards compatible with older versions of Zed, we
3018 // check if the user has seen the previous Edit Prediction Onboarding
3019 // before, by checking the data collection choice which was written to
3020 // the database once the user clicked on "Accept and Enable"
3021 let kvp = KeyValueStore::global(cx);
3022 if kvp
3023 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3024 .log_err()
3025 .is_some_and(|s| s.is_some())
3026 {
3027 return true;
3028 }
3029
3030 kvp.read_kvp(Self::KEY)
3031 .log_err()
3032 .is_some_and(|s| s.is_some())
3033 }
3034}
3035
3036pub fn should_show_upsell_modal(cx: &App) -> bool {
3037 !ZedPredictUpsell::dismissed(cx)
3038}
3039
3040pub fn init(cx: &mut App) {
3041 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3042 workspace.register_action(
3043 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3044 ZedPredictModal::toggle(
3045 workspace,
3046 workspace.user_store().clone(),
3047 workspace.client().clone(),
3048 window,
3049 cx,
3050 )
3051 },
3052 );
3053
3054 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3055 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3056 settings
3057 .project
3058 .all_languages
3059 .edit_predictions
3060 .get_or_insert_default()
3061 .provider = Some(EditPredictionProvider::None)
3062 });
3063 });
3064 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3065 EditPredictionStore::try_global(cx).and_then(|store| {
3066 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3067 })
3068 }
3069
3070 workspace.register_action(|workspace, _: &SignIn, window, cx| {
3071 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3072 copilot_ui::initiate_sign_in(copilot, window, cx);
3073 }
3074 });
3075 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3076 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3077 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3078 }
3079 });
3080 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3081 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3082 copilot_ui::initiate_sign_out(copilot, window, cx);
3083 }
3084 });
3085 })
3086 .detach();
3087}