1use anyhow::Result;
2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
3use cloud_api_client::LlmApiToken;
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use credentials_provider::CredentialsProvider;
16use db::kvp::{Dismissable, KeyValueStore};
17use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
19use futures::{
20 AsyncReadExt as _, FutureExt as _, StreamExt as _,
21 channel::mpsc::{self, UnboundedReceiver},
22 select_biased,
23};
24use gpui::BackgroundExecutor;
25use gpui::http_client::Url;
26use gpui::{
27 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
28 http_client::{self, AsyncBody, Method},
29 prelude::*,
30};
31use heapless::Vec as ArrayVec;
32use language::language_settings::all_language_settings;
33use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
34use language::{BufferSnapshot, OffsetRangeExt};
35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
36use release_channel::AppVersion;
37use semver::Version;
38use serde::de::DeserializeOwned;
39use settings::{
40 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
41};
42use std::collections::{VecDeque, hash_map};
43use std::env;
44use text::{AnchorRangeExt, Edit};
45use workspace::{AppState, Workspace};
46use zeta_prompt::{ZetaFormat, ZetaPromptInput};
47
48use std::mem;
49use std::ops::Range;
50use std::path::Path;
51use std::rc::Rc;
52use std::str::FromStr as _;
53use std::sync::Arc;
54use std::time::{Duration, Instant};
55
56use thiserror::Error;
57use util::{RangeExt as _, ResultExt as _};
58
59pub mod cursor_excerpt;
60pub mod example_spec;
61pub mod fim;
62mod license_detection;
63pub mod mercury;
64pub mod ollama;
65mod onboarding_modal;
66pub mod open_ai_response;
67mod prediction;
68
69pub mod udiff;
70
71mod capture_example;
72pub mod open_ai_compatible;
73mod zed_edit_prediction_delegate;
74pub mod zeta;
75
76#[cfg(test)]
77mod edit_prediction_tests;
78
79use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
80use crate::example_spec::ExampleSpec;
81use crate::license_detection::LicenseDetectionWatcher;
82use crate::mercury::Mercury;
83use crate::onboarding_modal::ZedPredictModal;
84pub use crate::prediction::EditPrediction;
85pub use crate::prediction::EditPredictionId;
86use crate::prediction::EditPredictionResult;
87pub use capture_example::capture_example;
88pub use language_model::ApiKeyState;
89pub use telemetry_events::EditPredictionRating;
90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
91
92actions!(
93 edit_prediction,
94 [
95 /// Resets the edit prediction onboarding state.
96 ResetOnboarding,
97 /// Clears the edit prediction history.
98 ClearHistory,
99 ]
100);
101
102/// Maximum number of events to track.
103const EVENT_COUNT_MAX: usize = 10;
104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
105const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
106const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
107const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
108const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
109const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
110const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
111const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
112const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
113
114pub struct EditPredictionJumpsFeatureFlag;
115
116impl FeatureFlag for EditPredictionJumpsFeatureFlag {
117 const NAME: &'static str = "edit_prediction_jumps";
118}
119
120#[derive(Clone)]
121struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
122
123impl Global for EditPredictionStoreGlobal {}
124
125/// Configuration for using the raw Zeta2 endpoint.
126/// When set, the client uses the raw endpoint and constructs the prompt itself.
127/// The version is also used as the Baseten environment name (lowercased).
128#[derive(Clone)]
129pub struct Zeta2RawConfig {
130 pub model_id: Option<String>,
131 pub environment: Option<String>,
132 pub format: ZetaFormat,
133}
134
135pub struct EditPredictionStore {
136 client: Arc<Client>,
137 user_store: Entity<UserStore>,
138 llm_token: LlmApiToken,
139 _fetch_experiments_task: Task<()>,
140 projects: HashMap<EntityId, ProjectState>,
141 update_required: bool,
142 edit_prediction_model: EditPredictionModel,
143 zeta2_raw_config: Option<Zeta2RawConfig>,
144 preferred_experiment: Option<String>,
145 available_experiments: Vec<String>,
146 pub mercury: Mercury,
147 data_collection_choice: DataCollectionChoice,
148 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
149 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
150 shown_predictions: VecDeque<EditPrediction>,
151 rated_predictions: HashSet<EditPredictionId>,
152 #[cfg(test)]
153 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
154 credentials_provider: Arc<dyn CredentialsProvider>,
155}
156
157pub(crate) struct EditPredictionRejectionPayload {
158 rejection: EditPredictionRejection,
159 organization_id: Option<OrganizationId>,
160}
161
162#[derive(Copy, Clone, PartialEq, Eq)]
163pub enum EditPredictionModel {
164 Zeta,
165 Fim { format: EditPredictionPromptFormat },
166 Mercury,
167}
168
169#[derive(Clone)]
170pub struct EditPredictionModelInput {
171 project: Entity<Project>,
172 buffer: Entity<Buffer>,
173 snapshot: BufferSnapshot,
174 position: Anchor,
175 events: Vec<Arc<zeta_prompt::Event>>,
176 related_files: Vec<RelatedFile>,
177 trigger: PredictEditsRequestTrigger,
178 diagnostic_search_range: Range<Point>,
179 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
180 can_collect_data: bool,
181 is_open_source: bool,
182}
183
184#[derive(Debug)]
185pub enum DebugEvent {
186 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
187 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
188 EditPredictionStarted(EditPredictionStartedDebugEvent),
189 EditPredictionFinished(EditPredictionFinishedDebugEvent),
190}
191
192#[derive(Debug)]
193pub struct ContextRetrievalStartedDebugEvent {
194 pub project_entity_id: EntityId,
195 pub timestamp: Instant,
196 pub search_prompt: String,
197}
198
199#[derive(Debug)]
200pub struct ContextRetrievalFinishedDebugEvent {
201 pub project_entity_id: EntityId,
202 pub timestamp: Instant,
203 pub metadata: Vec<(&'static str, SharedString)>,
204}
205
206#[derive(Debug)]
207pub struct EditPredictionStartedDebugEvent {
208 pub buffer: WeakEntity<Buffer>,
209 pub position: Anchor,
210 pub prompt: Option<String>,
211}
212
213#[derive(Debug)]
214pub struct EditPredictionFinishedDebugEvent {
215 pub buffer: WeakEntity<Buffer>,
216 pub position: Anchor,
217 pub model_output: Option<String>,
218}
219
220/// An event with associated metadata for reconstructing buffer state.
221#[derive(Clone)]
222pub struct StoredEvent {
223 pub event: Arc<zeta_prompt::Event>,
224 pub old_snapshot: TextBufferSnapshot,
225 pub new_snapshot_version: clock::Global,
226 pub total_edit_range: Range<Anchor>,
227}
228
229impl StoredEvent {
230 fn can_merge(
231 &self,
232 next_old_event: &StoredEvent,
233 latest_snapshot: &TextBufferSnapshot,
234 latest_edit_range: &Range<Anchor>,
235 ) -> bool {
236 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
237 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
238 return false;
239 }
240 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
241 return false;
242 }
243 if self.new_snapshot_version != next_old_event.old_snapshot.version {
244 return false;
245 }
246 if !latest_snapshot
247 .version
248 .observed_all(&next_old_event.new_snapshot_version)
249 {
250 return false;
251 }
252
253 let a_is_predicted = matches!(
254 self.event.as_ref(),
255 zeta_prompt::Event::BufferChange {
256 predicted: true,
257 ..
258 }
259 );
260 let b_is_predicted = matches!(
261 next_old_event.event.as_ref(),
262 zeta_prompt::Event::BufferChange {
263 predicted: true,
264 ..
265 }
266 );
267
268 // If events come from the same source (both predicted or both manual) then
269 // we would have coalesced them already.
270 if a_is_predicted == b_is_predicted {
271 return false;
272 }
273
274 let left_range = self.total_edit_range.to_point(latest_snapshot);
275 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
276 let latest_range = latest_edit_range.to_point(latest_snapshot);
277
278 // Events near to the latest edit are not merged if their sources differ.
279 if lines_between_ranges(&left_range, &latest_range)
280 .min(lines_between_ranges(&right_range, &latest_range))
281 <= CHANGE_GROUPING_LINE_SPAN
282 {
283 return false;
284 }
285
286 // Events that are distant from each other are not merged.
287 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
288 return false;
289 }
290
291 true
292 }
293}
294
295fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
296 if left.start > right.end {
297 return left.start.row - right.end.row;
298 }
299 if right.start > left.end {
300 return right.start.row - left.end.row;
301 }
302 0
303}
304
305struct ProjectState {
306 events: VecDeque<StoredEvent>,
307 last_event: Option<LastEvent>,
308 recent_paths: VecDeque<ProjectPath>,
309 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
310 current_prediction: Option<CurrentEditPrediction>,
311 next_pending_prediction_id: usize,
312 pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
313 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
314 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
315 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
316 cancelled_predictions: HashSet<usize>,
317 context: Entity<RelatedExcerptStore>,
318 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
319 _subscriptions: [gpui::Subscription; 2],
320 copilot: Option<Entity<Copilot>>,
321}
322
323impl ProjectState {
324 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
325 self.events
326 .iter()
327 .cloned()
328 .chain(self.last_event.as_ref().iter().flat_map(|event| {
329 let (one, two) = event.split_by_pause();
330 let one = one.finalize(&self.license_detection_watchers, cx);
331 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
332 one.into_iter().chain(two)
333 }))
334 .collect()
335 }
336
337 fn cancel_pending_prediction(
338 &mut self,
339 pending_prediction: PendingPrediction,
340 cx: &mut Context<EditPredictionStore>,
341 ) {
342 self.cancelled_predictions.insert(pending_prediction.id);
343
344 if pending_prediction.drop_on_cancel {
345 drop(pending_prediction.task);
346 } else {
347 cx.spawn(async move |this, cx| {
348 let Some(prediction_id) = pending_prediction.task.await else {
349 return;
350 };
351
352 this.update(cx, |this, cx| {
353 this.reject_prediction(
354 prediction_id,
355 EditPredictionRejectReason::Canceled,
356 false,
357 None,
358 None,
359 cx,
360 );
361 })
362 .ok();
363 })
364 .detach()
365 }
366 }
367
368 fn active_buffer(
369 &self,
370 project: &Entity<Project>,
371 cx: &App,
372 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
373 let project = project.read(cx);
374 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
375 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
376 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
377 Some((active_buffer, registered_buffer.last_position))
378 }
379}
380
381#[derive(Debug, Clone)]
382struct CurrentEditPrediction {
383 pub requested_by: PredictionRequestedBy,
384 pub prediction: EditPrediction,
385 pub was_shown: bool,
386 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
387 pub e2e_latency: std::time::Duration,
388}
389
390impl CurrentEditPrediction {
391 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
392 let Some(new_edits) = self
393 .prediction
394 .interpolate(&self.prediction.buffer.read(cx))
395 else {
396 return false;
397 };
398
399 if self.prediction.buffer != old_prediction.prediction.buffer {
400 return true;
401 }
402
403 let Some(old_edits) = old_prediction
404 .prediction
405 .interpolate(&old_prediction.prediction.buffer.read(cx))
406 else {
407 return true;
408 };
409
410 let requested_by_buffer_id = self.requested_by.buffer_id();
411
412 // This reduces the occurrence of UI thrash from replacing edits
413 //
414 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
415 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
416 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
417 && old_edits.len() == 1
418 && new_edits.len() == 1
419 {
420 let (old_range, old_text) = &old_edits[0];
421 let (new_range, new_text) = &new_edits[0];
422 new_range == old_range && new_text.starts_with(old_text.as_ref())
423 } else {
424 true
425 }
426 }
427}
428
429#[derive(Debug, Clone)]
430enum PredictionRequestedBy {
431 DiagnosticsUpdate,
432 Buffer(EntityId),
433}
434
435impl PredictionRequestedBy {
436 pub fn buffer_id(&self) -> Option<EntityId> {
437 match self {
438 PredictionRequestedBy::DiagnosticsUpdate => None,
439 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
440 }
441 }
442}
443
444const DIAGNOSTIC_LINES_RANGE: u32 = 20;
445
446#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
447pub enum DiagnosticSearchScope {
448 Local,
449 Global,
450}
451
452#[derive(Debug)]
453struct PendingPrediction {
454 id: usize,
455 task: Task<Option<EditPredictionId>>,
456 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
457 /// If false, the task is awaited to completion so rejection can be reported.
458 drop_on_cancel: bool,
459}
460
461/// A prediction from the perspective of a buffer.
462#[derive(Debug)]
463enum BufferEditPrediction<'a> {
464 Local { prediction: &'a EditPrediction },
465 Jump { prediction: &'a EditPrediction },
466}
467
468#[cfg(test)]
469impl std::ops::Deref for BufferEditPrediction<'_> {
470 type Target = EditPrediction;
471
472 fn deref(&self) -> &Self::Target {
473 match self {
474 BufferEditPrediction::Local { prediction } => prediction,
475 BufferEditPrediction::Jump { prediction } => prediction,
476 }
477 }
478}
479
480#[derive(Clone)]
481
482struct PendingSettledPrediction {
483 request_id: EditPredictionId,
484 editable_anchor_range: Range<Anchor>,
485 example: Option<ExampleSpec>,
486 enqueued_at: Instant,
487 last_edit_at: Instant,
488 e2e_latency: std::time::Duration,
489}
490
491struct RegisteredBuffer {
492 file: Option<Arc<dyn File>>,
493 snapshot: TextBufferSnapshot,
494 pending_predictions: Vec<PendingSettledPrediction>,
495 last_position: Option<Anchor>,
496 _subscriptions: [gpui::Subscription; 2],
497}
498
499#[derive(Clone)]
500struct LastEvent {
501 old_snapshot: TextBufferSnapshot,
502 new_snapshot: TextBufferSnapshot,
503 old_file: Option<Arc<dyn File>>,
504 new_file: Option<Arc<dyn File>>,
505 latest_edit_range: Range<Anchor>,
506 total_edit_range: Range<Anchor>,
507 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
508 predicted: bool,
509 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
510 last_edit_time: Option<Instant>,
511}
512
513impl LastEvent {
514 pub fn finalize(
515 &self,
516 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
517 cx: &App,
518 ) -> Option<StoredEvent> {
519 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
520 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
521
522 let in_open_source_repo =
523 [self.new_file.as_ref(), self.old_file.as_ref()]
524 .iter()
525 .all(|file| {
526 file.is_some_and(|file| {
527 license_detection_watchers
528 .get(&file.worktree_id(cx))
529 .is_some_and(|watcher| watcher.is_project_open_source())
530 })
531 });
532
533 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
534 &self.old_snapshot,
535 &self.new_snapshot,
536 &self.total_edit_range,
537 )?;
538
539 if path == old_path && diff.is_empty() {
540 None
541 } else {
542 Some(StoredEvent {
543 event: Arc::new(zeta_prompt::Event::BufferChange {
544 old_path,
545 path,
546 diff,
547 in_open_source_repo,
548 predicted: self.predicted,
549 }),
550 old_snapshot: self.old_snapshot.clone(),
551 new_snapshot_version: self.new_snapshot.version.clone(),
552 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
553 ..self.new_snapshot.anchor_before(edit_range.end),
554 })
555 }
556 }
557
558 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
559 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
560 return (self.clone(), None);
561 };
562
563 let total_edit_range_before_pause = self
564 .total_edit_range_at_last_pause_boundary
565 .clone()
566 .unwrap_or_else(|| self.total_edit_range.clone());
567
568 let Some(total_edit_range_after_pause) =
569 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
570 else {
571 return (self.clone(), None);
572 };
573
574 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
575 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
576
577 let before = LastEvent {
578 old_snapshot: self.old_snapshot.clone(),
579 new_snapshot: boundary_snapshot.clone(),
580 old_file: self.old_file.clone(),
581 new_file: self.new_file.clone(),
582 latest_edit_range: latest_edit_range_before_pause,
583 total_edit_range: total_edit_range_before_pause,
584 total_edit_range_at_last_pause_boundary: None,
585 predicted: self.predicted,
586 snapshot_after_last_editing_pause: None,
587 last_edit_time: self.last_edit_time,
588 };
589
590 let after = LastEvent {
591 old_snapshot: boundary_snapshot.clone(),
592 new_snapshot: self.new_snapshot.clone(),
593 old_file: self.old_file.clone(),
594 new_file: self.new_file.clone(),
595 latest_edit_range: latest_edit_range_after_pause,
596 total_edit_range: total_edit_range_after_pause,
597 total_edit_range_at_last_pause_boundary: None,
598 predicted: self.predicted,
599 snapshot_after_last_editing_pause: None,
600 last_edit_time: self.last_edit_time,
601 };
602
603 (before, Some(after))
604 }
605}
606
607fn compute_total_edit_range_between_snapshots(
608 old_snapshot: &TextBufferSnapshot,
609 new_snapshot: &TextBufferSnapshot,
610) -> Option<Range<Anchor>> {
611 let edits: Vec<Edit<usize>> = new_snapshot
612 .edits_since::<usize>(&old_snapshot.version)
613 .collect();
614
615 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
616 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
617 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
618
619 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
620}
621
622fn compute_old_range_for_new_range(
623 old_snapshot: &TextBufferSnapshot,
624 new_snapshot: &TextBufferSnapshot,
625 total_edit_range: &Range<Anchor>,
626) -> Option<Range<Point>> {
627 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
628 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
629
630 let edits: Vec<Edit<usize>> = new_snapshot
631 .edits_since::<usize>(&old_snapshot.version)
632 .collect();
633 let mut old_start_offset = None;
634 let mut old_end_offset = None;
635 let mut delta: isize = 0;
636
637 for edit in &edits {
638 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
639 old_start_offset = Some(if new_start_offset < edit.new.start {
640 new_start_offset.checked_add_signed(-delta)?
641 } else {
642 edit.old.start
643 });
644 }
645
646 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
647 old_end_offset = Some(if new_end_offset < edit.new.start {
648 new_end_offset.checked_add_signed(-delta)?
649 } else {
650 edit.old.end
651 });
652 }
653
654 delta += edit.new.len() as isize - edit.old.len() as isize;
655 }
656
657 let old_start_offset =
658 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
659 let old_end_offset =
660 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
661
662 Some(
663 old_snapshot.offset_to_point(old_start_offset)
664 ..old_snapshot.offset_to_point(old_end_offset),
665 )
666}
667
668fn compute_diff_between_snapshots_in_range(
669 old_snapshot: &TextBufferSnapshot,
670 new_snapshot: &TextBufferSnapshot,
671 total_edit_range: &Range<Anchor>,
672) -> Option<(String, Range<Point>)> {
673 let new_start_point = total_edit_range.start.to_point(new_snapshot);
674 let new_end_point = total_edit_range.end.to_point(new_snapshot);
675 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
676 let old_start_point = old_range.start;
677 let old_end_point = old_range.end;
678
679 const CONTEXT_LINES: u32 = 3;
680
681 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
682 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
683 let old_context_end_row =
684 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
685 let new_context_end_row =
686 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
687
688 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
689 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
690 let old_end_line_offset = old_snapshot
691 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
692 let new_end_line_offset = new_snapshot
693 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
694 let old_edit_range = old_start_line_offset..old_end_line_offset;
695 let new_edit_range = new_start_line_offset..new_end_line_offset;
696
697 if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
698 || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
699 {
700 return None;
701 }
702
703 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
704 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
705
706 let diff = language::unified_diff_with_offsets(
707 &old_region_text,
708 &new_region_text,
709 old_context_start_row,
710 new_context_start_row,
711 );
712
713 Some((diff, new_start_point..new_end_point))
714}
715
716fn buffer_path_with_id_fallback(
717 file: Option<&Arc<dyn File>>,
718 snapshot: &TextBufferSnapshot,
719 cx: &App,
720) -> Arc<Path> {
721 if let Some(file) = file {
722 file.full_path(cx).into()
723 } else {
724 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
725 }
726}
727
728impl EditPredictionStore {
729 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
730 cx.try_global::<EditPredictionStoreGlobal>()
731 .map(|global| global.0.clone())
732 }
733
734 pub fn global(
735 client: &Arc<Client>,
736 user_store: &Entity<UserStore>,
737 cx: &mut App,
738 ) -> Entity<Self> {
739 cx.try_global::<EditPredictionStoreGlobal>()
740 .map(|global| global.0.clone())
741 .unwrap_or_else(|| {
742 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
743 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
744 ep_store
745 })
746 }
747
748 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
749 let data_collection_choice = Self::load_data_collection_choice(cx);
750
751 let llm_token = global_llm_token(cx);
752
753 let (reject_tx, reject_rx) = mpsc::unbounded();
754 cx.background_spawn({
755 let client = client.clone();
756 let llm_token = llm_token.clone();
757 let app_version = AppVersion::global(cx);
758 let background_executor = cx.background_executor().clone();
759 async move {
760 Self::handle_rejected_predictions(
761 reject_rx,
762 client,
763 llm_token,
764 app_version,
765 background_executor,
766 )
767 .await
768 }
769 })
770 .detach();
771
772 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
773 cx.spawn(async move |this, cx| {
774 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
775 })
776 .detach();
777
778 let mut current_user = user_store.read(cx).watch_current_user();
779 let fetch_experiments_task = cx.spawn(async move |this, cx| {
780 while current_user.borrow().is_none() {
781 current_user.next().await;
782 }
783
784 this.update(cx, |this, cx| {
785 if cx.is_staff() {
786 this.refresh_available_experiments(cx);
787 }
788 })
789 .log_err();
790 });
791
792 let credentials_provider = zed_credentials_provider::global(cx);
793
794 let this = Self {
795 projects: HashMap::default(),
796 client,
797 user_store,
798 llm_token,
799 _fetch_experiments_task: fetch_experiments_task,
800 update_required: false,
801 edit_prediction_model: EditPredictionModel::Zeta,
802 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
803 preferred_experiment: None,
804 available_experiments: Vec::new(),
805 mercury: Mercury::new(cx),
806
807 data_collection_choice,
808 reject_predictions_tx: reject_tx,
809 settled_predictions_tx,
810 rated_predictions: Default::default(),
811 shown_predictions: Default::default(),
812 #[cfg(test)]
813 settled_event_callback: None,
814
815 credentials_provider,
816 };
817
818 this
819 }
820
821 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
822 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
823 let format = ZetaFormat::parse(&version_str).ok()?;
824 let model_id = env::var("ZED_ZETA_MODEL").ok();
825 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
826 Some(Zeta2RawConfig {
827 model_id,
828 environment,
829 format,
830 })
831 }
832
833 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
834 self.edit_prediction_model = model;
835 }
836
837 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
838 self.zeta2_raw_config = Some(config);
839 }
840
841 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
842 self.zeta2_raw_config.as_ref()
843 }
844
845 pub fn preferred_experiment(&self) -> Option<&str> {
846 self.preferred_experiment.as_deref()
847 }
848
849 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
850 self.preferred_experiment = experiment;
851 }
852
853 pub fn available_experiments(&self) -> &[String] {
854 &self.available_experiments
855 }
856
857 pub fn active_experiment(&self) -> Option<&str> {
858 self.preferred_experiment.as_deref().or_else(|| {
859 self.shown_predictions
860 .iter()
861 .find_map(|p| p.model_version.as_ref())
862 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
863 })
864 }
865
866 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
867 let client = self.client.clone();
868 let llm_token = self.llm_token.clone();
869 let app_version = AppVersion::global(cx);
870 let organization_id = self
871 .user_store
872 .read(cx)
873 .current_organization()
874 .map(|organization| organization.id.clone());
875
876 cx.spawn(async move |this, cx| {
877 let experiments = cx
878 .background_spawn(async move {
879 let http_client = client.http_client();
880 let token = client
881 .acquire_llm_token(&llm_token, organization_id.clone())
882 .await?;
883 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
884 let request = http_client::Request::builder()
885 .method(Method::GET)
886 .uri(url.as_ref())
887 .header("Authorization", format!("Bearer {}", token))
888 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
889 .body(Default::default())?;
890 let mut response = http_client.send(request).await?;
891 if response.status().is_success() {
892 let mut body = Vec::new();
893 response.body_mut().read_to_end(&mut body).await?;
894 let experiments: Vec<String> = serde_json::from_slice(&body)?;
895 Ok(experiments)
896 } else {
897 let mut body = String::new();
898 response.body_mut().read_to_string(&mut body).await?;
899 anyhow::bail!(
900 "Failed to fetch experiments: {:?}\nBody: {}",
901 response.status(),
902 body
903 );
904 }
905 })
906 .await?;
907 this.update(cx, |this, cx| {
908 this.available_experiments = experiments;
909 cx.notify();
910 })?;
911 anyhow::Ok(())
912 })
913 .detach_and_log_err(cx);
914 }
915
916 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
917 use ui::IconName;
918 match self.edit_prediction_model {
919 EditPredictionModel::Mercury => {
920 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
921 }
922 EditPredictionModel::Zeta => {
923 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
924 .with_disabled(IconName::ZedPredictDisabled)
925 .with_up(IconName::ZedPredictUp)
926 .with_down(IconName::ZedPredictDown)
927 .with_error(IconName::ZedPredictError)
928 }
929 EditPredictionModel::Fim { .. } => {
930 let settings = &all_language_settings(None, cx).edit_predictions;
931 match settings.provider {
932 EditPredictionProvider::Ollama => {
933 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
934 }
935 _ => {
936 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
937 }
938 }
939 }
940 }
941 }
942
943 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
944 self.mercury.api_token.read(cx).has_key()
945 }
946
947 pub fn mercury_has_payment_required_error(&self) -> bool {
948 self.mercury.has_payment_required_error()
949 }
950
951 pub fn clear_history(&mut self) {
952 for project_state in self.projects.values_mut() {
953 project_state.events.clear();
954 project_state.last_event.take();
955 }
956 }
957
958 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
959 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
960 project_state.events.clear();
961 project_state.last_event.take();
962 }
963 }
964
965 pub fn edit_history_for_project(
966 &self,
967 project: &Entity<Project>,
968 cx: &App,
969 ) -> Vec<StoredEvent> {
970 self.projects
971 .get(&project.entity_id())
972 .map(|project_state| project_state.events(cx))
973 .unwrap_or_default()
974 }
975
976 pub fn context_for_project<'a>(
977 &'a self,
978 project: &Entity<Project>,
979 cx: &'a mut App,
980 ) -> Vec<RelatedFile> {
981 self.projects
982 .get(&project.entity_id())
983 .map(|project_state| {
984 project_state.context.update(cx, |context, cx| {
985 context
986 .related_files_with_buffers(cx)
987 .map(|(mut related_file, buffer)| {
988 related_file.in_open_source_repo = buffer
989 .read(cx)
990 .file()
991 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
992 related_file
993 })
994 .collect()
995 })
996 })
997 .unwrap_or_default()
998 }
999
1000 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1001 self.projects
1002 .get(&project.entity_id())
1003 .and_then(|project| project.copilot.clone())
1004 }
1005
1006 pub fn start_copilot_for_project(
1007 &mut self,
1008 project: &Entity<Project>,
1009 cx: &mut Context<Self>,
1010 ) -> Option<Entity<Copilot>> {
1011 if DisableAiSettings::get(None, cx).disable_ai {
1012 return None;
1013 }
1014 let state = self.get_or_init_project(project, cx);
1015
1016 if state.copilot.is_some() {
1017 return state.copilot.clone();
1018 }
1019 let _project = project.clone();
1020 let project = project.read(cx);
1021
1022 let node = project.node_runtime().cloned();
1023 if let Some(node) = node {
1024 let next_id = project.languages().next_language_server_id();
1025 let fs = project.fs().clone();
1026
1027 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1028 state.copilot = Some(copilot.clone());
1029 Some(copilot)
1030 } else {
1031 None
1032 }
1033 }
1034
1035 pub fn context_for_project_with_buffers<'a>(
1036 &'a self,
1037 project: &Entity<Project>,
1038 cx: &'a mut App,
1039 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1040 self.projects
1041 .get(&project.entity_id())
1042 .map(|project| {
1043 project.context.update(cx, |context, cx| {
1044 context.related_files_with_buffers(cx).collect()
1045 })
1046 })
1047 .unwrap_or_default()
1048 }
1049
1050 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1051 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1052 self.user_store.read(cx).edit_prediction_usage()
1053 } else {
1054 None
1055 }
1056 }
1057
1058 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1059 self.get_or_init_project(project, cx);
1060 }
1061
1062 pub fn register_buffer(
1063 &mut self,
1064 buffer: &Entity<Buffer>,
1065 project: &Entity<Project>,
1066 cx: &mut Context<Self>,
1067 ) {
1068 let project_state = self.get_or_init_project(project, cx);
1069 Self::register_buffer_impl(project_state, buffer, project, cx);
1070 }
1071
1072 fn get_or_init_project(
1073 &mut self,
1074 project: &Entity<Project>,
1075 cx: &mut Context<Self>,
1076 ) -> &mut ProjectState {
1077 let entity_id = project.entity_id();
1078 self.projects
1079 .entry(entity_id)
1080 .or_insert_with(|| ProjectState {
1081 context: {
1082 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1083 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1084 this.handle_excerpt_store_event(entity_id, event);
1085 })
1086 .detach();
1087 related_excerpt_store
1088 },
1089 events: VecDeque::new(),
1090 last_event: None,
1091 recent_paths: VecDeque::new(),
1092 debug_tx: None,
1093 registered_buffers: HashMap::default(),
1094 current_prediction: None,
1095 cancelled_predictions: HashSet::default(),
1096 pending_predictions: ArrayVec::new(),
1097 next_pending_prediction_id: 0,
1098 last_edit_prediction_refresh: None,
1099 last_jump_prediction_refresh: None,
1100 license_detection_watchers: HashMap::default(),
1101 _subscriptions: [
1102 cx.subscribe(&project, Self::handle_project_event),
1103 cx.observe_release(&project, move |this, _, cx| {
1104 this.projects.remove(&entity_id);
1105 cx.notify();
1106 }),
1107 ],
1108 copilot: None,
1109 })
1110 }
1111
1112 pub fn remove_project(&mut self, project: &Entity<Project>) {
1113 self.projects.remove(&project.entity_id());
1114 }
1115
1116 fn handle_excerpt_store_event(
1117 &mut self,
1118 project_entity_id: EntityId,
1119 event: &RelatedExcerptStoreEvent,
1120 ) {
1121 if let Some(project_state) = self.projects.get(&project_entity_id) {
1122 if let Some(debug_tx) = project_state.debug_tx.clone() {
1123 match event {
1124 RelatedExcerptStoreEvent::StartedRefresh => {
1125 debug_tx
1126 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1127 ContextRetrievalStartedDebugEvent {
1128 project_entity_id: project_entity_id,
1129 timestamp: Instant::now(),
1130 search_prompt: String::new(),
1131 },
1132 ))
1133 .ok();
1134 }
1135 RelatedExcerptStoreEvent::FinishedRefresh {
1136 cache_hit_count,
1137 cache_miss_count,
1138 mean_definition_latency,
1139 max_definition_latency,
1140 } => {
1141 debug_tx
1142 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1143 ContextRetrievalFinishedDebugEvent {
1144 project_entity_id: project_entity_id,
1145 timestamp: Instant::now(),
1146 metadata: vec![
1147 (
1148 "Cache Hits",
1149 format!(
1150 "{}/{}",
1151 cache_hit_count,
1152 cache_hit_count + cache_miss_count
1153 )
1154 .into(),
1155 ),
1156 (
1157 "Max LSP Time",
1158 format!("{} ms", max_definition_latency.as_millis())
1159 .into(),
1160 ),
1161 (
1162 "Mean LSP Time",
1163 format!("{} ms", mean_definition_latency.as_millis())
1164 .into(),
1165 ),
1166 ],
1167 },
1168 ))
1169 .ok();
1170 }
1171 }
1172 }
1173 }
1174 }
1175
1176 pub fn debug_info(
1177 &mut self,
1178 project: &Entity<Project>,
1179 cx: &mut Context<Self>,
1180 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1181 let project_state = self.get_or_init_project(project, cx);
1182 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1183 project_state.debug_tx = Some(debug_watch_tx);
1184 debug_watch_rx
1185 }
1186
1187 fn handle_project_event(
1188 &mut self,
1189 project: Entity<Project>,
1190 event: &project::Event,
1191 cx: &mut Context<Self>,
1192 ) {
1193 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1194 return;
1195 }
1196 // TODO [zeta2] init with recent paths
1197 match event {
1198 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1199 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1200 return;
1201 };
1202 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1203 if let Some(path) = path {
1204 if let Some(ix) = project_state
1205 .recent_paths
1206 .iter()
1207 .position(|probe| probe == &path)
1208 {
1209 project_state.recent_paths.remove(ix);
1210 }
1211 project_state.recent_paths.push_front(path);
1212 }
1213 }
1214 project::Event::DiagnosticsUpdated { .. } => {
1215 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1216 self.refresh_prediction_from_diagnostics(
1217 project,
1218 DiagnosticSearchScope::Global,
1219 cx,
1220 );
1221 }
1222 }
1223 _ => (),
1224 }
1225 }
1226
1227 fn register_buffer_impl<'a>(
1228 project_state: &'a mut ProjectState,
1229 buffer: &Entity<Buffer>,
1230 project: &Entity<Project>,
1231 cx: &mut Context<Self>,
1232 ) -> &'a mut RegisteredBuffer {
1233 let buffer_id = buffer.entity_id();
1234
1235 if let Some(file) = buffer.read(cx).file() {
1236 let worktree_id = file.worktree_id(cx);
1237 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1238 project_state
1239 .license_detection_watchers
1240 .entry(worktree_id)
1241 .or_insert_with(|| {
1242 let project_entity_id = project.entity_id();
1243 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1244 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1245 else {
1246 return;
1247 };
1248 project_state
1249 .license_detection_watchers
1250 .remove(&worktree_id);
1251 })
1252 .detach();
1253 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1254 });
1255 }
1256 }
1257
1258 match project_state.registered_buffers.entry(buffer_id) {
1259 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1260 hash_map::Entry::Vacant(entry) => {
1261 let buf = buffer.read(cx);
1262 let snapshot = buf.text_snapshot();
1263 let file = buf.file().cloned();
1264 let project_entity_id = project.entity_id();
1265 entry.insert(RegisteredBuffer {
1266 snapshot,
1267 file,
1268 last_position: None,
1269 pending_predictions: Vec::new(),
1270 _subscriptions: [
1271 cx.subscribe(buffer, {
1272 let project = project.downgrade();
1273 move |this, buffer, event, cx| {
1274 if let language::BufferEvent::Edited { is_local } = event
1275 && let Some(project) = project.upgrade()
1276 {
1277 this.report_changes_for_buffer(
1278 &buffer, &project, false, *is_local, cx,
1279 );
1280 }
1281 }
1282 }),
1283 cx.observe_release(buffer, move |this, _buffer, _cx| {
1284 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1285 else {
1286 return;
1287 };
1288 project_state.registered_buffers.remove(&buffer_id);
1289 }),
1290 ],
1291 })
1292 }
1293 }
1294 }
1295
1296 fn report_changes_for_buffer(
1297 &mut self,
1298 buffer: &Entity<Buffer>,
1299 project: &Entity<Project>,
1300 is_predicted: bool,
1301 is_local: bool,
1302 cx: &mut Context<Self>,
1303 ) {
1304 let project_state = self.get_or_init_project(project, cx);
1305 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1306
1307 let buf = buffer.read(cx);
1308 let new_file = buf.file().cloned();
1309 let new_snapshot = buf.text_snapshot();
1310 if new_snapshot.version == registered_buffer.snapshot.version {
1311 return;
1312 }
1313 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1314 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1315 let mut edit_range: Option<Range<Anchor>> = None;
1316 let now = cx.background_executor().now();
1317
1318 for (_edit, anchor_range) in
1319 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1320 {
1321 edit_range = Some(match edit_range {
1322 None => anchor_range,
1323 Some(acc) => acc.start..anchor_range.end,
1324 });
1325 }
1326
1327 let Some(edit_range) = edit_range else {
1328 return;
1329 };
1330
1331 for pending_prediction in &mut registered_buffer.pending_predictions {
1332 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1333 pending_prediction.last_edit_at = now;
1334 }
1335 }
1336
1337 let include_in_history = is_local
1338 || collaborator_edit_overlaps_locality_region(
1339 project_state,
1340 project,
1341 buffer,
1342 &buf.snapshot(),
1343 &edit_range,
1344 cx,
1345 );
1346
1347 if !include_in_history {
1348 return;
1349 }
1350
1351 let is_recordable_history_edit =
1352 compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1353 .is_some();
1354
1355 let events = &mut project_state.events;
1356
1357 if !is_recordable_history_edit {
1358 if let Some(event) = project_state.last_event.take() {
1359 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1360 if events.len() + 1 >= EVENT_COUNT_MAX {
1361 events.pop_front();
1362 }
1363 events.push_back(event);
1364 }
1365 }
1366 return;
1367 }
1368
1369 if let Some(last_event) = project_state.last_event.as_mut() {
1370 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1371 == last_event.new_snapshot.remote_id()
1372 && old_snapshot.version == last_event.new_snapshot.version;
1373
1374 let prediction_source_changed = is_predicted != last_event.predicted;
1375
1376 let should_coalesce = is_next_snapshot_of_same_buffer
1377 && !prediction_source_changed
1378 && lines_between_ranges(
1379 &edit_range.to_point(&new_snapshot),
1380 &last_event.latest_edit_range.to_point(&new_snapshot),
1381 ) <= CHANGE_GROUPING_LINE_SPAN;
1382
1383 if should_coalesce {
1384 let pause_elapsed = last_event
1385 .last_edit_time
1386 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1387 .unwrap_or(false);
1388 if pause_elapsed {
1389 last_event.snapshot_after_last_editing_pause =
1390 Some(last_event.new_snapshot.clone());
1391 last_event.total_edit_range_at_last_pause_boundary =
1392 Some(last_event.total_edit_range.clone());
1393 }
1394
1395 last_event.latest_edit_range = edit_range.clone();
1396 last_event.total_edit_range =
1397 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1398 last_event.new_snapshot = new_snapshot;
1399 last_event.last_edit_time = Some(now);
1400 return;
1401 }
1402 }
1403
1404 if let Some(event) = project_state.last_event.take() {
1405 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1406 if events.len() + 1 >= EVENT_COUNT_MAX {
1407 events.pop_front();
1408 }
1409 events.push_back(event);
1410 }
1411 }
1412
1413 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1414
1415 project_state.last_event = Some(LastEvent {
1416 old_file,
1417 new_file,
1418 old_snapshot,
1419 new_snapshot,
1420 latest_edit_range: edit_range.clone(),
1421 total_edit_range: edit_range,
1422 total_edit_range_at_last_pause_boundary: None,
1423 predicted: is_predicted,
1424 snapshot_after_last_editing_pause: None,
1425 last_edit_time: Some(now),
1426 });
1427 }
1428
1429 fn prediction_at(
1430 &mut self,
1431 buffer: &Entity<Buffer>,
1432 position: Option<language::Anchor>,
1433 project: &Entity<Project>,
1434 cx: &App,
1435 ) -> Option<BufferEditPrediction<'_>> {
1436 let project_state = self.projects.get_mut(&project.entity_id())?;
1437 if let Some(position) = position
1438 && let Some(buffer) = project_state
1439 .registered_buffers
1440 .get_mut(&buffer.entity_id())
1441 {
1442 buffer.last_position = Some(position);
1443 }
1444
1445 let CurrentEditPrediction {
1446 requested_by,
1447 prediction,
1448 ..
1449 } = project_state.current_prediction.as_ref()?;
1450
1451 if prediction.targets_buffer(buffer.read(cx)) {
1452 Some(BufferEditPrediction::Local { prediction })
1453 } else {
1454 let show_jump = match requested_by {
1455 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1456 requested_by_buffer_id == &buffer.entity_id()
1457 }
1458 PredictionRequestedBy::DiagnosticsUpdate => true,
1459 };
1460
1461 if show_jump {
1462 Some(BufferEditPrediction::Jump { prediction })
1463 } else {
1464 None
1465 }
1466 }
1467 }
1468
1469 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1470 let Some(current_prediction) = self
1471 .projects
1472 .get_mut(&project.entity_id())
1473 .and_then(|project_state| project_state.current_prediction.take())
1474 else {
1475 return;
1476 };
1477
1478 self.report_changes_for_buffer(
1479 ¤t_prediction.prediction.buffer,
1480 project,
1481 true,
1482 true,
1483 cx,
1484 );
1485
1486 // can't hold &mut project_state ref across report_changes_for_buffer_call
1487 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1488 return;
1489 };
1490
1491 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1492 project_state.cancel_pending_prediction(pending_prediction, cx);
1493 }
1494
1495 match self.edit_prediction_model {
1496 EditPredictionModel::Mercury => {
1497 mercury::edit_prediction_accepted(
1498 current_prediction.prediction.id,
1499 self.client.http_client(),
1500 cx,
1501 );
1502 }
1503 EditPredictionModel::Zeta => {
1504 let is_cloud = !matches!(
1505 all_language_settings(None, cx).edit_predictions.provider,
1506 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1507 );
1508 if is_cloud {
1509 zeta::edit_prediction_accepted(self, current_prediction, cx)
1510 }
1511 }
1512 EditPredictionModel::Fim { .. } => {}
1513 }
1514 }
1515
1516 async fn handle_rejected_predictions(
1517 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1518 client: Arc<Client>,
1519 llm_token: LlmApiToken,
1520 app_version: Version,
1521 background_executor: BackgroundExecutor,
1522 ) {
1523 let mut rx = std::pin::pin!(rx.peekable());
1524 let mut batched = Vec::new();
1525
1526 while let Some(EditPredictionRejectionPayload {
1527 rejection,
1528 organization_id,
1529 }) = rx.next().await
1530 {
1531 batched.push(rejection);
1532
1533 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1534 select_biased! {
1535 next = rx.as_mut().peek().fuse() => {
1536 if next.is_some() {
1537 continue;
1538 }
1539 }
1540 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1541 }
1542 }
1543
1544 let url = client
1545 .http_client()
1546 .build_zed_llm_url("/predict_edits/reject", &[])
1547 .unwrap();
1548
1549 let flush_count = batched
1550 .len()
1551 // in case items have accumulated after failure
1552 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1553 let start = batched.len() - flush_count;
1554
1555 let body = RejectEditPredictionsBodyRef {
1556 rejections: &batched[start..],
1557 };
1558
1559 let result = Self::send_api_request::<()>(
1560 |builder| {
1561 let req = builder
1562 .uri(url.as_ref())
1563 .body(serde_json::to_string(&body)?.into());
1564 anyhow::Ok(req?)
1565 },
1566 client.clone(),
1567 llm_token.clone(),
1568 organization_id,
1569 app_version.clone(),
1570 true,
1571 )
1572 .await;
1573
1574 if result.log_err().is_some() {
1575 batched.drain(start..);
1576 }
1577 }
1578 }
1579
1580 async fn run_settled_predictions_worker(
1581 this: WeakEntity<Self>,
1582 mut rx: UnboundedReceiver<Instant>,
1583 cx: &mut AsyncApp,
1584 ) {
1585 let mut next_wake_time: Option<Instant> = None;
1586 loop {
1587 let now = cx.background_executor().now();
1588 if let Some(wake_time) = next_wake_time.take() {
1589 cx.background_executor()
1590 .timer(wake_time.duration_since(now))
1591 .await;
1592 } else {
1593 let Some(new_enqueue_time) = rx.next().await else {
1594 break;
1595 };
1596 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1597 while rx.next().now_or_never().flatten().is_some() {}
1598 continue;
1599 }
1600
1601 let Some(this) = this.upgrade() else {
1602 break;
1603 };
1604
1605 let now = cx.background_executor().now();
1606
1607 let mut oldest_edited_at = None;
1608
1609 this.update(cx, |this, _| {
1610 for (_, project_state) in this.projects.iter_mut() {
1611 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1612 registered_buffer
1613 .pending_predictions
1614 .retain_mut(|pending_prediction| {
1615 let age =
1616 now.saturating_duration_since(pending_prediction.enqueued_at);
1617 if age >= EDIT_PREDICTION_SETTLED_TTL {
1618 return false;
1619 }
1620
1621 let quiet_for =
1622 now.saturating_duration_since(pending_prediction.last_edit_at);
1623 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1624 let settled_editable_region = registered_buffer
1625 .snapshot
1626 .text_for_range(
1627 pending_prediction.editable_anchor_range.clone(),
1628 )
1629 .collect::<String>();
1630
1631 #[cfg(test)]
1632 if let Some(callback) = &this.settled_event_callback {
1633 callback(
1634 pending_prediction.request_id.clone(),
1635 settled_editable_region.clone(),
1636 );
1637 }
1638
1639 telemetry::event!(
1640 EDIT_PREDICTION_SETTLED_EVENT,
1641 request_id = pending_prediction.request_id.0.clone(),
1642 settled_editable_region,
1643 example = pending_prediction.example.take(),
1644 e2e_latency = pending_prediction.e2e_latency.as_millis(),
1645 );
1646
1647 return false;
1648 }
1649
1650 if oldest_edited_at
1651 .is_none_or(|t| pending_prediction.last_edit_at < t)
1652 {
1653 oldest_edited_at = Some(pending_prediction.last_edit_at);
1654 }
1655
1656 true
1657 });
1658 }
1659 }
1660 });
1661
1662 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1663 }
1664 }
1665
1666 pub(crate) fn enqueue_settled_prediction(
1667 &mut self,
1668 request_id: EditPredictionId,
1669 project: &Entity<Project>,
1670 edited_buffer: &Entity<Buffer>,
1671 edited_buffer_snapshot: &BufferSnapshot,
1672 editable_offset_range: Range<usize>,
1673 example: Option<ExampleSpec>,
1674 e2e_latency: std::time::Duration,
1675 cx: &mut Context<Self>,
1676 ) {
1677 let this = &mut *self;
1678 let project_state = this.get_or_init_project(project, cx);
1679 if let Some(buffer) = project_state
1680 .registered_buffers
1681 .get_mut(&edited_buffer.entity_id())
1682 {
1683 let now = cx.background_executor().now();
1684 buffer.pending_predictions.push(PendingSettledPrediction {
1685 request_id: request_id,
1686 editable_anchor_range: edited_buffer_snapshot
1687 .anchor_range_inside(editable_offset_range),
1688 example,
1689 e2e_latency,
1690 enqueued_at: now,
1691 last_edit_at: now,
1692 });
1693 this.settled_predictions_tx.unbounded_send(now).ok();
1694 }
1695 }
1696
1697 fn reject_current_prediction(
1698 &mut self,
1699 reason: EditPredictionRejectReason,
1700 project: &Entity<Project>,
1701 cx: &App,
1702 ) {
1703 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1704 project_state.pending_predictions.clear();
1705 if let Some(prediction) = project_state.current_prediction.take() {
1706 let model_version = prediction.prediction.model_version.clone();
1707 self.reject_prediction(
1708 prediction.prediction.id,
1709 reason,
1710 prediction.was_shown,
1711 model_version,
1712 Some(prediction.e2e_latency),
1713 cx,
1714 );
1715 }
1716 };
1717 }
1718
1719 fn did_show_current_prediction(
1720 &mut self,
1721 project: &Entity<Project>,
1722 display_type: edit_prediction_types::SuggestionDisplayType,
1723 _cx: &mut Context<Self>,
1724 ) {
1725 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1726 return;
1727 };
1728
1729 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1730 return;
1731 };
1732
1733 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1734 let previous_shown_with = current_prediction.shown_with;
1735
1736 if previous_shown_with.is_none() || !is_jump {
1737 current_prediction.shown_with = Some(display_type);
1738 }
1739
1740 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1741
1742 if is_first_non_jump_show {
1743 current_prediction.was_shown = true;
1744 }
1745
1746 if is_first_non_jump_show {
1747 self.shown_predictions
1748 .push_front(current_prediction.prediction.clone());
1749 if self.shown_predictions.len() > 50 {
1750 let completion = self.shown_predictions.pop_back().unwrap();
1751 self.rated_predictions.remove(&completion.id);
1752 }
1753 }
1754 }
1755
1756 fn reject_prediction(
1757 &mut self,
1758 prediction_id: EditPredictionId,
1759 reason: EditPredictionRejectReason,
1760 was_shown: bool,
1761 model_version: Option<String>,
1762 e2e_latency: Option<std::time::Duration>,
1763 cx: &App,
1764 ) {
1765 match self.edit_prediction_model {
1766 EditPredictionModel::Zeta => {
1767 let is_cloud = !matches!(
1768 all_language_settings(None, cx).edit_predictions.provider,
1769 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1770 );
1771
1772 if is_cloud {
1773 let organization_id = self
1774 .user_store
1775 .read(cx)
1776 .current_organization()
1777 .map(|organization| organization.id.clone());
1778
1779 self.reject_predictions_tx
1780 .unbounded_send(EditPredictionRejectionPayload {
1781 rejection: EditPredictionRejection {
1782 request_id: prediction_id.to_string(),
1783 reason,
1784 was_shown,
1785 model_version,
1786 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1787 },
1788 organization_id,
1789 })
1790 .log_err();
1791 }
1792 }
1793 EditPredictionModel::Mercury => {
1794 mercury::edit_prediction_rejected(
1795 prediction_id,
1796 was_shown,
1797 reason,
1798 self.client.http_client(),
1799 cx,
1800 );
1801 }
1802 EditPredictionModel::Fim { .. } => {}
1803 }
1804 }
1805
1806 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1807 self.projects
1808 .get(&project.entity_id())
1809 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1810 }
1811
1812 pub fn refresh_prediction_from_buffer(
1813 &mut self,
1814 project: Entity<Project>,
1815 buffer: Entity<Buffer>,
1816 position: language::Anchor,
1817 cx: &mut Context<Self>,
1818 ) {
1819 self.queue_prediction_refresh(
1820 project.clone(),
1821 PredictEditsRequestTrigger::Other,
1822 buffer.entity_id(),
1823 cx,
1824 move |this, cx| {
1825 let Some(request_task) = this
1826 .update(cx, |this, cx| {
1827 this.request_prediction(
1828 &project,
1829 &buffer,
1830 position,
1831 PredictEditsRequestTrigger::Other,
1832 cx,
1833 )
1834 })
1835 .log_err()
1836 else {
1837 return Task::ready(anyhow::Ok(None));
1838 };
1839
1840 cx.spawn(async move |_cx| {
1841 request_task.await.map(|prediction_result| {
1842 prediction_result.map(|prediction_result| {
1843 (
1844 prediction_result,
1845 PredictionRequestedBy::Buffer(buffer.entity_id()),
1846 )
1847 })
1848 })
1849 })
1850 },
1851 )
1852 }
1853
1854 pub fn refresh_prediction_from_diagnostics(
1855 &mut self,
1856 project: Entity<Project>,
1857 scope: DiagnosticSearchScope,
1858 cx: &mut Context<Self>,
1859 ) {
1860 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1861 return;
1862 }
1863
1864 if currently_following(&project, cx) {
1865 return;
1866 }
1867
1868 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1869 return;
1870 };
1871
1872 // Prefer predictions from buffer
1873 if project_state.current_prediction.is_some() {
1874 log::debug!(
1875 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1876 );
1877 return;
1878 }
1879
1880 self.queue_prediction_refresh(
1881 project.clone(),
1882 PredictEditsRequestTrigger::Diagnostics,
1883 project.entity_id(),
1884 cx,
1885 move |this, cx| {
1886 let Some((active_buffer, snapshot, cursor_point)) = this
1887 .read_with(cx, |this, cx| {
1888 let project_state = this.projects.get(&project.entity_id())?;
1889 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1890 let snapshot = buffer.read(cx).snapshot();
1891
1892 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1893 return None;
1894 }
1895
1896 let cursor_point = position
1897 .map(|pos| pos.to_point(&snapshot))
1898 .unwrap_or_default();
1899
1900 Some((buffer, snapshot, cursor_point))
1901 })
1902 .log_err()
1903 .flatten()
1904 else {
1905 return Task::ready(anyhow::Ok(None));
1906 };
1907
1908 cx.spawn(async move |cx| {
1909 let diagnostic_search_range = match scope {
1910 DiagnosticSearchScope::Local => {
1911 let diagnostic_search_start =
1912 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1913 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1914 Point::new(diagnostic_search_start, 0)
1915 ..Point::new(diagnostic_search_end, 0)
1916 }
1917 DiagnosticSearchScope::Global => Default::default(),
1918 };
1919
1920 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1921 active_buffer,
1922 &snapshot,
1923 diagnostic_search_range,
1924 cursor_point,
1925 &project,
1926 cx,
1927 )
1928 .await?
1929 else {
1930 return anyhow::Ok(None);
1931 };
1932
1933 let Some(prediction_result) = this
1934 .update(cx, |this, cx| {
1935 this.request_prediction(
1936 &project,
1937 &jump_buffer,
1938 jump_position,
1939 PredictEditsRequestTrigger::Diagnostics,
1940 cx,
1941 )
1942 })?
1943 .await?
1944 else {
1945 return anyhow::Ok(None);
1946 };
1947
1948 this.update(cx, |this, cx| {
1949 Some((
1950 if this
1951 .get_or_init_project(&project, cx)
1952 .current_prediction
1953 .is_none()
1954 {
1955 prediction_result
1956 } else {
1957 EditPredictionResult {
1958 id: prediction_result.id,
1959 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1960 e2e_latency: prediction_result.e2e_latency,
1961 }
1962 },
1963 PredictionRequestedBy::DiagnosticsUpdate,
1964 ))
1965 })
1966 })
1967 },
1968 );
1969 }
1970
1971 fn predictions_enabled_at(
1972 snapshot: &BufferSnapshot,
1973 position: Option<language::Anchor>,
1974 cx: &App,
1975 ) -> bool {
1976 let file = snapshot.file();
1977 let all_settings = all_language_settings(file, cx);
1978 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1979 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1980 {
1981 return false;
1982 }
1983
1984 if let Some(last_position) = position {
1985 let settings = snapshot.settings_at(last_position, cx);
1986
1987 if !settings.edit_predictions_disabled_in.is_empty()
1988 && let Some(scope) = snapshot.language_scope_at(last_position)
1989 && let Some(scope_name) = scope.override_name()
1990 && settings
1991 .edit_predictions_disabled_in
1992 .iter()
1993 .any(|s| s == scope_name)
1994 {
1995 return false;
1996 }
1997 }
1998
1999 true
2000 }
2001
2002 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2003}
2004
2005fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2006 let Some(app_state) = AppState::try_global(cx) else {
2007 return false;
2008 };
2009
2010 app_state
2011 .workspace_store
2012 .read(cx)
2013 .workspaces()
2014 .filter_map(|workspace| workspace.upgrade())
2015 .any(|workspace| {
2016 workspace.read(cx).project().entity_id() == project.entity_id()
2017 && workspace
2018 .read(cx)
2019 .leader_for_pane(workspace.read(cx).active_pane())
2020 .is_some()
2021 })
2022}
2023
2024fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2025 match provider {
2026 EditPredictionProvider::Zed
2027 | EditPredictionProvider::Mercury
2028 | EditPredictionProvider::Ollama
2029 | EditPredictionProvider::OpenAiCompatibleApi
2030 | EditPredictionProvider::Experimental(_) => true,
2031 EditPredictionProvider::None
2032 | EditPredictionProvider::Copilot
2033 | EditPredictionProvider::Codestral => false,
2034 }
2035}
2036
2037impl EditPredictionStore {
2038 fn queue_prediction_refresh(
2039 &mut self,
2040 project: Entity<Project>,
2041 request_trigger: PredictEditsRequestTrigger,
2042 throttle_entity: EntityId,
2043 cx: &mut Context<Self>,
2044 do_refresh: impl FnOnce(
2045 WeakEntity<Self>,
2046 &mut AsyncApp,
2047 )
2048 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2049 + 'static,
2050 ) {
2051 fn select_throttle(
2052 project_state: &mut ProjectState,
2053 request_trigger: PredictEditsRequestTrigger,
2054 ) -> &mut Option<(EntityId, Instant)> {
2055 match request_trigger {
2056 PredictEditsRequestTrigger::Diagnostics => {
2057 &mut project_state.last_jump_prediction_refresh
2058 }
2059 _ => &mut project_state.last_edit_prediction_refresh,
2060 }
2061 }
2062
2063 let (needs_acceptance_tracking, max_pending_predictions) =
2064 match all_language_settings(None, cx).edit_predictions.provider {
2065 EditPredictionProvider::Zed
2066 | EditPredictionProvider::Mercury
2067 | EditPredictionProvider::Experimental(_) => (true, 2),
2068 EditPredictionProvider::Ollama => (false, 1),
2069 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2070 EditPredictionProvider::None
2071 | EditPredictionProvider::Copilot
2072 | EditPredictionProvider::Codestral => {
2073 log::error!("queue_prediction_refresh called with non-store provider");
2074 return;
2075 }
2076 };
2077
2078 let drop_on_cancel = !needs_acceptance_tracking;
2079 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2080 let project_state = self.get_or_init_project(&project, cx);
2081 let pending_prediction_id = project_state.next_pending_prediction_id;
2082 project_state.next_pending_prediction_id += 1;
2083 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2084
2085 let task = cx.spawn(async move |this, cx| {
2086 let throttle_wait = this
2087 .update(cx, |this, cx| {
2088 let project_state = this.get_or_init_project(&project, cx);
2089 let throttle = *select_throttle(project_state, request_trigger);
2090
2091 let now = cx.background_executor().now();
2092 throttle.and_then(|(last_entity, last_timestamp)| {
2093 if throttle_entity != last_entity {
2094 return None;
2095 }
2096 (last_timestamp + throttle_timeout).checked_duration_since(now)
2097 })
2098 })
2099 .ok()
2100 .flatten();
2101
2102 if let Some(timeout) = throttle_wait {
2103 cx.background_executor().timer(timeout).await;
2104 }
2105
2106 // If this task was cancelled before the throttle timeout expired,
2107 // do not perform a request. Also skip if another task already
2108 // proceeded since we were enqueued (duplicate).
2109 let mut is_cancelled = true;
2110 this.update(cx, |this, cx| {
2111 let project_state = this.get_or_init_project(&project, cx);
2112 let was_cancelled = project_state
2113 .cancelled_predictions
2114 .remove(&pending_prediction_id);
2115 if was_cancelled {
2116 return;
2117 }
2118
2119 // Another request has been already sent since this was enqueued
2120 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2121 return;
2122 }
2123
2124 let new_refresh = (throttle_entity, cx.background_executor().now());
2125 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2126 is_cancelled = false;
2127 })
2128 .ok();
2129 if is_cancelled {
2130 return None;
2131 }
2132
2133 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2134 let new_prediction_id = new_prediction_result
2135 .as_ref()
2136 .map(|(prediction, _)| prediction.id.clone());
2137
2138 // When a prediction completes, remove it from the pending list, and cancel
2139 // any pending predictions that were enqueued before it.
2140 this.update(cx, |this, cx| {
2141 let project_state = this.get_or_init_project(&project, cx);
2142
2143 let is_cancelled = project_state
2144 .cancelled_predictions
2145 .remove(&pending_prediction_id);
2146
2147 let new_current_prediction = if !is_cancelled
2148 && let Some((prediction_result, requested_by)) = new_prediction_result
2149 {
2150 match prediction_result.prediction {
2151 Ok(prediction) => {
2152 let new_prediction = CurrentEditPrediction {
2153 requested_by,
2154 prediction,
2155 was_shown: false,
2156 shown_with: None,
2157 e2e_latency: prediction_result.e2e_latency,
2158 };
2159
2160 if let Some(current_prediction) =
2161 project_state.current_prediction.as_ref()
2162 {
2163 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2164 {
2165 this.reject_current_prediction(
2166 EditPredictionRejectReason::Replaced,
2167 &project,
2168 cx,
2169 );
2170
2171 Some(new_prediction)
2172 } else {
2173 this.reject_prediction(
2174 new_prediction.prediction.id,
2175 EditPredictionRejectReason::CurrentPreferred,
2176 false,
2177 new_prediction.prediction.model_version,
2178 Some(new_prediction.e2e_latency),
2179 cx,
2180 );
2181 None
2182 }
2183 } else {
2184 Some(new_prediction)
2185 }
2186 }
2187 Err(reject_reason) => {
2188 this.reject_prediction(
2189 prediction_result.id,
2190 reject_reason,
2191 false,
2192 None,
2193 Some(prediction_result.e2e_latency),
2194 cx,
2195 );
2196 None
2197 }
2198 }
2199 } else {
2200 None
2201 };
2202
2203 let project_state = this.get_or_init_project(&project, cx);
2204
2205 if let Some(new_prediction) = new_current_prediction {
2206 project_state.current_prediction = Some(new_prediction);
2207 }
2208
2209 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2210 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2211 if pending_prediction.id == pending_prediction_id {
2212 pending_predictions.remove(ix);
2213 for pending_prediction in pending_predictions.drain(0..ix) {
2214 project_state.cancel_pending_prediction(pending_prediction, cx)
2215 }
2216 break;
2217 }
2218 }
2219 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2220 cx.notify();
2221 })
2222 .ok();
2223
2224 new_prediction_id
2225 });
2226
2227 if project_state.pending_predictions.len() < max_pending_predictions {
2228 project_state
2229 .pending_predictions
2230 .push(PendingPrediction {
2231 id: pending_prediction_id,
2232 task,
2233 drop_on_cancel,
2234 })
2235 .unwrap();
2236 } else {
2237 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2238 project_state
2239 .pending_predictions
2240 .push(PendingPrediction {
2241 id: pending_prediction_id,
2242 task,
2243 drop_on_cancel,
2244 })
2245 .unwrap();
2246 project_state.cancel_pending_prediction(pending_prediction, cx);
2247 }
2248 }
2249
2250 pub fn request_prediction(
2251 &mut self,
2252 project: &Entity<Project>,
2253 active_buffer: &Entity<Buffer>,
2254 position: language::Anchor,
2255 trigger: PredictEditsRequestTrigger,
2256 cx: &mut Context<Self>,
2257 ) -> Task<Result<Option<EditPredictionResult>>> {
2258 self.request_prediction_internal(
2259 project.clone(),
2260 active_buffer.clone(),
2261 position,
2262 trigger,
2263 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2264 cx,
2265 )
2266 }
2267
2268 fn request_prediction_internal(
2269 &mut self,
2270 project: Entity<Project>,
2271 active_buffer: Entity<Buffer>,
2272 position: language::Anchor,
2273 trigger: PredictEditsRequestTrigger,
2274 allow_jump: bool,
2275 cx: &mut Context<Self>,
2276 ) -> Task<Result<Option<EditPredictionResult>>> {
2277 self.get_or_init_project(&project, cx);
2278 let project_state = self.projects.get(&project.entity_id()).unwrap();
2279 let stored_events = project_state.events(cx);
2280 let has_events = !stored_events.is_empty();
2281 let events: Vec<Arc<zeta_prompt::Event>> =
2282 stored_events.iter().map(|e| e.event.clone()).collect();
2283 let debug_tx = project_state.debug_tx.clone();
2284
2285 let snapshot = active_buffer.read(cx).snapshot();
2286 let cursor_point = position.to_point(&snapshot);
2287 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2288 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2289 let diagnostic_search_range =
2290 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2291
2292 let related_files = self.context_for_project(&project, cx);
2293
2294 let is_open_source = snapshot
2295 .file()
2296 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2297 && events.iter().all(|event| event.in_open_source_repo())
2298 && related_files.iter().all(|file| file.in_open_source_repo);
2299
2300 let can_collect_data = !cfg!(test)
2301 && is_open_source
2302 && self.is_data_collection_enabled(cx)
2303 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2304
2305 let inputs = EditPredictionModelInput {
2306 project: project.clone(),
2307 buffer: active_buffer,
2308 snapshot,
2309 position,
2310 events,
2311 related_files,
2312 trigger,
2313 diagnostic_search_range: diagnostic_search_range,
2314 debug_tx,
2315 can_collect_data,
2316 is_open_source,
2317 };
2318
2319 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2320
2321 let task = match self.edit_prediction_model {
2322 EditPredictionModel::Zeta => {
2323 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2324 }
2325 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2326 EditPredictionModel::Mercury => {
2327 self.mercury
2328 .request_prediction(inputs, self.credentials_provider.clone(), cx)
2329 }
2330 };
2331
2332 cx.spawn(async move |this, cx| {
2333 let prediction = task.await?;
2334
2335 // Only fall back to diagnostics-based prediction if we got a
2336 // the model had nothing to suggest for the buffer
2337 if prediction.is_none()
2338 && allow_jump
2339 && has_events
2340 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2341 {
2342 this.update(cx, |this, cx| {
2343 this.refresh_prediction_from_diagnostics(
2344 project,
2345 DiagnosticSearchScope::Local,
2346 cx,
2347 );
2348 })?;
2349 return anyhow::Ok(None);
2350 }
2351
2352 Ok(prediction)
2353 })
2354 }
2355
2356 pub(crate) async fn next_diagnostic_location(
2357 active_buffer: Entity<Buffer>,
2358 active_buffer_snapshot: &BufferSnapshot,
2359 active_buffer_diagnostic_search_range: Range<Point>,
2360 active_buffer_cursor_point: Point,
2361 project: &Entity<Project>,
2362 cx: &mut AsyncApp,
2363 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2364 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2365 .selections_in_range(
2366 Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2367 false,
2368 )
2369 .flat_map(|(_, _, _, selections)| {
2370 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2371 })
2372 .collect();
2373
2374 let mut jump_location = active_buffer_snapshot
2375 .diagnostic_groups(None)
2376 .into_iter()
2377 .filter_map(|(_, group)| {
2378 let range = &group.entries[group.primary_ix]
2379 .range
2380 .to_point(&active_buffer_snapshot);
2381 if range.overlaps(&active_buffer_diagnostic_search_range) {
2382 return None;
2383 }
2384 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2385 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2386 });
2387 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2388 <= DIAGNOSTIC_LINES_RANGE;
2389 if near_collaborator && !near_local {
2390 return None;
2391 }
2392 Some(range.start)
2393 })
2394 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2395 .map(|position| {
2396 (
2397 active_buffer.clone(),
2398 active_buffer_snapshot.anchor_before(position),
2399 )
2400 });
2401
2402 if jump_location.is_none() {
2403 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2404 let file = buffer.file()?;
2405
2406 Some(ProjectPath {
2407 worktree_id: file.worktree_id(cx),
2408 path: file.path().clone(),
2409 })
2410 });
2411
2412 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2413 project
2414 .diagnostic_summaries(false, cx)
2415 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2416 .map(|(path, _, _)| {
2417 let shared_prefix = path
2418 .path
2419 .components()
2420 .zip(
2421 active_buffer_path
2422 .as_ref()
2423 .map(|p| p.path.components())
2424 .unwrap_or_default(),
2425 )
2426 .take_while(|(a, b)| a == b)
2427 .count();
2428 (path, shared_prefix)
2429 })
2430 .collect()
2431 });
2432
2433 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2434
2435 for (path, _) in candidates {
2436 let candidate_buffer = project
2437 .update(cx, |project, cx| project.open_buffer(path, cx))
2438 .await?;
2439
2440 let (has_collaborators, diagnostic_position) =
2441 candidate_buffer.read_with(cx, |buffer, _cx| {
2442 let snapshot = buffer.snapshot();
2443 let has_collaborators = snapshot
2444 .selections_in_range(
2445 Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2446 false,
2447 )
2448 .next()
2449 .is_some();
2450 let position = buffer
2451 .buffer_diagnostics(None)
2452 .into_iter()
2453 .min_by_key(|entry| entry.diagnostic.severity)
2454 .map(|entry| entry.range.start);
2455 (has_collaborators, position)
2456 });
2457
2458 if has_collaborators {
2459 continue;
2460 }
2461
2462 if let Some(position) = diagnostic_position {
2463 jump_location = Some((candidate_buffer, position));
2464 break;
2465 }
2466 }
2467 }
2468
2469 anyhow::Ok(jump_location)
2470 }
2471
2472 async fn send_raw_llm_request(
2473 request: RawCompletionRequest,
2474 client: Arc<Client>,
2475 custom_url: Option<Arc<Url>>,
2476 llm_token: LlmApiToken,
2477 organization_id: Option<OrganizationId>,
2478 app_version: Version,
2479 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2480 let url = if let Some(custom_url) = custom_url {
2481 custom_url.as_ref().clone()
2482 } else {
2483 client
2484 .http_client()
2485 .build_zed_llm_url("/predict_edits/raw", &[])?
2486 };
2487
2488 Self::send_api_request(
2489 |builder| {
2490 let req = builder
2491 .uri(url.as_ref())
2492 .body(serde_json::to_string(&request)?.into());
2493 Ok(req?)
2494 },
2495 client,
2496 llm_token,
2497 organization_id,
2498 app_version,
2499 true,
2500 )
2501 .await
2502 }
2503
2504 pub(crate) async fn send_v3_request(
2505 input: ZetaPromptInput,
2506 client: Arc<Client>,
2507 llm_token: LlmApiToken,
2508 organization_id: Option<OrganizationId>,
2509 app_version: Version,
2510 trigger: PredictEditsRequestTrigger,
2511 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2512 let url = client
2513 .http_client()
2514 .build_zed_llm_url("/predict_edits/v3", &[])?;
2515
2516 let request = PredictEditsV3Request { input, trigger };
2517
2518 let json_bytes = serde_json::to_vec(&request)?;
2519 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2520
2521 Self::send_api_request(
2522 |builder| {
2523 let req = builder
2524 .uri(url.as_ref())
2525 .header("Content-Encoding", "zstd")
2526 .body(compressed.clone().into());
2527 Ok(req?)
2528 },
2529 client,
2530 llm_token,
2531 organization_id,
2532 app_version,
2533 true,
2534 )
2535 .await
2536 }
2537
2538 async fn send_api_request<Res>(
2539 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2540 client: Arc<Client>,
2541 llm_token: LlmApiToken,
2542 organization_id: Option<OrganizationId>,
2543 app_version: Version,
2544 require_auth: bool,
2545 ) -> Result<(Res, Option<EditPredictionUsage>)>
2546 where
2547 Res: DeserializeOwned,
2548 {
2549 let http_client = client.http_client();
2550 let mut token = if require_auth {
2551 Some(
2552 client
2553 .acquire_llm_token(&llm_token, organization_id.clone())
2554 .await?,
2555 )
2556 } else {
2557 client
2558 .acquire_llm_token(&llm_token, organization_id.clone())
2559 .await
2560 .ok()
2561 };
2562 let mut did_retry = false;
2563
2564 loop {
2565 let request_builder = http_client::Request::builder().method(Method::POST);
2566
2567 let mut request_builder = request_builder
2568 .header("Content-Type", "application/json")
2569 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2570
2571 // Only add Authorization header if we have a token
2572 if let Some(ref token_value) = token {
2573 request_builder =
2574 request_builder.header("Authorization", format!("Bearer {}", token_value));
2575 }
2576
2577 let request = build(request_builder)?;
2578
2579 let mut response = http_client.send(request).await?;
2580
2581 if let Some(minimum_required_version) = response
2582 .headers()
2583 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2584 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2585 {
2586 anyhow::ensure!(
2587 app_version >= minimum_required_version,
2588 ZedUpdateRequiredError {
2589 minimum_version: minimum_required_version
2590 }
2591 );
2592 }
2593
2594 if response.status().is_success() {
2595 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2596
2597 let mut body = Vec::new();
2598 response.body_mut().read_to_end(&mut body).await?;
2599 return Ok((serde_json::from_slice(&body)?, usage));
2600 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2601 did_retry = true;
2602 token = Some(
2603 client
2604 .refresh_llm_token(&llm_token, organization_id.clone())
2605 .await?,
2606 );
2607 } else {
2608 let mut body = String::new();
2609 response.body_mut().read_to_string(&mut body).await?;
2610 anyhow::bail!(
2611 "Request failed with status: {:?}\nBody: {}",
2612 response.status(),
2613 body
2614 );
2615 }
2616 }
2617 }
2618
2619 pub fn refresh_context(
2620 &mut self,
2621 project: &Entity<Project>,
2622 buffer: &Entity<language::Buffer>,
2623 cursor_position: language::Anchor,
2624 cx: &mut Context<Self>,
2625 ) {
2626 self.get_or_init_project(project, cx)
2627 .context
2628 .update(cx, |store, cx| {
2629 store.refresh(buffer.clone(), cursor_position, cx);
2630 });
2631 }
2632
2633 #[cfg(feature = "cli-support")]
2634 pub fn set_context_for_buffer(
2635 &mut self,
2636 project: &Entity<Project>,
2637 related_files: Vec<RelatedFile>,
2638 cx: &mut Context<Self>,
2639 ) {
2640 self.get_or_init_project(project, cx)
2641 .context
2642 .update(cx, |store, cx| {
2643 store.set_related_files(related_files, cx);
2644 });
2645 }
2646
2647 #[cfg(feature = "cli-support")]
2648 pub fn set_recent_paths_for_project(
2649 &mut self,
2650 project: &Entity<Project>,
2651 paths: impl IntoIterator<Item = project::ProjectPath>,
2652 cx: &mut Context<Self>,
2653 ) {
2654 let project_state = self.get_or_init_project(project, cx);
2655 project_state.recent_paths = paths.into_iter().collect();
2656 }
2657
2658 fn is_file_open_source(
2659 &self,
2660 project: &Entity<Project>,
2661 file: &Arc<dyn File>,
2662 cx: &App,
2663 ) -> bool {
2664 if !file.is_local() || file.is_private() {
2665 return false;
2666 }
2667 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2668 return false;
2669 };
2670 project_state
2671 .license_detection_watchers
2672 .get(&file.worktree_id(cx))
2673 .as_ref()
2674 .is_some_and(|watcher| watcher.is_project_open_source())
2675 }
2676
2677 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2678 self.data_collection_choice.is_enabled(cx)
2679 }
2680
2681 fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2682 let choice = KeyValueStore::global(cx)
2683 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2684 .log_err()
2685 .flatten();
2686
2687 match choice.as_deref() {
2688 Some("true") => DataCollectionChoice::Enabled,
2689 Some("false") => DataCollectionChoice::Disabled,
2690 Some(_) => {
2691 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2692 DataCollectionChoice::NotAnswered
2693 }
2694 None => DataCollectionChoice::NotAnswered,
2695 }
2696 }
2697
2698 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2699 self.data_collection_choice = self.data_collection_choice.toggle();
2700 let new_choice = self.data_collection_choice;
2701 let is_enabled = new_choice.is_enabled(cx);
2702 let kvp = KeyValueStore::global(cx);
2703 db::write_and_log(cx, move || async move {
2704 kvp.write_kvp(
2705 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2706 is_enabled.to_string(),
2707 )
2708 .await
2709 });
2710 }
2711
2712 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2713 self.shown_predictions.iter()
2714 }
2715
2716 pub fn shown_completions_len(&self) -> usize {
2717 self.shown_predictions.len()
2718 }
2719
2720 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2721 self.rated_predictions.contains(id)
2722 }
2723
2724 pub fn rate_prediction(
2725 &mut self,
2726 prediction: &EditPrediction,
2727 rating: EditPredictionRating,
2728 feedback: String,
2729 cx: &mut Context<Self>,
2730 ) {
2731 let organization = self.user_store.read(cx).current_organization();
2732
2733 self.rated_predictions.insert(prediction.id.clone());
2734
2735 cx.background_spawn({
2736 let client = self.client.clone();
2737 let prediction_id = prediction.id.to_string();
2738 let inputs = serde_json::to_value(&prediction.inputs);
2739 let output = prediction
2740 .edit_preview
2741 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2742 async move {
2743 client
2744 .cloud_client()
2745 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2746 organization_id: organization.map(|organization| organization.id.clone()),
2747 request_id: prediction_id,
2748 rating: match rating {
2749 EditPredictionRating::Positive => "positive".to_string(),
2750 EditPredictionRating::Negative => "negative".to_string(),
2751 },
2752 inputs: inputs?,
2753 output,
2754 feedback,
2755 })
2756 .await?;
2757
2758 anyhow::Ok(())
2759 }
2760 })
2761 .detach_and_log_err(cx);
2762
2763 cx.notify();
2764 }
2765}
2766
2767fn collaborator_edit_overlaps_locality_region(
2768 project_state: &ProjectState,
2769 project: &Entity<Project>,
2770 buffer: &Entity<Buffer>,
2771 snapshot: &BufferSnapshot,
2772 edit_range: &Range<Anchor>,
2773 cx: &App,
2774) -> bool {
2775 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2776 return false;
2777 };
2778
2779 if active_buffer.entity_id() != buffer.entity_id() {
2780 return false;
2781 }
2782
2783 let locality_point_range = expand_context_syntactically_then_linewise(
2784 snapshot,
2785 (position..position).to_point(snapshot),
2786 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2787 );
2788 let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2789
2790 edit_range.overlaps(&locality_anchor_range, snapshot)
2791}
2792
2793fn merge_trailing_events_if_needed(
2794 events: &mut VecDeque<StoredEvent>,
2795 end_snapshot: &TextBufferSnapshot,
2796 latest_snapshot: &TextBufferSnapshot,
2797 latest_edit_range: &Range<Anchor>,
2798) {
2799 if let Some(last_event) = events.back() {
2800 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2801 return;
2802 }
2803 if !latest_snapshot
2804 .version
2805 .observed_all(&last_event.new_snapshot_version)
2806 {
2807 return;
2808 }
2809 }
2810
2811 let mut next_old_event = None;
2812 let mut mergeable_count = 0;
2813 for old_event in events.iter().rev() {
2814 if let Some(next_old_event) = next_old_event
2815 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2816 {
2817 break;
2818 }
2819 mergeable_count += 1;
2820 next_old_event = Some(old_event);
2821 }
2822
2823 if mergeable_count <= 1 {
2824 return;
2825 }
2826
2827 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2828 let oldest_event = events_to_merge.peek().unwrap();
2829 let oldest_snapshot = oldest_event.old_snapshot.clone();
2830 let newest_snapshot = end_snapshot;
2831 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2832
2833 for event in events.range(events.len() - mergeable_count + 1..) {
2834 merged_edit_range =
2835 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2836 }
2837
2838 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2839 &oldest_snapshot,
2840 newest_snapshot,
2841 &merged_edit_range,
2842 ) {
2843 let merged_event = match oldest_event.event.as_ref() {
2844 zeta_prompt::Event::BufferChange {
2845 old_path,
2846 path,
2847 in_open_source_repo,
2848 ..
2849 } => StoredEvent {
2850 event: Arc::new(zeta_prompt::Event::BufferChange {
2851 old_path: old_path.clone(),
2852 path: path.clone(),
2853 diff,
2854 in_open_source_repo: *in_open_source_repo,
2855 predicted: events_to_merge.all(|e| {
2856 matches!(
2857 e.event.as_ref(),
2858 zeta_prompt::Event::BufferChange {
2859 predicted: true,
2860 ..
2861 }
2862 )
2863 }),
2864 }),
2865 old_snapshot: oldest_snapshot.clone(),
2866 new_snapshot_version: newest_snapshot.version.clone(),
2867 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2868 ..newest_snapshot.anchor_before(edit_range.end),
2869 },
2870 };
2871 events.truncate(events.len() - mergeable_count);
2872 events.push_back(merged_event);
2873 }
2874}
2875
2876fn merge_anchor_ranges(
2877 left: &Range<Anchor>,
2878 right: &Range<Anchor>,
2879 snapshot: &TextBufferSnapshot,
2880) -> Range<Anchor> {
2881 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2882 left.start
2883 } else {
2884 right.start
2885 };
2886 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2887 left.end
2888 } else {
2889 right.end
2890 };
2891 start..end
2892}
2893
2894#[derive(Error, Debug)]
2895#[error(
2896 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2897)]
2898pub struct ZedUpdateRequiredError {
2899 minimum_version: Version,
2900}
2901
2902#[derive(Debug, Clone, Copy)]
2903pub enum DataCollectionChoice {
2904 NotAnswered,
2905 Enabled,
2906 Disabled,
2907}
2908
2909impl DataCollectionChoice {
2910 pub fn is_enabled(self, cx: &App) -> bool {
2911 if cx.is_staff() {
2912 return true;
2913 }
2914 match self {
2915 Self::Enabled => true,
2916 Self::NotAnswered | Self::Disabled => false,
2917 }
2918 }
2919
2920 #[must_use]
2921 pub fn toggle(&self) -> DataCollectionChoice {
2922 match self {
2923 Self::Enabled => Self::Disabled,
2924 Self::Disabled => Self::Enabled,
2925 Self::NotAnswered => Self::Enabled,
2926 }
2927 }
2928}
2929
2930impl From<bool> for DataCollectionChoice {
2931 fn from(value: bool) -> Self {
2932 match value {
2933 true => DataCollectionChoice::Enabled,
2934 false => DataCollectionChoice::Disabled,
2935 }
2936 }
2937}
2938
2939struct ZedPredictUpsell;
2940
2941impl Dismissable for ZedPredictUpsell {
2942 const KEY: &'static str = "dismissed-edit-predict-upsell";
2943
2944 fn dismissed(cx: &App) -> bool {
2945 // To make this backwards compatible with older versions of Zed, we
2946 // check if the user has seen the previous Edit Prediction Onboarding
2947 // before, by checking the data collection choice which was written to
2948 // the database once the user clicked on "Accept and Enable"
2949 let kvp = KeyValueStore::global(cx);
2950 if kvp
2951 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2952 .log_err()
2953 .is_some_and(|s| s.is_some())
2954 {
2955 return true;
2956 }
2957
2958 kvp.read_kvp(Self::KEY)
2959 .log_err()
2960 .is_some_and(|s| s.is_some())
2961 }
2962}
2963
2964pub fn should_show_upsell_modal(cx: &App) -> bool {
2965 !ZedPredictUpsell::dismissed(cx)
2966}
2967
2968pub fn init(cx: &mut App) {
2969 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2970 workspace.register_action(
2971 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2972 ZedPredictModal::toggle(
2973 workspace,
2974 workspace.user_store().clone(),
2975 workspace.client().clone(),
2976 window,
2977 cx,
2978 )
2979 },
2980 );
2981
2982 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2983 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2984 settings
2985 .project
2986 .all_languages
2987 .edit_predictions
2988 .get_or_insert_default()
2989 .provider = Some(EditPredictionProvider::None)
2990 });
2991 });
2992 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2993 EditPredictionStore::try_global(cx).and_then(|store| {
2994 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2995 })
2996 }
2997
2998 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2999 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3000 copilot_ui::initiate_sign_in(copilot, window, cx);
3001 }
3002 });
3003 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3004 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3005 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3006 }
3007 });
3008 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3009 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3010 copilot_ui::initiate_sign_out(copilot, window, cx);
3011 }
3012 });
3013 })
3014 .detach();
3015}