1use anyhow::Result;
2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
3use cloud_api_client::LlmApiToken;
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use credentials_provider::CredentialsProvider;
16use db::kvp::{Dismissable, KeyValueStore};
17use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
19use futures::{
20 AsyncReadExt as _, FutureExt as _, StreamExt as _,
21 channel::mpsc::{self, UnboundedReceiver},
22 select_biased,
23};
24use gpui::BackgroundExecutor;
25use gpui::http_client::Url;
26use gpui::{
27 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, TaskExt, WeakEntity, actions,
28 http_client::{self, AsyncBody, Method},
29 prelude::*,
30};
31use heapless::Vec as ArrayVec;
32use language::language_settings::all_language_settings;
33use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
34use language::{BufferSnapshot, OffsetRangeExt};
35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
36use release_channel::AppVersion;
37use semver::Version;
38use serde::de::DeserializeOwned;
39use settings::{
40 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
41};
42use std::collections::{VecDeque, hash_map};
43use std::env;
44use text::{AnchorRangeExt, Edit};
45use workspace::{AppState, Workspace};
46use zeta_prompt::{ZetaFormat, ZetaPromptInput};
47
48use std::mem;
49use std::ops::Range;
50use std::path::Path;
51use std::rc::Rc;
52use std::str::FromStr as _;
53use std::sync::Arc;
54use std::time::{Duration, Instant};
55
56use thiserror::Error;
57use util::{RangeExt as _, ResultExt as _};
58
59pub mod cursor_excerpt;
60pub mod example_spec;
61pub mod fim;
62mod license_detection;
63pub mod mercury;
64pub mod metrics;
65pub mod ollama;
66mod onboarding_modal;
67pub mod open_ai_response;
68mod prediction;
69
70pub mod udiff;
71
72mod capture_example;
73pub mod open_ai_compatible;
74mod zed_edit_prediction_delegate;
75pub mod zeta;
76
77#[cfg(test)]
78mod edit_prediction_tests;
79
80use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
81use crate::example_spec::ExampleSpec;
82use crate::license_detection::LicenseDetectionWatcher;
83use crate::mercury::Mercury;
84pub use crate::metrics::{KeptRateResult, compute_kept_rate};
85use crate::onboarding_modal::ZedPredictModal;
86pub use crate::prediction::EditPrediction;
87pub use crate::prediction::EditPredictionId;
88use crate::prediction::EditPredictionResult;
89pub use capture_example::capture_example;
90pub use language_model::ApiKeyState;
91pub use telemetry_events::EditPredictionRating;
92pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
93
94actions!(
95 edit_prediction,
96 [
97 /// Resets the edit prediction onboarding state.
98 ResetOnboarding,
99 /// Clears the edit prediction history.
100 ClearHistory,
101 ]
102);
103
104/// Maximum number of events to track.
105const EVENT_COUNT_MAX: usize = 10;
106const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
107const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
108const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
109const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
110const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
111const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
112const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
113const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
114const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
115
116pub struct EditPredictionJumpsFeatureFlag;
117
118impl FeatureFlag for EditPredictionJumpsFeatureFlag {
119 const NAME: &'static str = "edit_prediction_jumps";
120}
121
122#[derive(Clone)]
123struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
124
125impl Global for EditPredictionStoreGlobal {}
126
127/// Configuration for using the raw Zeta2 endpoint.
128/// When set, the client uses the raw endpoint and constructs the prompt itself.
129/// The version is also used as the Baseten environment name (lowercased).
130#[derive(Clone)]
131pub struct Zeta2RawConfig {
132 pub model_id: Option<String>,
133 pub environment: Option<String>,
134 pub format: ZetaFormat,
135}
136
137pub struct EditPredictionStore {
138 client: Arc<Client>,
139 user_store: Entity<UserStore>,
140 llm_token: LlmApiToken,
141 _fetch_experiments_task: Task<()>,
142 projects: HashMap<EntityId, ProjectState>,
143 update_required: bool,
144 edit_prediction_model: EditPredictionModel,
145 zeta2_raw_config: Option<Zeta2RawConfig>,
146 preferred_experiment: Option<String>,
147 available_experiments: Vec<String>,
148 pub mercury: Mercury,
149 data_collection_choice: DataCollectionChoice,
150 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
151 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
152 shown_predictions: VecDeque<EditPrediction>,
153 rated_predictions: HashSet<EditPredictionId>,
154 #[cfg(test)]
155 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
156 credentials_provider: Arc<dyn CredentialsProvider>,
157}
158
159pub(crate) struct EditPredictionRejectionPayload {
160 rejection: EditPredictionRejection,
161 organization_id: Option<OrganizationId>,
162}
163
164#[derive(Copy, Clone, PartialEq, Eq)]
165pub enum EditPredictionModel {
166 Zeta,
167 Fim { format: EditPredictionPromptFormat },
168 Mercury,
169}
170
171#[derive(Clone)]
172pub struct EditPredictionModelInput {
173 project: Entity<Project>,
174 buffer: Entity<Buffer>,
175 snapshot: BufferSnapshot,
176 position: Anchor,
177 events: Vec<Arc<zeta_prompt::Event>>,
178 related_files: Vec<RelatedFile>,
179 trigger: PredictEditsRequestTrigger,
180 diagnostic_search_range: Range<Point>,
181 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
182 can_collect_data: bool,
183 is_open_source: bool,
184}
185
186#[derive(Debug)]
187pub enum DebugEvent {
188 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
189 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
190 EditPredictionStarted(EditPredictionStartedDebugEvent),
191 EditPredictionFinished(EditPredictionFinishedDebugEvent),
192}
193
194#[derive(Debug)]
195pub struct ContextRetrievalStartedDebugEvent {
196 pub project_entity_id: EntityId,
197 pub timestamp: Instant,
198 pub search_prompt: String,
199}
200
201#[derive(Debug)]
202pub struct ContextRetrievalFinishedDebugEvent {
203 pub project_entity_id: EntityId,
204 pub timestamp: Instant,
205 pub metadata: Vec<(&'static str, SharedString)>,
206}
207
208#[derive(Debug)]
209pub struct EditPredictionStartedDebugEvent {
210 pub buffer: WeakEntity<Buffer>,
211 pub position: Anchor,
212 pub prompt: Option<String>,
213}
214
215#[derive(Debug)]
216pub struct EditPredictionFinishedDebugEvent {
217 pub buffer: WeakEntity<Buffer>,
218 pub position: Anchor,
219 pub model_output: Option<String>,
220}
221
222/// An event with associated metadata for reconstructing buffer state.
223#[derive(Clone)]
224pub struct StoredEvent {
225 pub event: Arc<zeta_prompt::Event>,
226 pub old_snapshot: TextBufferSnapshot,
227 pub new_snapshot_version: clock::Global,
228 pub total_edit_range: Range<Anchor>,
229}
230
231impl StoredEvent {
232 fn can_merge(
233 &self,
234 next_old_event: &StoredEvent,
235 latest_snapshot: &TextBufferSnapshot,
236 latest_edit_range: &Range<Anchor>,
237 ) -> bool {
238 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
239 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
240 return false;
241 }
242 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
243 return false;
244 }
245 if self.new_snapshot_version != next_old_event.old_snapshot.version {
246 return false;
247 }
248 if !latest_snapshot
249 .version
250 .observed_all(&next_old_event.new_snapshot_version)
251 {
252 return false;
253 }
254
255 let a_is_predicted = matches!(
256 self.event.as_ref(),
257 zeta_prompt::Event::BufferChange {
258 predicted: true,
259 ..
260 }
261 );
262 let b_is_predicted = matches!(
263 next_old_event.event.as_ref(),
264 zeta_prompt::Event::BufferChange {
265 predicted: true,
266 ..
267 }
268 );
269
270 // If events come from the same source (both predicted or both manual) then
271 // we would have coalesced them already.
272 if a_is_predicted == b_is_predicted {
273 return false;
274 }
275
276 let left_range = self.total_edit_range.to_point(latest_snapshot);
277 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
278 let latest_range = latest_edit_range.to_point(latest_snapshot);
279
280 // Events near to the latest edit are not merged if their sources differ.
281 if lines_between_ranges(&left_range, &latest_range)
282 .min(lines_between_ranges(&right_range, &latest_range))
283 <= CHANGE_GROUPING_LINE_SPAN
284 {
285 return false;
286 }
287
288 // Events that are distant from each other are not merged.
289 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
290 return false;
291 }
292
293 true
294 }
295}
296
297fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
298 if left.start > right.end {
299 return left.start.row - right.end.row;
300 }
301 if right.start > left.end {
302 return right.start.row - left.end.row;
303 }
304 0
305}
306
307struct ProjectState {
308 events: VecDeque<StoredEvent>,
309 last_event: Option<LastEvent>,
310 recent_paths: VecDeque<ProjectPath>,
311 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
312 current_prediction: Option<CurrentEditPrediction>,
313 next_pending_prediction_id: usize,
314 pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
315 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
316 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
317 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
318 cancelled_predictions: HashSet<usize>,
319 context: Entity<RelatedExcerptStore>,
320 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
321 _subscriptions: [gpui::Subscription; 2],
322 copilot: Option<Entity<Copilot>>,
323}
324
325impl ProjectState {
326 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
327 self.events
328 .iter()
329 .cloned()
330 .chain(self.last_event.as_ref().iter().flat_map(|event| {
331 let (one, two) = event.split_by_pause();
332 let one = one.finalize(&self.license_detection_watchers, cx);
333 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
334 one.into_iter().chain(two)
335 }))
336 .collect()
337 }
338
339 fn cancel_pending_prediction(
340 &mut self,
341 pending_prediction: PendingPrediction,
342 cx: &mut Context<EditPredictionStore>,
343 ) {
344 self.cancelled_predictions.insert(pending_prediction.id);
345
346 if pending_prediction.drop_on_cancel {
347 drop(pending_prediction.task);
348 } else {
349 cx.spawn(async move |this, cx| {
350 let Some(prediction_id) = pending_prediction.task.await else {
351 return;
352 };
353
354 this.update(cx, |this, cx| {
355 this.reject_prediction(
356 prediction_id,
357 EditPredictionRejectReason::Canceled,
358 false,
359 None,
360 None,
361 cx,
362 );
363 })
364 .ok();
365 })
366 .detach()
367 }
368 }
369
370 fn active_buffer(
371 &self,
372 project: &Entity<Project>,
373 cx: &App,
374 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
375 let project = project.read(cx);
376 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
377 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
378 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
379 Some((active_buffer, registered_buffer.last_position))
380 }
381}
382
383#[derive(Debug, Clone)]
384struct CurrentEditPrediction {
385 pub requested_by: PredictionRequestedBy,
386 pub prediction: EditPrediction,
387 pub was_shown: bool,
388 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
389 pub e2e_latency: std::time::Duration,
390}
391
392impl CurrentEditPrediction {
393 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
394 let Some(new_edits) = self
395 .prediction
396 .interpolate(&self.prediction.buffer.read(cx))
397 else {
398 return false;
399 };
400
401 if self.prediction.buffer != old_prediction.prediction.buffer {
402 return true;
403 }
404
405 let Some(old_edits) = old_prediction
406 .prediction
407 .interpolate(&old_prediction.prediction.buffer.read(cx))
408 else {
409 return true;
410 };
411
412 let requested_by_buffer_id = self.requested_by.buffer_id();
413
414 // This reduces the occurrence of UI thrash from replacing edits
415 //
416 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
417 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
418 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
419 && old_edits.len() == 1
420 && new_edits.len() == 1
421 {
422 let (old_range, old_text) = &old_edits[0];
423 let (new_range, new_text) = &new_edits[0];
424 new_range == old_range && new_text.starts_with(old_text.as_ref())
425 } else {
426 true
427 }
428 }
429}
430
431#[derive(Debug, Clone)]
432enum PredictionRequestedBy {
433 DiagnosticsUpdate,
434 Buffer(EntityId),
435}
436
437impl PredictionRequestedBy {
438 pub fn buffer_id(&self) -> Option<EntityId> {
439 match self {
440 PredictionRequestedBy::DiagnosticsUpdate => None,
441 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
442 }
443 }
444}
445
446const DIAGNOSTIC_LINES_RANGE: u32 = 20;
447
448#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
449pub enum DiagnosticSearchScope {
450 Local,
451 Global,
452}
453
454#[derive(Debug)]
455struct PendingPrediction {
456 id: usize,
457 task: Task<Option<EditPredictionId>>,
458 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
459 /// If false, the task is awaited to completion so rejection can be reported.
460 drop_on_cancel: bool,
461}
462
463/// A prediction from the perspective of a buffer.
464#[derive(Debug)]
465enum BufferEditPrediction<'a> {
466 Local { prediction: &'a EditPrediction },
467 Jump { prediction: &'a EditPrediction },
468}
469
470#[cfg(test)]
471impl std::ops::Deref for BufferEditPrediction<'_> {
472 type Target = EditPrediction;
473
474 fn deref(&self) -> &Self::Target {
475 match self {
476 BufferEditPrediction::Local { prediction } => prediction,
477 BufferEditPrediction::Jump { prediction } => prediction,
478 }
479 }
480}
481
482#[derive(Clone)]
483struct PendingSettledPrediction {
484 request_id: EditPredictionId,
485 editable_anchor_range: Range<Anchor>,
486 editable_region_before_prediction: String,
487 predicted_editable_region: String,
488 ts_error_count_before_prediction: usize,
489 ts_error_count_after_prediction: usize,
490 example: Option<ExampleSpec>,
491 enqueued_at: Instant,
492 last_edit_at: Instant,
493 e2e_latency: std::time::Duration,
494}
495
496struct RegisteredBuffer {
497 file: Option<Arc<dyn File>>,
498 snapshot: TextBufferSnapshot,
499 pending_predictions: Vec<PendingSettledPrediction>,
500 last_position: Option<Anchor>,
501 _subscriptions: [gpui::Subscription; 2],
502}
503
504#[derive(Clone)]
505struct LastEvent {
506 old_snapshot: TextBufferSnapshot,
507 new_snapshot: TextBufferSnapshot,
508 old_file: Option<Arc<dyn File>>,
509 new_file: Option<Arc<dyn File>>,
510 latest_edit_range: Range<Anchor>,
511 total_edit_range: Range<Anchor>,
512 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
513 predicted: bool,
514 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
515 last_edit_time: Option<Instant>,
516}
517
518impl LastEvent {
519 pub fn finalize(
520 &self,
521 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
522 cx: &App,
523 ) -> Option<StoredEvent> {
524 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
525 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
526
527 let in_open_source_repo =
528 [self.new_file.as_ref(), self.old_file.as_ref()]
529 .iter()
530 .all(|file| {
531 file.is_some_and(|file| {
532 license_detection_watchers
533 .get(&file.worktree_id(cx))
534 .is_some_and(|watcher| watcher.is_project_open_source())
535 })
536 });
537
538 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
539 &self.old_snapshot,
540 &self.new_snapshot,
541 &self.total_edit_range,
542 )?;
543
544 if path == old_path && diff.is_empty() {
545 None
546 } else {
547 Some(StoredEvent {
548 event: Arc::new(zeta_prompt::Event::BufferChange {
549 old_path,
550 path,
551 diff,
552 in_open_source_repo,
553 predicted: self.predicted,
554 }),
555 old_snapshot: self.old_snapshot.clone(),
556 new_snapshot_version: self.new_snapshot.version.clone(),
557 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
558 ..self.new_snapshot.anchor_before(edit_range.end),
559 })
560 }
561 }
562
563 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
564 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
565 return (self.clone(), None);
566 };
567
568 let total_edit_range_before_pause = self
569 .total_edit_range_at_last_pause_boundary
570 .clone()
571 .unwrap_or_else(|| self.total_edit_range.clone());
572
573 let Some(total_edit_range_after_pause) =
574 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
575 else {
576 return (self.clone(), None);
577 };
578
579 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
580 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
581
582 let before = LastEvent {
583 old_snapshot: self.old_snapshot.clone(),
584 new_snapshot: boundary_snapshot.clone(),
585 old_file: self.old_file.clone(),
586 new_file: self.new_file.clone(),
587 latest_edit_range: latest_edit_range_before_pause,
588 total_edit_range: total_edit_range_before_pause,
589 total_edit_range_at_last_pause_boundary: None,
590 predicted: self.predicted,
591 snapshot_after_last_editing_pause: None,
592 last_edit_time: self.last_edit_time,
593 };
594
595 let after = LastEvent {
596 old_snapshot: boundary_snapshot.clone(),
597 new_snapshot: self.new_snapshot.clone(),
598 old_file: self.old_file.clone(),
599 new_file: self.new_file.clone(),
600 latest_edit_range: latest_edit_range_after_pause,
601 total_edit_range: total_edit_range_after_pause,
602 total_edit_range_at_last_pause_boundary: None,
603 predicted: self.predicted,
604 snapshot_after_last_editing_pause: None,
605 last_edit_time: self.last_edit_time,
606 };
607
608 (before, Some(after))
609 }
610}
611
612fn compute_total_edit_range_between_snapshots(
613 old_snapshot: &TextBufferSnapshot,
614 new_snapshot: &TextBufferSnapshot,
615) -> Option<Range<Anchor>> {
616 let edits: Vec<Edit<usize>> = new_snapshot
617 .edits_since::<usize>(&old_snapshot.version)
618 .collect();
619
620 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
621 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
622 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
623
624 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
625}
626
627fn compute_old_range_for_new_range(
628 old_snapshot: &TextBufferSnapshot,
629 new_snapshot: &TextBufferSnapshot,
630 total_edit_range: &Range<Anchor>,
631) -> Option<Range<Point>> {
632 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
633 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
634
635 let edits: Vec<Edit<usize>> = new_snapshot
636 .edits_since::<usize>(&old_snapshot.version)
637 .collect();
638 let mut old_start_offset = None;
639 let mut old_end_offset = None;
640 let mut delta: isize = 0;
641
642 for edit in &edits {
643 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
644 old_start_offset = Some(if new_start_offset < edit.new.start {
645 new_start_offset.checked_add_signed(-delta)?
646 } else {
647 edit.old.start
648 });
649 }
650
651 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
652 old_end_offset = Some(if new_end_offset < edit.new.start {
653 new_end_offset.checked_add_signed(-delta)?
654 } else {
655 edit.old.end
656 });
657 }
658
659 delta += edit.new.len() as isize - edit.old.len() as isize;
660 }
661
662 let old_start_offset =
663 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
664 let old_end_offset =
665 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
666
667 Some(
668 old_snapshot.offset_to_point(old_start_offset)
669 ..old_snapshot.offset_to_point(old_end_offset),
670 )
671}
672
673fn compute_diff_between_snapshots_in_range(
674 old_snapshot: &TextBufferSnapshot,
675 new_snapshot: &TextBufferSnapshot,
676 total_edit_range: &Range<Anchor>,
677) -> Option<(String, Range<Point>)> {
678 let new_start_point = total_edit_range.start.to_point(new_snapshot);
679 let new_end_point = total_edit_range.end.to_point(new_snapshot);
680 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
681 let old_start_point = old_range.start;
682 let old_end_point = old_range.end;
683
684 const CONTEXT_LINES: u32 = 3;
685
686 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
687 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
688 let old_context_end_row =
689 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
690 let new_context_end_row =
691 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
692
693 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
694 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
695 let old_end_line_offset = old_snapshot
696 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
697 let new_end_line_offset = new_snapshot
698 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
699 let old_edit_range = old_start_line_offset..old_end_line_offset;
700 let new_edit_range = new_start_line_offset..new_end_line_offset;
701
702 if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
703 || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
704 {
705 return None;
706 }
707
708 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
709 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
710
711 let diff = language::unified_diff_with_offsets(
712 &old_region_text,
713 &new_region_text,
714 old_context_start_row,
715 new_context_start_row,
716 );
717
718 Some((diff, new_start_point..new_end_point))
719}
720
721fn buffer_path_with_id_fallback(
722 file: Option<&Arc<dyn File>>,
723 snapshot: &TextBufferSnapshot,
724 cx: &App,
725) -> Arc<Path> {
726 if let Some(file) = file {
727 file.full_path(cx).into()
728 } else {
729 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
730 }
731}
732
733impl EditPredictionStore {
734 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
735 cx.try_global::<EditPredictionStoreGlobal>()
736 .map(|global| global.0.clone())
737 }
738
739 pub fn global(
740 client: &Arc<Client>,
741 user_store: &Entity<UserStore>,
742 cx: &mut App,
743 ) -> Entity<Self> {
744 cx.try_global::<EditPredictionStoreGlobal>()
745 .map(|global| global.0.clone())
746 .unwrap_or_else(|| {
747 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
748 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
749 ep_store
750 })
751 }
752
753 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
754 let data_collection_choice = Self::load_data_collection_choice(cx);
755
756 let llm_token = global_llm_token(cx);
757
758 let (reject_tx, reject_rx) = mpsc::unbounded();
759 cx.background_spawn({
760 let client = client.clone();
761 let llm_token = llm_token.clone();
762 let app_version = AppVersion::global(cx);
763 let background_executor = cx.background_executor().clone();
764 async move {
765 Self::handle_rejected_predictions(
766 reject_rx,
767 client,
768 llm_token,
769 app_version,
770 background_executor,
771 )
772 .await
773 }
774 })
775 .detach();
776
777 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
778 cx.spawn(async move |this, cx| {
779 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
780 })
781 .detach();
782
783 let mut current_user = user_store.read(cx).watch_current_user();
784 let fetch_experiments_task = cx.spawn(async move |this, cx| {
785 while current_user.borrow().is_none() {
786 current_user.next().await;
787 }
788
789 this.update(cx, |this, cx| {
790 if cx.is_staff() {
791 this.refresh_available_experiments(cx);
792 }
793 })
794 .log_err();
795 });
796
797 let credentials_provider = zed_credentials_provider::global(cx);
798
799 let this = Self {
800 projects: HashMap::default(),
801 client,
802 user_store,
803 llm_token,
804 _fetch_experiments_task: fetch_experiments_task,
805 update_required: false,
806 edit_prediction_model: EditPredictionModel::Zeta,
807 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
808 preferred_experiment: None,
809 available_experiments: Vec::new(),
810 mercury: Mercury::new(cx),
811
812 data_collection_choice,
813 reject_predictions_tx: reject_tx,
814 settled_predictions_tx,
815 rated_predictions: Default::default(),
816 shown_predictions: Default::default(),
817 #[cfg(test)]
818 settled_event_callback: None,
819
820 credentials_provider,
821 };
822
823 this
824 }
825
826 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
827 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
828 let format = ZetaFormat::parse(&version_str).ok()?;
829 let model_id = env::var("ZED_ZETA_MODEL").ok();
830 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
831 Some(Zeta2RawConfig {
832 model_id,
833 environment,
834 format,
835 })
836 }
837
838 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
839 self.edit_prediction_model = model;
840 }
841
842 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
843 self.zeta2_raw_config = Some(config);
844 }
845
846 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
847 self.zeta2_raw_config.as_ref()
848 }
849
850 pub fn preferred_experiment(&self) -> Option<&str> {
851 self.preferred_experiment.as_deref()
852 }
853
854 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
855 self.preferred_experiment = experiment;
856 }
857
858 pub fn available_experiments(&self) -> &[String] {
859 &self.available_experiments
860 }
861
862 pub fn active_experiment(&self) -> Option<&str> {
863 self.preferred_experiment.as_deref().or_else(|| {
864 self.shown_predictions
865 .iter()
866 .find_map(|p| p.model_version.as_ref())
867 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
868 })
869 }
870
871 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
872 let client = self.client.clone();
873 let llm_token = self.llm_token.clone();
874 let app_version = AppVersion::global(cx);
875 let organization_id = self
876 .user_store
877 .read(cx)
878 .current_organization()
879 .map(|organization| organization.id.clone());
880
881 cx.spawn(async move |this, cx| {
882 let experiments = cx
883 .background_spawn(async move {
884 let http_client = client.http_client();
885 let token = client
886 .acquire_llm_token(&llm_token, organization_id.clone())
887 .await?;
888 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
889 let request = http_client::Request::builder()
890 .method(Method::GET)
891 .uri(url.as_ref())
892 .header("Authorization", format!("Bearer {}", token))
893 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
894 .body(Default::default())?;
895 let mut response = http_client.send(request).await?;
896 if response.status().is_success() {
897 let mut body = Vec::new();
898 response.body_mut().read_to_end(&mut body).await?;
899 let experiments: Vec<String> = serde_json::from_slice(&body)?;
900 Ok(experiments)
901 } else {
902 let mut body = String::new();
903 response.body_mut().read_to_string(&mut body).await?;
904 anyhow::bail!(
905 "Failed to fetch experiments: {:?}\nBody: {}",
906 response.status(),
907 body
908 );
909 }
910 })
911 .await?;
912 this.update(cx, |this, cx| {
913 this.available_experiments = experiments;
914 cx.notify();
915 })?;
916 anyhow::Ok(())
917 })
918 .detach_and_log_err(cx);
919 }
920
921 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
922 use ui::IconName;
923 match self.edit_prediction_model {
924 EditPredictionModel::Mercury => {
925 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
926 }
927 EditPredictionModel::Zeta => {
928 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
929 .with_disabled(IconName::ZedPredictDisabled)
930 .with_up(IconName::ZedPredictUp)
931 .with_down(IconName::ZedPredictDown)
932 .with_error(IconName::ZedPredictError)
933 }
934 EditPredictionModel::Fim { .. } => {
935 let settings = &all_language_settings(None, cx).edit_predictions;
936 match settings.provider {
937 EditPredictionProvider::Ollama => {
938 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
939 }
940 _ => {
941 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
942 }
943 }
944 }
945 }
946 }
947
948 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
949 self.mercury.api_token.read(cx).has_key()
950 }
951
952 pub fn mercury_has_payment_required_error(&self) -> bool {
953 self.mercury.has_payment_required_error()
954 }
955
956 pub fn clear_history(&mut self) {
957 for project_state in self.projects.values_mut() {
958 project_state.events.clear();
959 project_state.last_event.take();
960 }
961 }
962
963 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
964 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
965 project_state.events.clear();
966 project_state.last_event.take();
967 }
968 }
969
970 pub fn edit_history_for_project(
971 &self,
972 project: &Entity<Project>,
973 cx: &App,
974 ) -> Vec<StoredEvent> {
975 self.projects
976 .get(&project.entity_id())
977 .map(|project_state| project_state.events(cx))
978 .unwrap_or_default()
979 }
980
981 pub fn context_for_project<'a>(
982 &'a self,
983 project: &Entity<Project>,
984 cx: &'a mut App,
985 ) -> Vec<RelatedFile> {
986 self.projects
987 .get(&project.entity_id())
988 .map(|project_state| {
989 project_state.context.update(cx, |context, cx| {
990 context
991 .related_files_with_buffers(cx)
992 .map(|(mut related_file, buffer)| {
993 related_file.in_open_source_repo = buffer
994 .read(cx)
995 .file()
996 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
997 related_file
998 })
999 .collect()
1000 })
1001 })
1002 .unwrap_or_default()
1003 }
1004
1005 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1006 self.projects
1007 .get(&project.entity_id())
1008 .and_then(|project| project.copilot.clone())
1009 }
1010
1011 pub fn start_copilot_for_project(
1012 &mut self,
1013 project: &Entity<Project>,
1014 cx: &mut Context<Self>,
1015 ) -> Option<Entity<Copilot>> {
1016 if DisableAiSettings::get(None, cx).disable_ai {
1017 return None;
1018 }
1019 let state = self.get_or_init_project(project, cx);
1020
1021 if state.copilot.is_some() {
1022 return state.copilot.clone();
1023 }
1024 let _project = project.clone();
1025 let project = project.read(cx);
1026
1027 let node = project.node_runtime().cloned();
1028 if let Some(node) = node {
1029 let next_id = project.languages().next_language_server_id();
1030 let fs = project.fs().clone();
1031
1032 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1033 state.copilot = Some(copilot.clone());
1034 Some(copilot)
1035 } else {
1036 None
1037 }
1038 }
1039
1040 pub fn context_for_project_with_buffers<'a>(
1041 &'a self,
1042 project: &Entity<Project>,
1043 cx: &'a mut App,
1044 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1045 self.projects
1046 .get(&project.entity_id())
1047 .map(|project| {
1048 project.context.update(cx, |context, cx| {
1049 context.related_files_with_buffers(cx).collect()
1050 })
1051 })
1052 .unwrap_or_default()
1053 }
1054
1055 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1056 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1057 self.user_store.read(cx).edit_prediction_usage()
1058 } else {
1059 None
1060 }
1061 }
1062
1063 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1064 self.get_or_init_project(project, cx);
1065 }
1066
1067 pub fn register_buffer(
1068 &mut self,
1069 buffer: &Entity<Buffer>,
1070 project: &Entity<Project>,
1071 cx: &mut Context<Self>,
1072 ) {
1073 let project_state = self.get_or_init_project(project, cx);
1074 Self::register_buffer_impl(project_state, buffer, project, cx);
1075 }
1076
1077 fn get_or_init_project(
1078 &mut self,
1079 project: &Entity<Project>,
1080 cx: &mut Context<Self>,
1081 ) -> &mut ProjectState {
1082 let entity_id = project.entity_id();
1083 self.projects
1084 .entry(entity_id)
1085 .or_insert_with(|| ProjectState {
1086 context: {
1087 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1088 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1089 this.handle_excerpt_store_event(entity_id, event);
1090 })
1091 .detach();
1092 related_excerpt_store
1093 },
1094 events: VecDeque::new(),
1095 last_event: None,
1096 recent_paths: VecDeque::new(),
1097 debug_tx: None,
1098 registered_buffers: HashMap::default(),
1099 current_prediction: None,
1100 cancelled_predictions: HashSet::default(),
1101 pending_predictions: ArrayVec::new(),
1102 next_pending_prediction_id: 0,
1103 last_edit_prediction_refresh: None,
1104 last_jump_prediction_refresh: None,
1105 license_detection_watchers: HashMap::default(),
1106 _subscriptions: [
1107 cx.subscribe(&project, Self::handle_project_event),
1108 cx.observe_release(&project, move |this, _, cx| {
1109 this.projects.remove(&entity_id);
1110 cx.notify();
1111 }),
1112 ],
1113 copilot: None,
1114 })
1115 }
1116
1117 pub fn remove_project(&mut self, project: &Entity<Project>) {
1118 self.projects.remove(&project.entity_id());
1119 }
1120
1121 fn handle_excerpt_store_event(
1122 &mut self,
1123 project_entity_id: EntityId,
1124 event: &RelatedExcerptStoreEvent,
1125 ) {
1126 if let Some(project_state) = self.projects.get(&project_entity_id) {
1127 if let Some(debug_tx) = project_state.debug_tx.clone() {
1128 match event {
1129 RelatedExcerptStoreEvent::StartedRefresh => {
1130 debug_tx
1131 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1132 ContextRetrievalStartedDebugEvent {
1133 project_entity_id: project_entity_id,
1134 timestamp: Instant::now(),
1135 search_prompt: String::new(),
1136 },
1137 ))
1138 .ok();
1139 }
1140 RelatedExcerptStoreEvent::FinishedRefresh {
1141 cache_hit_count,
1142 cache_miss_count,
1143 mean_definition_latency,
1144 max_definition_latency,
1145 } => {
1146 debug_tx
1147 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1148 ContextRetrievalFinishedDebugEvent {
1149 project_entity_id: project_entity_id,
1150 timestamp: Instant::now(),
1151 metadata: vec![
1152 (
1153 "Cache Hits",
1154 format!(
1155 "{}/{}",
1156 cache_hit_count,
1157 cache_hit_count + cache_miss_count
1158 )
1159 .into(),
1160 ),
1161 (
1162 "Max LSP Time",
1163 format!("{} ms", max_definition_latency.as_millis())
1164 .into(),
1165 ),
1166 (
1167 "Mean LSP Time",
1168 format!("{} ms", mean_definition_latency.as_millis())
1169 .into(),
1170 ),
1171 ],
1172 },
1173 ))
1174 .ok();
1175 }
1176 }
1177 }
1178 }
1179 }
1180
1181 pub fn debug_info(
1182 &mut self,
1183 project: &Entity<Project>,
1184 cx: &mut Context<Self>,
1185 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1186 let project_state = self.get_or_init_project(project, cx);
1187 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1188 project_state.debug_tx = Some(debug_watch_tx);
1189 debug_watch_rx
1190 }
1191
1192 fn handle_project_event(
1193 &mut self,
1194 project: Entity<Project>,
1195 event: &project::Event,
1196 cx: &mut Context<Self>,
1197 ) {
1198 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1199 return;
1200 }
1201 // TODO [zeta2] init with recent paths
1202 match event {
1203 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1204 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1205 return;
1206 };
1207 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1208 if let Some(path) = path {
1209 if let Some(ix) = project_state
1210 .recent_paths
1211 .iter()
1212 .position(|probe| probe == &path)
1213 {
1214 project_state.recent_paths.remove(ix);
1215 }
1216 project_state.recent_paths.push_front(path);
1217 }
1218 }
1219 project::Event::DiagnosticsUpdated { .. } => {
1220 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1221 self.refresh_prediction_from_diagnostics(
1222 project,
1223 DiagnosticSearchScope::Global,
1224 cx,
1225 );
1226 }
1227 }
1228 _ => (),
1229 }
1230 }
1231
1232 fn register_buffer_impl<'a>(
1233 project_state: &'a mut ProjectState,
1234 buffer: &Entity<Buffer>,
1235 project: &Entity<Project>,
1236 cx: &mut Context<Self>,
1237 ) -> &'a mut RegisteredBuffer {
1238 let buffer_id = buffer.entity_id();
1239
1240 if let Some(file) = buffer.read(cx).file() {
1241 let worktree_id = file.worktree_id(cx);
1242 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1243 project_state
1244 .license_detection_watchers
1245 .entry(worktree_id)
1246 .or_insert_with(|| {
1247 let project_entity_id = project.entity_id();
1248 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1249 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1250 else {
1251 return;
1252 };
1253 project_state
1254 .license_detection_watchers
1255 .remove(&worktree_id);
1256 })
1257 .detach();
1258 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1259 });
1260 }
1261 }
1262
1263 match project_state.registered_buffers.entry(buffer_id) {
1264 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1265 hash_map::Entry::Vacant(entry) => {
1266 let buf = buffer.read(cx);
1267 let snapshot = buf.text_snapshot();
1268 let file = buf.file().cloned();
1269 let project_entity_id = project.entity_id();
1270 entry.insert(RegisteredBuffer {
1271 snapshot,
1272 file,
1273 last_position: None,
1274 pending_predictions: Vec::new(),
1275 _subscriptions: [
1276 cx.subscribe(buffer, {
1277 let project = project.downgrade();
1278 move |this, buffer, event, cx| {
1279 if let language::BufferEvent::Edited { is_local } = event
1280 && let Some(project) = project.upgrade()
1281 {
1282 this.report_changes_for_buffer(
1283 &buffer, &project, false, *is_local, cx,
1284 );
1285 }
1286 }
1287 }),
1288 cx.observe_release(buffer, move |this, _buffer, _cx| {
1289 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1290 else {
1291 return;
1292 };
1293 project_state.registered_buffers.remove(&buffer_id);
1294 }),
1295 ],
1296 })
1297 }
1298 }
1299 }
1300
1301 fn report_changes_for_buffer(
1302 &mut self,
1303 buffer: &Entity<Buffer>,
1304 project: &Entity<Project>,
1305 is_predicted: bool,
1306 is_local: bool,
1307 cx: &mut Context<Self>,
1308 ) {
1309 let project_state = self.get_or_init_project(project, cx);
1310 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1311
1312 let buf = buffer.read(cx);
1313 let new_file = buf.file().cloned();
1314 let new_snapshot = buf.text_snapshot();
1315 if new_snapshot.version == registered_buffer.snapshot.version {
1316 return;
1317 }
1318 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1319 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1320 let mut edit_range: Option<Range<Anchor>> = None;
1321 let now = cx.background_executor().now();
1322
1323 for (_edit, anchor_range) in
1324 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1325 {
1326 edit_range = Some(match edit_range {
1327 None => anchor_range,
1328 Some(acc) => acc.start..anchor_range.end,
1329 });
1330 }
1331
1332 let Some(edit_range) = edit_range else {
1333 return;
1334 };
1335
1336 for pending_prediction in &mut registered_buffer.pending_predictions {
1337 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1338 pending_prediction.last_edit_at = now;
1339 }
1340 }
1341
1342 let include_in_history = is_local
1343 || collaborator_edit_overlaps_locality_region(
1344 project_state,
1345 project,
1346 buffer,
1347 &buf.snapshot(),
1348 &edit_range,
1349 cx,
1350 );
1351
1352 if !include_in_history {
1353 return;
1354 }
1355
1356 let is_recordable_history_edit =
1357 compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1358 .is_some();
1359
1360 let events = &mut project_state.events;
1361
1362 if !is_recordable_history_edit {
1363 if let Some(event) = project_state.last_event.take() {
1364 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1365 if events.len() + 1 >= EVENT_COUNT_MAX {
1366 events.pop_front();
1367 }
1368 events.push_back(event);
1369 }
1370 }
1371 return;
1372 }
1373
1374 if let Some(last_event) = project_state.last_event.as_mut() {
1375 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1376 == last_event.new_snapshot.remote_id()
1377 && old_snapshot.version == last_event.new_snapshot.version;
1378
1379 let prediction_source_changed = is_predicted != last_event.predicted;
1380
1381 let should_coalesce = is_next_snapshot_of_same_buffer
1382 && !prediction_source_changed
1383 && lines_between_ranges(
1384 &edit_range.to_point(&new_snapshot),
1385 &last_event.latest_edit_range.to_point(&new_snapshot),
1386 ) <= CHANGE_GROUPING_LINE_SPAN;
1387
1388 if should_coalesce {
1389 let pause_elapsed = last_event
1390 .last_edit_time
1391 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1392 .unwrap_or(false);
1393 if pause_elapsed {
1394 last_event.snapshot_after_last_editing_pause =
1395 Some(last_event.new_snapshot.clone());
1396 last_event.total_edit_range_at_last_pause_boundary =
1397 Some(last_event.total_edit_range.clone());
1398 }
1399
1400 last_event.latest_edit_range = edit_range.clone();
1401 last_event.total_edit_range =
1402 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1403 last_event.new_snapshot = new_snapshot;
1404 last_event.last_edit_time = Some(now);
1405 return;
1406 }
1407 }
1408
1409 if let Some(event) = project_state.last_event.take() {
1410 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1411 if events.len() + 1 >= EVENT_COUNT_MAX {
1412 events.pop_front();
1413 }
1414 events.push_back(event);
1415 }
1416 }
1417
1418 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1419
1420 project_state.last_event = Some(LastEvent {
1421 old_file,
1422 new_file,
1423 old_snapshot,
1424 new_snapshot,
1425 latest_edit_range: edit_range.clone(),
1426 total_edit_range: edit_range,
1427 total_edit_range_at_last_pause_boundary: None,
1428 predicted: is_predicted,
1429 snapshot_after_last_editing_pause: None,
1430 last_edit_time: Some(now),
1431 });
1432 }
1433
1434 fn prediction_at(
1435 &mut self,
1436 buffer: &Entity<Buffer>,
1437 position: Option<language::Anchor>,
1438 project: &Entity<Project>,
1439 cx: &App,
1440 ) -> Option<BufferEditPrediction<'_>> {
1441 let project_state = self.projects.get_mut(&project.entity_id())?;
1442 if let Some(position) = position
1443 && let Some(buffer) = project_state
1444 .registered_buffers
1445 .get_mut(&buffer.entity_id())
1446 {
1447 buffer.last_position = Some(position);
1448 }
1449
1450 let CurrentEditPrediction {
1451 requested_by,
1452 prediction,
1453 ..
1454 } = project_state.current_prediction.as_ref()?;
1455
1456 if prediction.targets_buffer(buffer.read(cx)) {
1457 Some(BufferEditPrediction::Local { prediction })
1458 } else {
1459 let show_jump = match requested_by {
1460 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1461 requested_by_buffer_id == &buffer.entity_id()
1462 }
1463 PredictionRequestedBy::DiagnosticsUpdate => true,
1464 };
1465
1466 if show_jump {
1467 Some(BufferEditPrediction::Jump { prediction })
1468 } else {
1469 None
1470 }
1471 }
1472 }
1473
1474 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1475 let Some(current_prediction) = self
1476 .projects
1477 .get_mut(&project.entity_id())
1478 .and_then(|project_state| project_state.current_prediction.take())
1479 else {
1480 return;
1481 };
1482
1483 self.report_changes_for_buffer(
1484 ¤t_prediction.prediction.buffer,
1485 project,
1486 true,
1487 true,
1488 cx,
1489 );
1490
1491 // can't hold &mut project_state ref across report_changes_for_buffer_call
1492 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1493 return;
1494 };
1495
1496 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1497 project_state.cancel_pending_prediction(pending_prediction, cx);
1498 }
1499
1500 match self.edit_prediction_model {
1501 EditPredictionModel::Mercury => {
1502 mercury::edit_prediction_accepted(
1503 current_prediction.prediction.id,
1504 self.client.http_client(),
1505 cx,
1506 );
1507 }
1508 EditPredictionModel::Zeta => {
1509 let is_cloud = !matches!(
1510 all_language_settings(None, cx).edit_predictions.provider,
1511 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1512 );
1513 if is_cloud {
1514 zeta::edit_prediction_accepted(self, current_prediction, cx)
1515 }
1516 }
1517 EditPredictionModel::Fim { .. } => {}
1518 }
1519 }
1520
1521 async fn handle_rejected_predictions(
1522 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1523 client: Arc<Client>,
1524 llm_token: LlmApiToken,
1525 app_version: Version,
1526 background_executor: BackgroundExecutor,
1527 ) {
1528 let mut rx = std::pin::pin!(rx.peekable());
1529 let mut batched = Vec::new();
1530
1531 while let Some(EditPredictionRejectionPayload {
1532 rejection,
1533 organization_id,
1534 }) = rx.next().await
1535 {
1536 batched.push(rejection);
1537
1538 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1539 select_biased! {
1540 next = rx.as_mut().peek().fuse() => {
1541 if next.is_some() {
1542 continue;
1543 }
1544 }
1545 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1546 }
1547 }
1548
1549 let url = client
1550 .http_client()
1551 .build_zed_llm_url("/predict_edits/reject", &[])
1552 .unwrap();
1553
1554 let flush_count = batched
1555 .len()
1556 // in case items have accumulated after failure
1557 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1558 let start = batched.len() - flush_count;
1559
1560 let body = RejectEditPredictionsBodyRef {
1561 rejections: &batched[start..],
1562 };
1563
1564 let result = Self::send_api_request::<()>(
1565 |builder| {
1566 let req = builder
1567 .uri(url.as_ref())
1568 .body(serde_json::to_string(&body)?.into());
1569 anyhow::Ok(req?)
1570 },
1571 client.clone(),
1572 llm_token.clone(),
1573 organization_id,
1574 app_version.clone(),
1575 true,
1576 )
1577 .await;
1578
1579 if result.log_err().is_some() {
1580 batched.drain(start..);
1581 }
1582 }
1583 }
1584
1585 async fn run_settled_predictions_worker(
1586 this: WeakEntity<Self>,
1587 mut rx: UnboundedReceiver<Instant>,
1588 cx: &mut AsyncApp,
1589 ) {
1590 let mut next_wake_time: Option<Instant> = None;
1591 loop {
1592 let now = cx.background_executor().now();
1593 if let Some(wake_time) = next_wake_time.take() {
1594 cx.background_executor()
1595 .timer(wake_time.duration_since(now))
1596 .await;
1597 } else {
1598 let Some(new_enqueue_time) = rx.next().await else {
1599 break;
1600 };
1601 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1602 while rx.next().now_or_never().flatten().is_some() {}
1603 continue;
1604 }
1605
1606 let Some(this) = this.upgrade() else {
1607 break;
1608 };
1609
1610 let now = cx.background_executor().now();
1611 let mut oldest_edited_at = None;
1612 let mut ready_predictions = Vec::new();
1613
1614 this.update(cx, |this, _| {
1615 for (_, project_state) in this.projects.iter_mut() {
1616 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1617 let mut pending_index = 0;
1618 while pending_index < registered_buffer.pending_predictions.len() {
1619 let pending_prediction =
1620 ®istered_buffer.pending_predictions[pending_index];
1621 let age = now.saturating_duration_since(pending_prediction.enqueued_at);
1622 if age >= EDIT_PREDICTION_SETTLED_TTL {
1623 registered_buffer.pending_predictions.remove(pending_index);
1624 continue;
1625 }
1626
1627 let quiet_for =
1628 now.saturating_duration_since(pending_prediction.last_edit_at);
1629 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1630 let pending_prediction =
1631 registered_buffer.pending_predictions.remove(pending_index);
1632 let settled_editable_region = registered_buffer
1633 .snapshot
1634 .text_for_range(
1635 pending_prediction.editable_anchor_range.clone(),
1636 )
1637 .collect::<String>();
1638 ready_predictions
1639 .push((pending_prediction, settled_editable_region));
1640 continue;
1641 }
1642
1643 if oldest_edited_at
1644 .is_none_or(|time| pending_prediction.last_edit_at < time)
1645 {
1646 oldest_edited_at = Some(pending_prediction.last_edit_at);
1647 }
1648 pending_index += 1;
1649 }
1650 }
1651 }
1652 });
1653
1654 for (pending_prediction, settled_editable_region) in ready_predictions {
1655 let PendingSettledPrediction {
1656 request_id,
1657 editable_region_before_prediction,
1658 predicted_editable_region,
1659 ts_error_count_before_prediction,
1660 ts_error_count_after_prediction,
1661 example,
1662 e2e_latency,
1663 ..
1664 } = pending_prediction;
1665 let settled_editable_region_for_metrics = settled_editable_region.clone();
1666 let kept_rate_result = cx
1667 .background_spawn(async move {
1668 compute_kept_rate(
1669 &editable_region_before_prediction,
1670 &predicted_editable_region,
1671 &settled_editable_region_for_metrics,
1672 )
1673 })
1674 .await;
1675
1676 #[cfg(test)]
1677 {
1678 let request_id = request_id.clone();
1679 let settled_editable_region = settled_editable_region.clone();
1680 this.update(cx, |this, _| {
1681 if let Some(callback) = &this.settled_event_callback {
1682 callback(request_id, settled_editable_region);
1683 }
1684 });
1685 }
1686
1687 telemetry::event!(
1688 EDIT_PREDICTION_SETTLED_EVENT,
1689 request_id = request_id.0.clone(),
1690 settled_editable_region,
1691 ts_error_count_before_prediction,
1692 ts_error_count_after_prediction,
1693 edit_bytes_candidate_new = kept_rate_result.candidate_new_chars,
1694 edit_bytes_reference_new = kept_rate_result.reference_new_chars,
1695 edit_bytes_candidate_deleted = kept_rate_result.candidate_deleted_chars,
1696 edit_bytes_reference_deleted = kept_rate_result.reference_deleted_chars,
1697 edit_bytes_kept = kept_rate_result.kept_chars,
1698 edit_bytes_correctly_deleted = kept_rate_result.correctly_deleted_chars,
1699 edit_bytes_discarded = kept_rate_result.discarded_chars,
1700 edit_bytes_context = kept_rate_result.context_chars,
1701 edit_bytes_kept_rate = kept_rate_result.kept_rate,
1702 edit_bytes_recall_rate = kept_rate_result.recall_rate,
1703 example,
1704 e2e_latency = e2e_latency.as_millis(),
1705 );
1706 }
1707
1708 next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1709 }
1710 }
1711
1712 pub(crate) fn enqueue_settled_prediction(
1713 &mut self,
1714 request_id: EditPredictionId,
1715 project: &Entity<Project>,
1716 edited_buffer: &Entity<Buffer>,
1717 edited_buffer_snapshot: &BufferSnapshot,
1718 editable_offset_range: Range<usize>,
1719 edit_preview: &EditPreview,
1720 example: Option<ExampleSpec>,
1721 e2e_latency: std::time::Duration,
1722 cx: &mut Context<Self>,
1723 ) {
1724 let this = &mut *self;
1725 let project_state = this.get_or_init_project(project, cx);
1726 let Some(registered_buffer) = project_state
1727 .registered_buffers
1728 .get_mut(&edited_buffer.entity_id())
1729 else {
1730 return;
1731 };
1732
1733 let editable_region_before_prediction = edited_buffer_snapshot
1734 .text_for_range(editable_offset_range.clone())
1735 .collect::<String>();
1736 let editable_anchor_range_for_result =
1737 edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
1738 let predicted_editable_region = edit_preview
1739 .result_text_snapshot()
1740 .text_for_range(editable_anchor_range_for_result.clone())
1741 .collect();
1742 let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
1743 edited_buffer_snapshot
1744 .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
1745 );
1746 let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
1747 edit_preview.result_syntax_snapshot().layers_for_range(
1748 editable_anchor_range_for_result,
1749 edit_preview.result_text_snapshot(),
1750 true,
1751 ),
1752 );
1753 let editable_anchor_range =
1754 edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
1755 let now = cx.background_executor().now();
1756 registered_buffer
1757 .pending_predictions
1758 .push(PendingSettledPrediction {
1759 request_id,
1760 editable_anchor_range,
1761 editable_region_before_prediction,
1762 predicted_editable_region,
1763 ts_error_count_before_prediction,
1764 ts_error_count_after_prediction,
1765 example,
1766 e2e_latency,
1767 enqueued_at: now,
1768 last_edit_at: now,
1769 });
1770 this.settled_predictions_tx.unbounded_send(now).ok();
1771 }
1772
1773 fn reject_current_prediction(
1774 &mut self,
1775 reason: EditPredictionRejectReason,
1776 project: &Entity<Project>,
1777 cx: &App,
1778 ) {
1779 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1780 project_state.pending_predictions.clear();
1781 if let Some(prediction) = project_state.current_prediction.take() {
1782 let model_version = prediction.prediction.model_version.clone();
1783 self.reject_prediction(
1784 prediction.prediction.id,
1785 reason,
1786 prediction.was_shown,
1787 model_version,
1788 Some(prediction.e2e_latency),
1789 cx,
1790 );
1791 }
1792 };
1793 }
1794
1795 fn did_show_current_prediction(
1796 &mut self,
1797 project: &Entity<Project>,
1798 display_type: edit_prediction_types::SuggestionDisplayType,
1799 _cx: &mut Context<Self>,
1800 ) {
1801 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1802 return;
1803 };
1804
1805 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1806 return;
1807 };
1808
1809 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1810 let previous_shown_with = current_prediction.shown_with;
1811
1812 if previous_shown_with.is_none() || !is_jump {
1813 current_prediction.shown_with = Some(display_type);
1814 }
1815
1816 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1817
1818 if is_first_non_jump_show {
1819 current_prediction.was_shown = true;
1820 }
1821
1822 if is_first_non_jump_show {
1823 self.shown_predictions
1824 .push_front(current_prediction.prediction.clone());
1825 if self.shown_predictions.len() > 50 {
1826 let completion = self.shown_predictions.pop_back().unwrap();
1827 self.rated_predictions.remove(&completion.id);
1828 }
1829 }
1830 }
1831
1832 fn reject_prediction(
1833 &mut self,
1834 prediction_id: EditPredictionId,
1835 reason: EditPredictionRejectReason,
1836 was_shown: bool,
1837 model_version: Option<String>,
1838 e2e_latency: Option<std::time::Duration>,
1839 cx: &App,
1840 ) {
1841 match self.edit_prediction_model {
1842 EditPredictionModel::Zeta => {
1843 let is_cloud = !matches!(
1844 all_language_settings(None, cx).edit_predictions.provider,
1845 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1846 );
1847
1848 if is_cloud {
1849 let organization_id = self
1850 .user_store
1851 .read(cx)
1852 .current_organization()
1853 .map(|organization| organization.id.clone());
1854
1855 self.reject_predictions_tx
1856 .unbounded_send(EditPredictionRejectionPayload {
1857 rejection: EditPredictionRejection {
1858 request_id: prediction_id.to_string(),
1859 reason,
1860 was_shown,
1861 model_version,
1862 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1863 },
1864 organization_id,
1865 })
1866 .log_err();
1867 }
1868 }
1869 EditPredictionModel::Mercury => {
1870 mercury::edit_prediction_rejected(
1871 prediction_id,
1872 was_shown,
1873 reason,
1874 self.client.http_client(),
1875 cx,
1876 );
1877 }
1878 EditPredictionModel::Fim { .. } => {}
1879 }
1880 }
1881
1882 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1883 self.projects
1884 .get(&project.entity_id())
1885 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1886 }
1887
1888 pub fn refresh_prediction_from_buffer(
1889 &mut self,
1890 project: Entity<Project>,
1891 buffer: Entity<Buffer>,
1892 position: language::Anchor,
1893 cx: &mut Context<Self>,
1894 ) {
1895 self.queue_prediction_refresh(
1896 project.clone(),
1897 PredictEditsRequestTrigger::Other,
1898 buffer.entity_id(),
1899 cx,
1900 move |this, cx| {
1901 let Some(request_task) = this
1902 .update(cx, |this, cx| {
1903 this.request_prediction(
1904 &project,
1905 &buffer,
1906 position,
1907 PredictEditsRequestTrigger::Other,
1908 cx,
1909 )
1910 })
1911 .log_err()
1912 else {
1913 return Task::ready(anyhow::Ok(None));
1914 };
1915
1916 cx.spawn(async move |_cx| {
1917 request_task.await.map(|prediction_result| {
1918 prediction_result.map(|prediction_result| {
1919 (
1920 prediction_result,
1921 PredictionRequestedBy::Buffer(buffer.entity_id()),
1922 )
1923 })
1924 })
1925 })
1926 },
1927 )
1928 }
1929
1930 pub fn refresh_prediction_from_diagnostics(
1931 &mut self,
1932 project: Entity<Project>,
1933 scope: DiagnosticSearchScope,
1934 cx: &mut Context<Self>,
1935 ) {
1936 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1937 return;
1938 }
1939
1940 if currently_following(&project, cx) {
1941 return;
1942 }
1943
1944 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1945 return;
1946 };
1947
1948 // Prefer predictions from buffer
1949 if project_state.current_prediction.is_some() {
1950 log::debug!(
1951 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1952 );
1953 return;
1954 }
1955
1956 self.queue_prediction_refresh(
1957 project.clone(),
1958 PredictEditsRequestTrigger::Diagnostics,
1959 project.entity_id(),
1960 cx,
1961 move |this, cx| {
1962 let Some((active_buffer, snapshot, cursor_point)) = this
1963 .read_with(cx, |this, cx| {
1964 let project_state = this.projects.get(&project.entity_id())?;
1965 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1966 let snapshot = buffer.read(cx).snapshot();
1967
1968 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1969 return None;
1970 }
1971
1972 let cursor_point = position
1973 .map(|pos| pos.to_point(&snapshot))
1974 .unwrap_or_default();
1975
1976 Some((buffer, snapshot, cursor_point))
1977 })
1978 .log_err()
1979 .flatten()
1980 else {
1981 return Task::ready(anyhow::Ok(None));
1982 };
1983
1984 cx.spawn(async move |cx| {
1985 let diagnostic_search_range = match scope {
1986 DiagnosticSearchScope::Local => {
1987 let diagnostic_search_start =
1988 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1989 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1990 Point::new(diagnostic_search_start, 0)
1991 ..Point::new(diagnostic_search_end, 0)
1992 }
1993 DiagnosticSearchScope::Global => Default::default(),
1994 };
1995
1996 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1997 active_buffer,
1998 &snapshot,
1999 diagnostic_search_range,
2000 cursor_point,
2001 &project,
2002 cx,
2003 )
2004 .await?
2005 else {
2006 return anyhow::Ok(None);
2007 };
2008
2009 let Some(prediction_result) = this
2010 .update(cx, |this, cx| {
2011 this.request_prediction(
2012 &project,
2013 &jump_buffer,
2014 jump_position,
2015 PredictEditsRequestTrigger::Diagnostics,
2016 cx,
2017 )
2018 })?
2019 .await?
2020 else {
2021 return anyhow::Ok(None);
2022 };
2023
2024 this.update(cx, |this, cx| {
2025 Some((
2026 if this
2027 .get_or_init_project(&project, cx)
2028 .current_prediction
2029 .is_none()
2030 {
2031 prediction_result
2032 } else {
2033 EditPredictionResult {
2034 id: prediction_result.id,
2035 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2036 e2e_latency: prediction_result.e2e_latency,
2037 }
2038 },
2039 PredictionRequestedBy::DiagnosticsUpdate,
2040 ))
2041 })
2042 })
2043 },
2044 );
2045 }
2046
2047 fn predictions_enabled_at(
2048 snapshot: &BufferSnapshot,
2049 position: Option<language::Anchor>,
2050 cx: &App,
2051 ) -> bool {
2052 let file = snapshot.file();
2053 let all_settings = all_language_settings(file, cx);
2054 if !all_settings.show_edit_predictions(snapshot.language(), cx)
2055 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2056 {
2057 return false;
2058 }
2059
2060 if let Some(last_position) = position {
2061 let settings = snapshot.settings_at(last_position, cx);
2062
2063 if !settings.edit_predictions_disabled_in.is_empty()
2064 && let Some(scope) = snapshot.language_scope_at(last_position)
2065 && let Some(scope_name) = scope.override_name()
2066 && settings
2067 .edit_predictions_disabled_in
2068 .iter()
2069 .any(|s| s == scope_name)
2070 {
2071 return false;
2072 }
2073 }
2074
2075 true
2076 }
2077
2078 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2079}
2080
2081fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2082 let Some(app_state) = AppState::try_global(cx) else {
2083 return false;
2084 };
2085
2086 app_state
2087 .workspace_store
2088 .read(cx)
2089 .workspaces()
2090 .filter_map(|workspace| workspace.upgrade())
2091 .any(|workspace| {
2092 workspace.read(cx).project().entity_id() == project.entity_id()
2093 && workspace
2094 .read(cx)
2095 .leader_for_pane(workspace.read(cx).active_pane())
2096 .is_some()
2097 })
2098}
2099
2100fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2101 match provider {
2102 EditPredictionProvider::Zed
2103 | EditPredictionProvider::Mercury
2104 | EditPredictionProvider::Ollama
2105 | EditPredictionProvider::OpenAiCompatibleApi
2106 | EditPredictionProvider::Experimental(_) => true,
2107 EditPredictionProvider::None
2108 | EditPredictionProvider::Copilot
2109 | EditPredictionProvider::Codestral => false,
2110 }
2111}
2112
2113impl EditPredictionStore {
2114 fn queue_prediction_refresh(
2115 &mut self,
2116 project: Entity<Project>,
2117 request_trigger: PredictEditsRequestTrigger,
2118 throttle_entity: EntityId,
2119 cx: &mut Context<Self>,
2120 do_refresh: impl FnOnce(
2121 WeakEntity<Self>,
2122 &mut AsyncApp,
2123 )
2124 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2125 + 'static,
2126 ) {
2127 fn select_throttle(
2128 project_state: &mut ProjectState,
2129 request_trigger: PredictEditsRequestTrigger,
2130 ) -> &mut Option<(EntityId, Instant)> {
2131 match request_trigger {
2132 PredictEditsRequestTrigger::Diagnostics => {
2133 &mut project_state.last_jump_prediction_refresh
2134 }
2135 _ => &mut project_state.last_edit_prediction_refresh,
2136 }
2137 }
2138
2139 let (needs_acceptance_tracking, max_pending_predictions) =
2140 match all_language_settings(None, cx).edit_predictions.provider {
2141 EditPredictionProvider::Zed
2142 | EditPredictionProvider::Mercury
2143 | EditPredictionProvider::Experimental(_) => (true, 2),
2144 EditPredictionProvider::Ollama => (false, 1),
2145 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2146 EditPredictionProvider::None
2147 | EditPredictionProvider::Copilot
2148 | EditPredictionProvider::Codestral => {
2149 log::error!("queue_prediction_refresh called with non-store provider");
2150 return;
2151 }
2152 };
2153
2154 let drop_on_cancel = !needs_acceptance_tracking;
2155 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2156 let project_state = self.get_or_init_project(&project, cx);
2157 let pending_prediction_id = project_state.next_pending_prediction_id;
2158 project_state.next_pending_prediction_id += 1;
2159 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2160
2161 let task = cx.spawn(async move |this, cx| {
2162 let throttle_wait = this
2163 .update(cx, |this, cx| {
2164 let project_state = this.get_or_init_project(&project, cx);
2165 let throttle = *select_throttle(project_state, request_trigger);
2166
2167 let now = cx.background_executor().now();
2168 throttle.and_then(|(last_entity, last_timestamp)| {
2169 if throttle_entity != last_entity {
2170 return None;
2171 }
2172 (last_timestamp + throttle_timeout).checked_duration_since(now)
2173 })
2174 })
2175 .ok()
2176 .flatten();
2177
2178 if let Some(timeout) = throttle_wait {
2179 cx.background_executor().timer(timeout).await;
2180 }
2181
2182 // If this task was cancelled before the throttle timeout expired,
2183 // do not perform a request. Also skip if another task already
2184 // proceeded since we were enqueued (duplicate).
2185 let mut is_cancelled = true;
2186 this.update(cx, |this, cx| {
2187 let project_state = this.get_or_init_project(&project, cx);
2188 let was_cancelled = project_state
2189 .cancelled_predictions
2190 .remove(&pending_prediction_id);
2191 if was_cancelled {
2192 return;
2193 }
2194
2195 // Another request has been already sent since this was enqueued
2196 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2197 return;
2198 }
2199
2200 let new_refresh = (throttle_entity, cx.background_executor().now());
2201 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2202 is_cancelled = false;
2203 })
2204 .ok();
2205 if is_cancelled {
2206 return None;
2207 }
2208
2209 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2210 let new_prediction_id = new_prediction_result
2211 .as_ref()
2212 .map(|(prediction, _)| prediction.id.clone());
2213
2214 // When a prediction completes, remove it from the pending list, and cancel
2215 // any pending predictions that were enqueued before it.
2216 this.update(cx, |this, cx| {
2217 let project_state = this.get_or_init_project(&project, cx);
2218
2219 let is_cancelled = project_state
2220 .cancelled_predictions
2221 .remove(&pending_prediction_id);
2222
2223 let new_current_prediction = if !is_cancelled
2224 && let Some((prediction_result, requested_by)) = new_prediction_result
2225 {
2226 match prediction_result.prediction {
2227 Ok(prediction) => {
2228 let new_prediction = CurrentEditPrediction {
2229 requested_by,
2230 prediction,
2231 was_shown: false,
2232 shown_with: None,
2233 e2e_latency: prediction_result.e2e_latency,
2234 };
2235
2236 if let Some(current_prediction) =
2237 project_state.current_prediction.as_ref()
2238 {
2239 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2240 {
2241 this.reject_current_prediction(
2242 EditPredictionRejectReason::Replaced,
2243 &project,
2244 cx,
2245 );
2246
2247 Some(new_prediction)
2248 } else {
2249 this.reject_prediction(
2250 new_prediction.prediction.id,
2251 EditPredictionRejectReason::CurrentPreferred,
2252 false,
2253 new_prediction.prediction.model_version,
2254 Some(new_prediction.e2e_latency),
2255 cx,
2256 );
2257 None
2258 }
2259 } else {
2260 Some(new_prediction)
2261 }
2262 }
2263 Err(reject_reason) => {
2264 this.reject_prediction(
2265 prediction_result.id,
2266 reject_reason,
2267 false,
2268 None,
2269 Some(prediction_result.e2e_latency),
2270 cx,
2271 );
2272 None
2273 }
2274 }
2275 } else {
2276 None
2277 };
2278
2279 let project_state = this.get_or_init_project(&project, cx);
2280
2281 if let Some(new_prediction) = new_current_prediction {
2282 project_state.current_prediction = Some(new_prediction);
2283 }
2284
2285 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2286 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2287 if pending_prediction.id == pending_prediction_id {
2288 pending_predictions.remove(ix);
2289 for pending_prediction in pending_predictions.drain(0..ix) {
2290 project_state.cancel_pending_prediction(pending_prediction, cx)
2291 }
2292 break;
2293 }
2294 }
2295 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2296 cx.notify();
2297 })
2298 .ok();
2299
2300 new_prediction_id
2301 });
2302
2303 if project_state.pending_predictions.len() < max_pending_predictions {
2304 project_state
2305 .pending_predictions
2306 .push(PendingPrediction {
2307 id: pending_prediction_id,
2308 task,
2309 drop_on_cancel,
2310 })
2311 .unwrap();
2312 } else {
2313 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2314 project_state
2315 .pending_predictions
2316 .push(PendingPrediction {
2317 id: pending_prediction_id,
2318 task,
2319 drop_on_cancel,
2320 })
2321 .unwrap();
2322 project_state.cancel_pending_prediction(pending_prediction, cx);
2323 }
2324 }
2325
2326 pub fn request_prediction(
2327 &mut self,
2328 project: &Entity<Project>,
2329 active_buffer: &Entity<Buffer>,
2330 position: language::Anchor,
2331 trigger: PredictEditsRequestTrigger,
2332 cx: &mut Context<Self>,
2333 ) -> Task<Result<Option<EditPredictionResult>>> {
2334 self.request_prediction_internal(
2335 project.clone(),
2336 active_buffer.clone(),
2337 position,
2338 trigger,
2339 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2340 cx,
2341 )
2342 }
2343
2344 fn request_prediction_internal(
2345 &mut self,
2346 project: Entity<Project>,
2347 active_buffer: Entity<Buffer>,
2348 position: language::Anchor,
2349 trigger: PredictEditsRequestTrigger,
2350 allow_jump: bool,
2351 cx: &mut Context<Self>,
2352 ) -> Task<Result<Option<EditPredictionResult>>> {
2353 self.get_or_init_project(&project, cx);
2354 let project_state = self.projects.get(&project.entity_id()).unwrap();
2355 let stored_events = project_state.events(cx);
2356 let has_events = !stored_events.is_empty();
2357 let events: Vec<Arc<zeta_prompt::Event>> =
2358 stored_events.iter().map(|e| e.event.clone()).collect();
2359 let debug_tx = project_state.debug_tx.clone();
2360
2361 let snapshot = active_buffer.read(cx).snapshot();
2362 let cursor_point = position.to_point(&snapshot);
2363 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2364 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2365 let diagnostic_search_range =
2366 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2367
2368 let related_files = self.context_for_project(&project, cx);
2369
2370 let is_open_source = snapshot
2371 .file()
2372 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2373 && events.iter().all(|event| event.in_open_source_repo())
2374 && related_files.iter().all(|file| file.in_open_source_repo);
2375
2376 let can_collect_data = !cfg!(test)
2377 && is_open_source
2378 && self.is_data_collection_enabled(cx)
2379 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2380
2381 let inputs = EditPredictionModelInput {
2382 project: project.clone(),
2383 buffer: active_buffer,
2384 snapshot,
2385 position,
2386 events,
2387 related_files,
2388 trigger,
2389 diagnostic_search_range: diagnostic_search_range,
2390 debug_tx,
2391 can_collect_data,
2392 is_open_source,
2393 };
2394
2395 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2396
2397 let task = match self.edit_prediction_model {
2398 EditPredictionModel::Zeta => {
2399 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2400 }
2401 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2402 EditPredictionModel::Mercury => {
2403 self.mercury
2404 .request_prediction(inputs, self.credentials_provider.clone(), cx)
2405 }
2406 };
2407
2408 cx.spawn(async move |this, cx| {
2409 let prediction = task.await?;
2410
2411 // Only fall back to diagnostics-based prediction if we got a
2412 // the model had nothing to suggest for the buffer
2413 if prediction.is_none()
2414 && allow_jump
2415 && has_events
2416 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2417 {
2418 this.update(cx, |this, cx| {
2419 this.refresh_prediction_from_diagnostics(
2420 project,
2421 DiagnosticSearchScope::Local,
2422 cx,
2423 );
2424 })?;
2425 return anyhow::Ok(None);
2426 }
2427
2428 Ok(prediction)
2429 })
2430 }
2431
2432 pub(crate) async fn next_diagnostic_location(
2433 active_buffer: Entity<Buffer>,
2434 active_buffer_snapshot: &BufferSnapshot,
2435 active_buffer_diagnostic_search_range: Range<Point>,
2436 active_buffer_cursor_point: Point,
2437 project: &Entity<Project>,
2438 cx: &mut AsyncApp,
2439 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2440 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2441 .selections_in_range(
2442 Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2443 false,
2444 )
2445 .flat_map(|(_, _, _, selections)| {
2446 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2447 })
2448 .collect();
2449
2450 let mut jump_location = active_buffer_snapshot
2451 .diagnostic_groups(None)
2452 .into_iter()
2453 .filter_map(|(_, group)| {
2454 let range = &group.entries[group.primary_ix]
2455 .range
2456 .to_point(&active_buffer_snapshot);
2457 if range.overlaps(&active_buffer_diagnostic_search_range) {
2458 return None;
2459 }
2460 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2461 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2462 });
2463 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2464 <= DIAGNOSTIC_LINES_RANGE;
2465 if near_collaborator && !near_local {
2466 return None;
2467 }
2468 Some(range.start)
2469 })
2470 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2471 .map(|position| {
2472 (
2473 active_buffer.clone(),
2474 active_buffer_snapshot.anchor_before(position),
2475 )
2476 });
2477
2478 if jump_location.is_none() {
2479 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2480 let file = buffer.file()?;
2481
2482 Some(ProjectPath {
2483 worktree_id: file.worktree_id(cx),
2484 path: file.path().clone(),
2485 })
2486 });
2487
2488 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2489 project
2490 .diagnostic_summaries(false, cx)
2491 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2492 .map(|(path, _, _)| {
2493 let shared_prefix = path
2494 .path
2495 .components()
2496 .zip(
2497 active_buffer_path
2498 .as_ref()
2499 .map(|p| p.path.components())
2500 .unwrap_or_default(),
2501 )
2502 .take_while(|(a, b)| a == b)
2503 .count();
2504 (path, shared_prefix)
2505 })
2506 .collect()
2507 });
2508
2509 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2510
2511 for (path, _) in candidates {
2512 let candidate_buffer = project
2513 .update(cx, |project, cx| project.open_buffer(path, cx))
2514 .await?;
2515
2516 let (has_collaborators, diagnostic_position) =
2517 candidate_buffer.read_with(cx, |buffer, _cx| {
2518 let snapshot = buffer.snapshot();
2519 let has_collaborators = snapshot
2520 .selections_in_range(
2521 Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2522 false,
2523 )
2524 .next()
2525 .is_some();
2526 let position = buffer
2527 .buffer_diagnostics(None)
2528 .into_iter()
2529 .min_by_key(|entry| entry.diagnostic.severity)
2530 .map(|entry| entry.range.start);
2531 (has_collaborators, position)
2532 });
2533
2534 if has_collaborators {
2535 continue;
2536 }
2537
2538 if let Some(position) = diagnostic_position {
2539 jump_location = Some((candidate_buffer, position));
2540 break;
2541 }
2542 }
2543 }
2544
2545 anyhow::Ok(jump_location)
2546 }
2547
2548 async fn send_raw_llm_request(
2549 request: RawCompletionRequest,
2550 client: Arc<Client>,
2551 custom_url: Option<Arc<Url>>,
2552 llm_token: LlmApiToken,
2553 organization_id: Option<OrganizationId>,
2554 app_version: Version,
2555 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2556 let url = if let Some(custom_url) = custom_url {
2557 custom_url.as_ref().clone()
2558 } else {
2559 client
2560 .http_client()
2561 .build_zed_llm_url("/predict_edits/raw", &[])?
2562 };
2563
2564 Self::send_api_request(
2565 |builder| {
2566 let req = builder
2567 .uri(url.as_ref())
2568 .body(serde_json::to_string(&request)?.into());
2569 Ok(req?)
2570 },
2571 client,
2572 llm_token,
2573 organization_id,
2574 app_version,
2575 true,
2576 )
2577 .await
2578 }
2579
2580 pub(crate) async fn send_v3_request(
2581 input: ZetaPromptInput,
2582 client: Arc<Client>,
2583 llm_token: LlmApiToken,
2584 organization_id: Option<OrganizationId>,
2585 app_version: Version,
2586 trigger: PredictEditsRequestTrigger,
2587 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2588 let url = client
2589 .http_client()
2590 .build_zed_llm_url("/predict_edits/v3", &[])?;
2591
2592 let request = PredictEditsV3Request { input, trigger };
2593
2594 let json_bytes = serde_json::to_vec(&request)?;
2595 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2596
2597 Self::send_api_request(
2598 |builder| {
2599 let req = builder
2600 .uri(url.as_ref())
2601 .header("Content-Encoding", "zstd")
2602 .body(compressed.clone().into());
2603 Ok(req?)
2604 },
2605 client,
2606 llm_token,
2607 organization_id,
2608 app_version,
2609 true,
2610 )
2611 .await
2612 }
2613
2614 async fn send_api_request<Res>(
2615 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2616 client: Arc<Client>,
2617 llm_token: LlmApiToken,
2618 organization_id: Option<OrganizationId>,
2619 app_version: Version,
2620 require_auth: bool,
2621 ) -> Result<(Res, Option<EditPredictionUsage>)>
2622 where
2623 Res: DeserializeOwned,
2624 {
2625 let http_client = client.http_client();
2626 let mut token = if require_auth {
2627 Some(
2628 client
2629 .acquire_llm_token(&llm_token, organization_id.clone())
2630 .await?,
2631 )
2632 } else {
2633 client
2634 .acquire_llm_token(&llm_token, organization_id.clone())
2635 .await
2636 .ok()
2637 };
2638 let mut did_retry = false;
2639
2640 loop {
2641 let request_builder = http_client::Request::builder().method(Method::POST);
2642
2643 let mut request_builder = request_builder
2644 .header("Content-Type", "application/json")
2645 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2646
2647 // Only add Authorization header if we have a token
2648 if let Some(ref token_value) = token {
2649 request_builder =
2650 request_builder.header("Authorization", format!("Bearer {}", token_value));
2651 }
2652
2653 let request = build(request_builder)?;
2654
2655 let mut response = http_client.send(request).await?;
2656
2657 if let Some(minimum_required_version) = response
2658 .headers()
2659 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2660 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2661 {
2662 anyhow::ensure!(
2663 app_version >= minimum_required_version,
2664 ZedUpdateRequiredError {
2665 minimum_version: minimum_required_version
2666 }
2667 );
2668 }
2669
2670 if response.status().is_success() {
2671 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2672
2673 let mut body = Vec::new();
2674 response.body_mut().read_to_end(&mut body).await?;
2675 return Ok((serde_json::from_slice(&body)?, usage));
2676 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2677 did_retry = true;
2678 token = Some(
2679 client
2680 .refresh_llm_token(&llm_token, organization_id.clone())
2681 .await?,
2682 );
2683 } else {
2684 let mut body = String::new();
2685 response.body_mut().read_to_string(&mut body).await?;
2686 anyhow::bail!(
2687 "Request failed with status: {:?}\nBody: {}",
2688 response.status(),
2689 body
2690 );
2691 }
2692 }
2693 }
2694
2695 pub fn refresh_context(
2696 &mut self,
2697 project: &Entity<Project>,
2698 buffer: &Entity<language::Buffer>,
2699 cursor_position: language::Anchor,
2700 cx: &mut Context<Self>,
2701 ) {
2702 self.get_or_init_project(project, cx)
2703 .context
2704 .update(cx, |store, cx| {
2705 store.refresh(buffer.clone(), cursor_position, cx);
2706 });
2707 }
2708
2709 #[cfg(feature = "cli-support")]
2710 pub fn set_context_for_buffer(
2711 &mut self,
2712 project: &Entity<Project>,
2713 related_files: Vec<RelatedFile>,
2714 cx: &mut Context<Self>,
2715 ) {
2716 self.get_or_init_project(project, cx)
2717 .context
2718 .update(cx, |store, cx| {
2719 store.set_related_files(related_files, cx);
2720 });
2721 }
2722
2723 #[cfg(feature = "cli-support")]
2724 pub fn set_recent_paths_for_project(
2725 &mut self,
2726 project: &Entity<Project>,
2727 paths: impl IntoIterator<Item = project::ProjectPath>,
2728 cx: &mut Context<Self>,
2729 ) {
2730 let project_state = self.get_or_init_project(project, cx);
2731 project_state.recent_paths = paths.into_iter().collect();
2732 }
2733
2734 fn is_file_open_source(
2735 &self,
2736 project: &Entity<Project>,
2737 file: &Arc<dyn File>,
2738 cx: &App,
2739 ) -> bool {
2740 if !file.is_local() || file.is_private() {
2741 return false;
2742 }
2743 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2744 return false;
2745 };
2746 project_state
2747 .license_detection_watchers
2748 .get(&file.worktree_id(cx))
2749 .as_ref()
2750 .is_some_and(|watcher| watcher.is_project_open_source())
2751 }
2752
2753 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2754 self.data_collection_choice.is_enabled(cx)
2755 }
2756
2757 fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2758 let choice = KeyValueStore::global(cx)
2759 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2760 .log_err()
2761 .flatten();
2762
2763 match choice.as_deref() {
2764 Some("true") => DataCollectionChoice::Enabled,
2765 Some("false") => DataCollectionChoice::Disabled,
2766 Some(_) => {
2767 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2768 DataCollectionChoice::NotAnswered
2769 }
2770 None => DataCollectionChoice::NotAnswered,
2771 }
2772 }
2773
2774 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2775 self.data_collection_choice = self.data_collection_choice.toggle();
2776 let new_choice = self.data_collection_choice;
2777 let is_enabled = new_choice.is_enabled(cx);
2778 let kvp = KeyValueStore::global(cx);
2779 db::write_and_log(cx, move || async move {
2780 kvp.write_kvp(
2781 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2782 is_enabled.to_string(),
2783 )
2784 .await
2785 });
2786 }
2787
2788 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2789 self.shown_predictions.iter()
2790 }
2791
2792 pub fn shown_completions_len(&self) -> usize {
2793 self.shown_predictions.len()
2794 }
2795
2796 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2797 self.rated_predictions.contains(id)
2798 }
2799
2800 pub fn rate_prediction(
2801 &mut self,
2802 prediction: &EditPrediction,
2803 rating: EditPredictionRating,
2804 feedback: String,
2805 cx: &mut Context<Self>,
2806 ) {
2807 let organization = self.user_store.read(cx).current_organization();
2808
2809 self.rated_predictions.insert(prediction.id.clone());
2810
2811 cx.background_spawn({
2812 let client = self.client.clone();
2813 let prediction_id = prediction.id.to_string();
2814 let inputs = serde_json::to_value(&prediction.inputs);
2815 let output = prediction
2816 .edit_preview
2817 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2818 async move {
2819 client
2820 .cloud_client()
2821 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2822 organization_id: organization.map(|organization| organization.id.clone()),
2823 request_id: prediction_id,
2824 rating: match rating {
2825 EditPredictionRating::Positive => "positive".to_string(),
2826 EditPredictionRating::Negative => "negative".to_string(),
2827 },
2828 inputs: inputs?,
2829 output,
2830 feedback,
2831 })
2832 .await?;
2833
2834 anyhow::Ok(())
2835 }
2836 })
2837 .detach_and_log_err(cx);
2838
2839 cx.notify();
2840 }
2841}
2842
2843fn collaborator_edit_overlaps_locality_region(
2844 project_state: &ProjectState,
2845 project: &Entity<Project>,
2846 buffer: &Entity<Buffer>,
2847 snapshot: &BufferSnapshot,
2848 edit_range: &Range<Anchor>,
2849 cx: &App,
2850) -> bool {
2851 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2852 return false;
2853 };
2854
2855 if active_buffer.entity_id() != buffer.entity_id() {
2856 return false;
2857 }
2858
2859 let locality_point_range = expand_context_syntactically_then_linewise(
2860 snapshot,
2861 (position..position).to_point(snapshot),
2862 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2863 );
2864 let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2865
2866 edit_range.overlaps(&locality_anchor_range, snapshot)
2867}
2868
2869fn merge_trailing_events_if_needed(
2870 events: &mut VecDeque<StoredEvent>,
2871 end_snapshot: &TextBufferSnapshot,
2872 latest_snapshot: &TextBufferSnapshot,
2873 latest_edit_range: &Range<Anchor>,
2874) {
2875 if let Some(last_event) = events.back() {
2876 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2877 return;
2878 }
2879 if !latest_snapshot
2880 .version
2881 .observed_all(&last_event.new_snapshot_version)
2882 {
2883 return;
2884 }
2885 }
2886
2887 let mut next_old_event = None;
2888 let mut mergeable_count = 0;
2889 for old_event in events.iter().rev() {
2890 if let Some(next_old_event) = next_old_event
2891 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2892 {
2893 break;
2894 }
2895 mergeable_count += 1;
2896 next_old_event = Some(old_event);
2897 }
2898
2899 if mergeable_count <= 1 {
2900 return;
2901 }
2902
2903 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2904 let oldest_event = events_to_merge.peek().unwrap();
2905 let oldest_snapshot = oldest_event.old_snapshot.clone();
2906 let newest_snapshot = end_snapshot;
2907 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2908
2909 for event in events.range(events.len() - mergeable_count + 1..) {
2910 merged_edit_range =
2911 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2912 }
2913
2914 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2915 &oldest_snapshot,
2916 newest_snapshot,
2917 &merged_edit_range,
2918 ) {
2919 let merged_event = match oldest_event.event.as_ref() {
2920 zeta_prompt::Event::BufferChange {
2921 old_path,
2922 path,
2923 in_open_source_repo,
2924 ..
2925 } => StoredEvent {
2926 event: Arc::new(zeta_prompt::Event::BufferChange {
2927 old_path: old_path.clone(),
2928 path: path.clone(),
2929 diff,
2930 in_open_source_repo: *in_open_source_repo,
2931 predicted: events_to_merge.all(|e| {
2932 matches!(
2933 e.event.as_ref(),
2934 zeta_prompt::Event::BufferChange {
2935 predicted: true,
2936 ..
2937 }
2938 )
2939 }),
2940 }),
2941 old_snapshot: oldest_snapshot.clone(),
2942 new_snapshot_version: newest_snapshot.version.clone(),
2943 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2944 ..newest_snapshot.anchor_before(edit_range.end),
2945 },
2946 };
2947 events.truncate(events.len() - mergeable_count);
2948 events.push_back(merged_event);
2949 }
2950}
2951
2952fn merge_anchor_ranges(
2953 left: &Range<Anchor>,
2954 right: &Range<Anchor>,
2955 snapshot: &TextBufferSnapshot,
2956) -> Range<Anchor> {
2957 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2958 left.start
2959 } else {
2960 right.start
2961 };
2962 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2963 left.end
2964 } else {
2965 right.end
2966 };
2967 start..end
2968}
2969
2970#[derive(Error, Debug)]
2971#[error(
2972 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2973)]
2974pub struct ZedUpdateRequiredError {
2975 minimum_version: Version,
2976}
2977
2978#[derive(Debug, Clone, Copy)]
2979pub enum DataCollectionChoice {
2980 NotAnswered,
2981 Enabled,
2982 Disabled,
2983}
2984
2985impl DataCollectionChoice {
2986 pub fn is_enabled(self, cx: &App) -> bool {
2987 if cx.is_staff() {
2988 return true;
2989 }
2990 match self {
2991 Self::Enabled => true,
2992 Self::NotAnswered | Self::Disabled => false,
2993 }
2994 }
2995
2996 #[must_use]
2997 pub fn toggle(&self) -> DataCollectionChoice {
2998 match self {
2999 Self::Enabled => Self::Disabled,
3000 Self::Disabled => Self::Enabled,
3001 Self::NotAnswered => Self::Enabled,
3002 }
3003 }
3004}
3005
3006impl From<bool> for DataCollectionChoice {
3007 fn from(value: bool) -> Self {
3008 match value {
3009 true => DataCollectionChoice::Enabled,
3010 false => DataCollectionChoice::Disabled,
3011 }
3012 }
3013}
3014
3015struct ZedPredictUpsell;
3016
3017impl Dismissable for ZedPredictUpsell {
3018 const KEY: &'static str = "dismissed-edit-predict-upsell";
3019
3020 fn dismissed(cx: &App) -> bool {
3021 // To make this backwards compatible with older versions of Zed, we
3022 // check if the user has seen the previous Edit Prediction Onboarding
3023 // before, by checking the data collection choice which was written to
3024 // the database once the user clicked on "Accept and Enable"
3025 let kvp = KeyValueStore::global(cx);
3026 if kvp
3027 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3028 .log_err()
3029 .is_some_and(|s| s.is_some())
3030 {
3031 return true;
3032 }
3033
3034 kvp.read_kvp(Self::KEY)
3035 .log_err()
3036 .is_some_and(|s| s.is_some())
3037 }
3038}
3039
3040pub fn should_show_upsell_modal(cx: &App) -> bool {
3041 !ZedPredictUpsell::dismissed(cx)
3042}
3043
3044pub fn init(cx: &mut App) {
3045 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3046 workspace.register_action(
3047 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3048 ZedPredictModal::toggle(
3049 workspace,
3050 workspace.user_store().clone(),
3051 workspace.client().clone(),
3052 window,
3053 cx,
3054 )
3055 },
3056 );
3057
3058 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3059 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3060 settings
3061 .project
3062 .all_languages
3063 .edit_predictions
3064 .get_or_insert_default()
3065 .provider = Some(EditPredictionProvider::None)
3066 });
3067 });
3068 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3069 EditPredictionStore::try_global(cx).and_then(|store| {
3070 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3071 })
3072 }
3073
3074 workspace.register_action(|workspace, _: &SignIn, window, cx| {
3075 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3076 copilot_ui::initiate_sign_in(copilot, window, cx);
3077 }
3078 });
3079 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3080 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3081 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3082 }
3083 });
3084 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3085 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3086 copilot_ui::initiate_sign_out(copilot, window, cx);
3087 }
3088 });
3089 })
3090 .detach();
3091}