1use anyhow::Result;
2use client::{Client, EditPredictionUsage, UserStore};
3use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KeyValueStore};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use heapless::Vec as ArrayVec;
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::{AppState, Workspace};
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54
55use thiserror::Error;
56use util::{RangeExt as _, ResultExt as _};
57
58pub mod cursor_excerpt;
59pub mod example_spec;
60pub mod fim;
61mod license_detection;
62pub mod mercury;
63pub mod ollama;
64mod onboarding_modal;
65pub mod open_ai_response;
66mod prediction;
67
68pub mod udiff;
69
70mod capture_example;
71pub mod open_ai_compatible;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
79use crate::example_spec::ExampleSpec;
80use crate::license_detection::LicenseDetectionWatcher;
81use crate::mercury::Mercury;
82use crate::onboarding_modal::ZedPredictModal;
83pub use crate::prediction::EditPrediction;
84pub use crate::prediction::EditPredictionId;
85use crate::prediction::EditPredictionResult;
86pub use capture_example::capture_example;
87pub use language_model::ApiKeyState;
88pub use telemetry_events::EditPredictionRating;
89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
90
91actions!(
92 edit_prediction,
93 [
94 /// Resets the edit prediction onboarding state.
95 ResetOnboarding,
96 /// Clears the edit prediction history.
97 ClearHistory,
98 ]
99);
100
101/// Maximum number of events to track.
102const EVENT_COUNT_MAX: usize = 10;
103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
104const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
105const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
106const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
107const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
108const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
109const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
110const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
111const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
112
113pub struct EditPredictionJumpsFeatureFlag;
114
115impl FeatureFlag for EditPredictionJumpsFeatureFlag {
116 const NAME: &'static str = "edit_prediction_jumps";
117}
118
119#[derive(Clone)]
120struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
121
122impl Global for EditPredictionStoreGlobal {}
123
124/// Configuration for using the raw Zeta2 endpoint.
125/// When set, the client uses the raw endpoint and constructs the prompt itself.
126/// The version is also used as the Baseten environment name (lowercased).
127#[derive(Clone)]
128pub struct Zeta2RawConfig {
129 pub model_id: Option<String>,
130 pub environment: Option<String>,
131 pub format: ZetaFormat,
132}
133
134pub struct EditPredictionStore {
135 client: Arc<Client>,
136 user_store: Entity<UserStore>,
137 llm_token: LlmApiToken,
138 _fetch_experiments_task: Task<()>,
139 projects: HashMap<EntityId, ProjectState>,
140 update_required: bool,
141 edit_prediction_model: EditPredictionModel,
142 zeta2_raw_config: Option<Zeta2RawConfig>,
143 preferred_experiment: Option<String>,
144 available_experiments: Vec<String>,
145 pub mercury: Mercury,
146 data_collection_choice: DataCollectionChoice,
147 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
148 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
149 shown_predictions: VecDeque<EditPrediction>,
150 rated_predictions: HashSet<EditPredictionId>,
151 #[cfg(test)]
152 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
153}
154
155pub(crate) struct EditPredictionRejectionPayload {
156 rejection: EditPredictionRejection,
157 organization_id: Option<OrganizationId>,
158}
159
160#[derive(Copy, Clone, PartialEq, Eq)]
161pub enum EditPredictionModel {
162 Zeta,
163 Fim { format: EditPredictionPromptFormat },
164 Mercury,
165}
166
167#[derive(Clone)]
168pub struct EditPredictionModelInput {
169 project: Entity<Project>,
170 buffer: Entity<Buffer>,
171 snapshot: BufferSnapshot,
172 position: Anchor,
173 events: Vec<Arc<zeta_prompt::Event>>,
174 related_files: Vec<RelatedFile>,
175 trigger: PredictEditsRequestTrigger,
176 diagnostic_search_range: Range<Point>,
177 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
178 can_collect_data: bool,
179 is_open_source: bool,
180}
181
182#[derive(Debug)]
183pub enum DebugEvent {
184 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
185 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
186 EditPredictionStarted(EditPredictionStartedDebugEvent),
187 EditPredictionFinished(EditPredictionFinishedDebugEvent),
188}
189
190#[derive(Debug)]
191pub struct ContextRetrievalStartedDebugEvent {
192 pub project_entity_id: EntityId,
193 pub timestamp: Instant,
194 pub search_prompt: String,
195}
196
197#[derive(Debug)]
198pub struct ContextRetrievalFinishedDebugEvent {
199 pub project_entity_id: EntityId,
200 pub timestamp: Instant,
201 pub metadata: Vec<(&'static str, SharedString)>,
202}
203
204#[derive(Debug)]
205pub struct EditPredictionStartedDebugEvent {
206 pub buffer: WeakEntity<Buffer>,
207 pub position: Anchor,
208 pub prompt: Option<String>,
209}
210
211#[derive(Debug)]
212pub struct EditPredictionFinishedDebugEvent {
213 pub buffer: WeakEntity<Buffer>,
214 pub position: Anchor,
215 pub model_output: Option<String>,
216}
217
218/// An event with associated metadata for reconstructing buffer state.
219#[derive(Clone)]
220pub struct StoredEvent {
221 pub event: Arc<zeta_prompt::Event>,
222 pub old_snapshot: TextBufferSnapshot,
223 pub new_snapshot_version: clock::Global,
224 pub total_edit_range: Range<Anchor>,
225}
226
227impl StoredEvent {
228 fn can_merge(
229 &self,
230 next_old_event: &StoredEvent,
231 latest_snapshot: &TextBufferSnapshot,
232 latest_edit_range: &Range<Anchor>,
233 ) -> bool {
234 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
235 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
236 return false;
237 }
238 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
239 return false;
240 }
241 if self.new_snapshot_version != next_old_event.old_snapshot.version {
242 return false;
243 }
244 if !latest_snapshot
245 .version
246 .observed_all(&next_old_event.new_snapshot_version)
247 {
248 return false;
249 }
250
251 let a_is_predicted = matches!(
252 self.event.as_ref(),
253 zeta_prompt::Event::BufferChange {
254 predicted: true,
255 ..
256 }
257 );
258 let b_is_predicted = matches!(
259 next_old_event.event.as_ref(),
260 zeta_prompt::Event::BufferChange {
261 predicted: true,
262 ..
263 }
264 );
265
266 // If events come from the same source (both predicted or both manual) then
267 // we would have coalesced them already.
268 if a_is_predicted == b_is_predicted {
269 return false;
270 }
271
272 let left_range = self.total_edit_range.to_point(latest_snapshot);
273 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
274 let latest_range = latest_edit_range.to_point(latest_snapshot);
275
276 // Events near to the latest edit are not merged if their sources differ.
277 if lines_between_ranges(&left_range, &latest_range)
278 .min(lines_between_ranges(&right_range, &latest_range))
279 <= CHANGE_GROUPING_LINE_SPAN
280 {
281 return false;
282 }
283
284 // Events that are distant from each other are not merged.
285 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
286 return false;
287 }
288
289 true
290 }
291}
292
293fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
294 if left.start > right.end {
295 return left.start.row - right.end.row;
296 }
297 if right.start > left.end {
298 return right.start.row - left.end.row;
299 }
300 0
301}
302
303struct ProjectState {
304 events: VecDeque<StoredEvent>,
305 last_event: Option<LastEvent>,
306 recent_paths: VecDeque<ProjectPath>,
307 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
308 current_prediction: Option<CurrentEditPrediction>,
309 next_pending_prediction_id: usize,
310 pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
311 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
312 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
313 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
314 cancelled_predictions: HashSet<usize>,
315 context: Entity<RelatedExcerptStore>,
316 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
317 _subscriptions: [gpui::Subscription; 2],
318 copilot: Option<Entity<Copilot>>,
319}
320
321impl ProjectState {
322 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
323 self.events
324 .iter()
325 .cloned()
326 .chain(self.last_event.as_ref().iter().flat_map(|event| {
327 let (one, two) = event.split_by_pause();
328 let one = one.finalize(&self.license_detection_watchers, cx);
329 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
330 one.into_iter().chain(two)
331 }))
332 .collect()
333 }
334
335 fn cancel_pending_prediction(
336 &mut self,
337 pending_prediction: PendingPrediction,
338 cx: &mut Context<EditPredictionStore>,
339 ) {
340 self.cancelled_predictions.insert(pending_prediction.id);
341
342 if pending_prediction.drop_on_cancel {
343 drop(pending_prediction.task);
344 } else {
345 cx.spawn(async move |this, cx| {
346 let Some(prediction_id) = pending_prediction.task.await else {
347 return;
348 };
349
350 this.update(cx, |this, cx| {
351 this.reject_prediction(
352 prediction_id,
353 EditPredictionRejectReason::Canceled,
354 false,
355 None,
356 None,
357 cx,
358 );
359 })
360 .ok();
361 })
362 .detach()
363 }
364 }
365
366 fn active_buffer(
367 &self,
368 project: &Entity<Project>,
369 cx: &App,
370 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
371 let project = project.read(cx);
372 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
373 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
374 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
375 Some((active_buffer, registered_buffer.last_position))
376 }
377}
378
379#[derive(Debug, Clone)]
380struct CurrentEditPrediction {
381 pub requested_by: PredictionRequestedBy,
382 pub prediction: EditPrediction,
383 pub was_shown: bool,
384 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
385 pub e2e_latency: std::time::Duration,
386}
387
388impl CurrentEditPrediction {
389 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
390 let Some(new_edits) = self
391 .prediction
392 .interpolate(&self.prediction.buffer.read(cx))
393 else {
394 return false;
395 };
396
397 if self.prediction.buffer != old_prediction.prediction.buffer {
398 return true;
399 }
400
401 let Some(old_edits) = old_prediction
402 .prediction
403 .interpolate(&old_prediction.prediction.buffer.read(cx))
404 else {
405 return true;
406 };
407
408 let requested_by_buffer_id = self.requested_by.buffer_id();
409
410 // This reduces the occurrence of UI thrash from replacing edits
411 //
412 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
413 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
414 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
415 && old_edits.len() == 1
416 && new_edits.len() == 1
417 {
418 let (old_range, old_text) = &old_edits[0];
419 let (new_range, new_text) = &new_edits[0];
420 new_range == old_range && new_text.starts_with(old_text.as_ref())
421 } else {
422 true
423 }
424 }
425}
426
427#[derive(Debug, Clone)]
428enum PredictionRequestedBy {
429 DiagnosticsUpdate,
430 Buffer(EntityId),
431}
432
433impl PredictionRequestedBy {
434 pub fn buffer_id(&self) -> Option<EntityId> {
435 match self {
436 PredictionRequestedBy::DiagnosticsUpdate => None,
437 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
438 }
439 }
440}
441
442const DIAGNOSTIC_LINES_RANGE: u32 = 20;
443
444#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
445pub enum DiagnosticSearchScope {
446 Local,
447 Global,
448}
449
450#[derive(Debug)]
451struct PendingPrediction {
452 id: usize,
453 task: Task<Option<EditPredictionId>>,
454 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
455 /// If false, the task is awaited to completion so rejection can be reported.
456 drop_on_cancel: bool,
457}
458
459/// A prediction from the perspective of a buffer.
460#[derive(Debug)]
461enum BufferEditPrediction<'a> {
462 Local { prediction: &'a EditPrediction },
463 Jump { prediction: &'a EditPrediction },
464}
465
466#[cfg(test)]
467impl std::ops::Deref for BufferEditPrediction<'_> {
468 type Target = EditPrediction;
469
470 fn deref(&self) -> &Self::Target {
471 match self {
472 BufferEditPrediction::Local { prediction } => prediction,
473 BufferEditPrediction::Jump { prediction } => prediction,
474 }
475 }
476}
477
478#[derive(Clone)]
479
480struct PendingSettledPrediction {
481 request_id: EditPredictionId,
482 editable_anchor_range: Range<Anchor>,
483 example: Option<ExampleSpec>,
484 enqueued_at: Instant,
485 last_edit_at: Instant,
486 e2e_latency: std::time::Duration,
487}
488
489struct RegisteredBuffer {
490 file: Option<Arc<dyn File>>,
491 snapshot: TextBufferSnapshot,
492 pending_predictions: Vec<PendingSettledPrediction>,
493 last_position: Option<Anchor>,
494 _subscriptions: [gpui::Subscription; 2],
495}
496
497#[derive(Clone)]
498struct LastEvent {
499 old_snapshot: TextBufferSnapshot,
500 new_snapshot: TextBufferSnapshot,
501 old_file: Option<Arc<dyn File>>,
502 new_file: Option<Arc<dyn File>>,
503 latest_edit_range: Range<Anchor>,
504 total_edit_range: Range<Anchor>,
505 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
506 predicted: bool,
507 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
508 last_edit_time: Option<Instant>,
509}
510
511impl LastEvent {
512 pub fn finalize(
513 &self,
514 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
515 cx: &App,
516 ) -> Option<StoredEvent> {
517 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
518 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
519
520 let in_open_source_repo =
521 [self.new_file.as_ref(), self.old_file.as_ref()]
522 .iter()
523 .all(|file| {
524 file.is_some_and(|file| {
525 license_detection_watchers
526 .get(&file.worktree_id(cx))
527 .is_some_and(|watcher| watcher.is_project_open_source())
528 })
529 });
530
531 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
532 &self.old_snapshot,
533 &self.new_snapshot,
534 &self.total_edit_range,
535 )?;
536
537 if path == old_path && diff.is_empty() {
538 None
539 } else {
540 Some(StoredEvent {
541 event: Arc::new(zeta_prompt::Event::BufferChange {
542 old_path,
543 path,
544 diff,
545 in_open_source_repo,
546 predicted: self.predicted,
547 }),
548 old_snapshot: self.old_snapshot.clone(),
549 new_snapshot_version: self.new_snapshot.version.clone(),
550 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
551 ..self.new_snapshot.anchor_before(edit_range.end),
552 })
553 }
554 }
555
556 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
557 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
558 return (self.clone(), None);
559 };
560
561 let total_edit_range_before_pause = self
562 .total_edit_range_at_last_pause_boundary
563 .clone()
564 .unwrap_or_else(|| self.total_edit_range.clone());
565
566 let Some(total_edit_range_after_pause) =
567 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
568 else {
569 return (self.clone(), None);
570 };
571
572 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
573 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
574
575 let before = LastEvent {
576 old_snapshot: self.old_snapshot.clone(),
577 new_snapshot: boundary_snapshot.clone(),
578 old_file: self.old_file.clone(),
579 new_file: self.new_file.clone(),
580 latest_edit_range: latest_edit_range_before_pause,
581 total_edit_range: total_edit_range_before_pause,
582 total_edit_range_at_last_pause_boundary: None,
583 predicted: self.predicted,
584 snapshot_after_last_editing_pause: None,
585 last_edit_time: self.last_edit_time,
586 };
587
588 let after = LastEvent {
589 old_snapshot: boundary_snapshot.clone(),
590 new_snapshot: self.new_snapshot.clone(),
591 old_file: self.old_file.clone(),
592 new_file: self.new_file.clone(),
593 latest_edit_range: latest_edit_range_after_pause,
594 total_edit_range: total_edit_range_after_pause,
595 total_edit_range_at_last_pause_boundary: None,
596 predicted: self.predicted,
597 snapshot_after_last_editing_pause: None,
598 last_edit_time: self.last_edit_time,
599 };
600
601 (before, Some(after))
602 }
603}
604
605fn compute_total_edit_range_between_snapshots(
606 old_snapshot: &TextBufferSnapshot,
607 new_snapshot: &TextBufferSnapshot,
608) -> Option<Range<Anchor>> {
609 let edits: Vec<Edit<usize>> = new_snapshot
610 .edits_since::<usize>(&old_snapshot.version)
611 .collect();
612
613 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
614 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
615 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
616
617 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
618}
619
620fn compute_old_range_for_new_range(
621 old_snapshot: &TextBufferSnapshot,
622 new_snapshot: &TextBufferSnapshot,
623 total_edit_range: &Range<Anchor>,
624) -> Option<Range<Point>> {
625 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
626 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
627
628 let edits: Vec<Edit<usize>> = new_snapshot
629 .edits_since::<usize>(&old_snapshot.version)
630 .collect();
631 let mut old_start_offset = None;
632 let mut old_end_offset = None;
633 let mut delta: isize = 0;
634
635 for edit in &edits {
636 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
637 old_start_offset = Some(if new_start_offset < edit.new.start {
638 new_start_offset.checked_add_signed(-delta)?
639 } else {
640 edit.old.start
641 });
642 }
643
644 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
645 old_end_offset = Some(if new_end_offset < edit.new.start {
646 new_end_offset.checked_add_signed(-delta)?
647 } else {
648 edit.old.end
649 });
650 }
651
652 delta += edit.new.len() as isize - edit.old.len() as isize;
653 }
654
655 let old_start_offset =
656 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
657 let old_end_offset =
658 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
659
660 Some(
661 old_snapshot.offset_to_point(old_start_offset)
662 ..old_snapshot.offset_to_point(old_end_offset),
663 )
664}
665
666fn compute_diff_between_snapshots_in_range(
667 old_snapshot: &TextBufferSnapshot,
668 new_snapshot: &TextBufferSnapshot,
669 total_edit_range: &Range<Anchor>,
670) -> Option<(String, Range<Point>)> {
671 let new_start_point = total_edit_range.start.to_point(new_snapshot);
672 let new_end_point = total_edit_range.end.to_point(new_snapshot);
673 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
674 let old_start_point = old_range.start;
675 let old_end_point = old_range.end;
676
677 const CONTEXT_LINES: u32 = 3;
678
679 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
680 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
681 let old_context_end_row =
682 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
683 let new_context_end_row =
684 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
685
686 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
687 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
688 let old_end_line_offset = old_snapshot
689 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
690 let new_end_line_offset = new_snapshot
691 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
692 let old_edit_range = old_start_line_offset..old_end_line_offset;
693 let new_edit_range = new_start_line_offset..new_end_line_offset;
694
695 if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
696 || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
697 {
698 return None;
699 }
700
701 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
702 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
703
704 let diff = language::unified_diff_with_offsets(
705 &old_region_text,
706 &new_region_text,
707 old_context_start_row,
708 new_context_start_row,
709 );
710
711 Some((diff, new_start_point..new_end_point))
712}
713
714fn buffer_path_with_id_fallback(
715 file: Option<&Arc<dyn File>>,
716 snapshot: &TextBufferSnapshot,
717 cx: &App,
718) -> Arc<Path> {
719 if let Some(file) = file {
720 file.full_path(cx).into()
721 } else {
722 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
723 }
724}
725
726impl EditPredictionStore {
727 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
728 cx.try_global::<EditPredictionStoreGlobal>()
729 .map(|global| global.0.clone())
730 }
731
732 pub fn global(
733 client: &Arc<Client>,
734 user_store: &Entity<UserStore>,
735 cx: &mut App,
736 ) -> Entity<Self> {
737 cx.try_global::<EditPredictionStoreGlobal>()
738 .map(|global| global.0.clone())
739 .unwrap_or_else(|| {
740 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
741 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
742 ep_store
743 })
744 }
745
746 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
747 let data_collection_choice = Self::load_data_collection_choice(cx);
748
749 let llm_token = LlmApiToken::global(cx);
750
751 let (reject_tx, reject_rx) = mpsc::unbounded();
752 cx.background_spawn({
753 let client = client.clone();
754 let llm_token = llm_token.clone();
755 let app_version = AppVersion::global(cx);
756 let background_executor = cx.background_executor().clone();
757 async move {
758 Self::handle_rejected_predictions(
759 reject_rx,
760 client,
761 llm_token,
762 app_version,
763 background_executor,
764 )
765 .await
766 }
767 })
768 .detach();
769
770 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
771 cx.spawn(async move |this, cx| {
772 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
773 })
774 .detach();
775
776 let mut current_user = user_store.read(cx).watch_current_user();
777 let fetch_experiments_task = cx.spawn(async move |this, cx| {
778 while current_user.borrow().is_none() {
779 current_user.next().await;
780 }
781
782 this.update(cx, |this, cx| {
783 if cx.is_staff() {
784 this.refresh_available_experiments(cx);
785 }
786 })
787 .log_err();
788 });
789
790 let this = Self {
791 projects: HashMap::default(),
792 client,
793 user_store,
794 llm_token,
795 _fetch_experiments_task: fetch_experiments_task,
796 update_required: false,
797 edit_prediction_model: EditPredictionModel::Zeta,
798 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
799 preferred_experiment: None,
800 available_experiments: Vec::new(),
801 mercury: Mercury::new(cx),
802
803 data_collection_choice,
804 reject_predictions_tx: reject_tx,
805 settled_predictions_tx,
806 rated_predictions: Default::default(),
807 shown_predictions: Default::default(),
808 #[cfg(test)]
809 settled_event_callback: None,
810 };
811
812 this
813 }
814
815 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
816 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
817 let format = ZetaFormat::parse(&version_str).ok()?;
818 let model_id = env::var("ZED_ZETA_MODEL").ok();
819 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
820 Some(Zeta2RawConfig {
821 model_id,
822 environment,
823 format,
824 })
825 }
826
827 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
828 self.edit_prediction_model = model;
829 }
830
831 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
832 self.zeta2_raw_config = Some(config);
833 }
834
835 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
836 self.zeta2_raw_config.as_ref()
837 }
838
839 pub fn preferred_experiment(&self) -> Option<&str> {
840 self.preferred_experiment.as_deref()
841 }
842
843 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
844 self.preferred_experiment = experiment;
845 }
846
847 pub fn available_experiments(&self) -> &[String] {
848 &self.available_experiments
849 }
850
851 pub fn active_experiment(&self) -> Option<&str> {
852 self.preferred_experiment.as_deref().or_else(|| {
853 self.shown_predictions
854 .iter()
855 .find_map(|p| p.model_version.as_ref())
856 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
857 })
858 }
859
860 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
861 let client = self.client.clone();
862 let llm_token = self.llm_token.clone();
863 let app_version = AppVersion::global(cx);
864 let organization_id = self
865 .user_store
866 .read(cx)
867 .current_organization()
868 .map(|organization| organization.id.clone());
869
870 cx.spawn(async move |this, cx| {
871 let experiments = cx
872 .background_spawn(async move {
873 let http_client = client.http_client();
874 let token = llm_token.acquire(&client, organization_id).await?;
875 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
876 let request = http_client::Request::builder()
877 .method(Method::GET)
878 .uri(url.as_ref())
879 .header("Authorization", format!("Bearer {}", token))
880 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
881 .body(Default::default())?;
882 let mut response = http_client.send(request).await?;
883 if response.status().is_success() {
884 let mut body = Vec::new();
885 response.body_mut().read_to_end(&mut body).await?;
886 let experiments: Vec<String> = serde_json::from_slice(&body)?;
887 Ok(experiments)
888 } else {
889 let mut body = String::new();
890 response.body_mut().read_to_string(&mut body).await?;
891 anyhow::bail!(
892 "Failed to fetch experiments: {:?}\nBody: {}",
893 response.status(),
894 body
895 );
896 }
897 })
898 .await?;
899 this.update(cx, |this, cx| {
900 this.available_experiments = experiments;
901 cx.notify();
902 })?;
903 anyhow::Ok(())
904 })
905 .detach_and_log_err(cx);
906 }
907
908 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
909 use ui::IconName;
910 match self.edit_prediction_model {
911 EditPredictionModel::Mercury => {
912 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
913 }
914 EditPredictionModel::Zeta => {
915 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
916 .with_disabled(IconName::ZedPredictDisabled)
917 .with_up(IconName::ZedPredictUp)
918 .with_down(IconName::ZedPredictDown)
919 .with_error(IconName::ZedPredictError)
920 }
921 EditPredictionModel::Fim { .. } => {
922 let settings = &all_language_settings(None, cx).edit_predictions;
923 match settings.provider {
924 EditPredictionProvider::Ollama => {
925 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
926 }
927 _ => {
928 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
929 }
930 }
931 }
932 }
933 }
934
935 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
936 self.mercury.api_token.read(cx).has_key()
937 }
938
939 pub fn mercury_has_payment_required_error(&self) -> bool {
940 self.mercury.has_payment_required_error()
941 }
942
943 pub fn clear_history(&mut self) {
944 for project_state in self.projects.values_mut() {
945 project_state.events.clear();
946 project_state.last_event.take();
947 }
948 }
949
950 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
951 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
952 project_state.events.clear();
953 project_state.last_event.take();
954 }
955 }
956
957 pub fn edit_history_for_project(
958 &self,
959 project: &Entity<Project>,
960 cx: &App,
961 ) -> Vec<StoredEvent> {
962 self.projects
963 .get(&project.entity_id())
964 .map(|project_state| project_state.events(cx))
965 .unwrap_or_default()
966 }
967
968 pub fn context_for_project<'a>(
969 &'a self,
970 project: &Entity<Project>,
971 cx: &'a mut App,
972 ) -> Vec<RelatedFile> {
973 self.projects
974 .get(&project.entity_id())
975 .map(|project_state| {
976 project_state.context.update(cx, |context, cx| {
977 context
978 .related_files_with_buffers(cx)
979 .map(|(mut related_file, buffer)| {
980 related_file.in_open_source_repo = buffer
981 .read(cx)
982 .file()
983 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
984 related_file
985 })
986 .collect()
987 })
988 })
989 .unwrap_or_default()
990 }
991
992 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
993 self.projects
994 .get(&project.entity_id())
995 .and_then(|project| project.copilot.clone())
996 }
997
998 pub fn start_copilot_for_project(
999 &mut self,
1000 project: &Entity<Project>,
1001 cx: &mut Context<Self>,
1002 ) -> Option<Entity<Copilot>> {
1003 if DisableAiSettings::get(None, cx).disable_ai {
1004 return None;
1005 }
1006 let state = self.get_or_init_project(project, cx);
1007
1008 if state.copilot.is_some() {
1009 return state.copilot.clone();
1010 }
1011 let _project = project.clone();
1012 let project = project.read(cx);
1013
1014 let node = project.node_runtime().cloned();
1015 if let Some(node) = node {
1016 let next_id = project.languages().next_language_server_id();
1017 let fs = project.fs().clone();
1018
1019 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1020 state.copilot = Some(copilot.clone());
1021 Some(copilot)
1022 } else {
1023 None
1024 }
1025 }
1026
1027 pub fn context_for_project_with_buffers<'a>(
1028 &'a self,
1029 project: &Entity<Project>,
1030 cx: &'a mut App,
1031 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1032 self.projects
1033 .get(&project.entity_id())
1034 .map(|project| {
1035 project.context.update(cx, |context, cx| {
1036 context.related_files_with_buffers(cx).collect()
1037 })
1038 })
1039 .unwrap_or_default()
1040 }
1041
1042 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1043 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1044 self.user_store.read(cx).edit_prediction_usage()
1045 } else {
1046 None
1047 }
1048 }
1049
1050 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1051 self.get_or_init_project(project, cx);
1052 }
1053
1054 pub fn register_buffer(
1055 &mut self,
1056 buffer: &Entity<Buffer>,
1057 project: &Entity<Project>,
1058 cx: &mut Context<Self>,
1059 ) {
1060 let project_state = self.get_or_init_project(project, cx);
1061 Self::register_buffer_impl(project_state, buffer, project, cx);
1062 }
1063
1064 fn get_or_init_project(
1065 &mut self,
1066 project: &Entity<Project>,
1067 cx: &mut Context<Self>,
1068 ) -> &mut ProjectState {
1069 let entity_id = project.entity_id();
1070 self.projects
1071 .entry(entity_id)
1072 .or_insert_with(|| ProjectState {
1073 context: {
1074 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1075 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1076 this.handle_excerpt_store_event(entity_id, event);
1077 })
1078 .detach();
1079 related_excerpt_store
1080 },
1081 events: VecDeque::new(),
1082 last_event: None,
1083 recent_paths: VecDeque::new(),
1084 debug_tx: None,
1085 registered_buffers: HashMap::default(),
1086 current_prediction: None,
1087 cancelled_predictions: HashSet::default(),
1088 pending_predictions: ArrayVec::new(),
1089 next_pending_prediction_id: 0,
1090 last_edit_prediction_refresh: None,
1091 last_jump_prediction_refresh: None,
1092 license_detection_watchers: HashMap::default(),
1093 _subscriptions: [
1094 cx.subscribe(&project, Self::handle_project_event),
1095 cx.observe_release(&project, move |this, _, cx| {
1096 this.projects.remove(&entity_id);
1097 cx.notify();
1098 }),
1099 ],
1100 copilot: None,
1101 })
1102 }
1103
1104 pub fn remove_project(&mut self, project: &Entity<Project>) {
1105 self.projects.remove(&project.entity_id());
1106 }
1107
1108 fn handle_excerpt_store_event(
1109 &mut self,
1110 project_entity_id: EntityId,
1111 event: &RelatedExcerptStoreEvent,
1112 ) {
1113 if let Some(project_state) = self.projects.get(&project_entity_id) {
1114 if let Some(debug_tx) = project_state.debug_tx.clone() {
1115 match event {
1116 RelatedExcerptStoreEvent::StartedRefresh => {
1117 debug_tx
1118 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1119 ContextRetrievalStartedDebugEvent {
1120 project_entity_id: project_entity_id,
1121 timestamp: Instant::now(),
1122 search_prompt: String::new(),
1123 },
1124 ))
1125 .ok();
1126 }
1127 RelatedExcerptStoreEvent::FinishedRefresh {
1128 cache_hit_count,
1129 cache_miss_count,
1130 mean_definition_latency,
1131 max_definition_latency,
1132 } => {
1133 debug_tx
1134 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1135 ContextRetrievalFinishedDebugEvent {
1136 project_entity_id: project_entity_id,
1137 timestamp: Instant::now(),
1138 metadata: vec![
1139 (
1140 "Cache Hits",
1141 format!(
1142 "{}/{}",
1143 cache_hit_count,
1144 cache_hit_count + cache_miss_count
1145 )
1146 .into(),
1147 ),
1148 (
1149 "Max LSP Time",
1150 format!("{} ms", max_definition_latency.as_millis())
1151 .into(),
1152 ),
1153 (
1154 "Mean LSP Time",
1155 format!("{} ms", mean_definition_latency.as_millis())
1156 .into(),
1157 ),
1158 ],
1159 },
1160 ))
1161 .ok();
1162 }
1163 }
1164 }
1165 }
1166 }
1167
1168 pub fn debug_info(
1169 &mut self,
1170 project: &Entity<Project>,
1171 cx: &mut Context<Self>,
1172 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1173 let project_state = self.get_or_init_project(project, cx);
1174 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1175 project_state.debug_tx = Some(debug_watch_tx);
1176 debug_watch_rx
1177 }
1178
1179 fn handle_project_event(
1180 &mut self,
1181 project: Entity<Project>,
1182 event: &project::Event,
1183 cx: &mut Context<Self>,
1184 ) {
1185 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1186 return;
1187 }
1188 // TODO [zeta2] init with recent paths
1189 match event {
1190 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1191 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1192 return;
1193 };
1194 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1195 if let Some(path) = path {
1196 if let Some(ix) = project_state
1197 .recent_paths
1198 .iter()
1199 .position(|probe| probe == &path)
1200 {
1201 project_state.recent_paths.remove(ix);
1202 }
1203 project_state.recent_paths.push_front(path);
1204 }
1205 }
1206 project::Event::DiagnosticsUpdated { .. } => {
1207 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1208 self.refresh_prediction_from_diagnostics(
1209 project,
1210 DiagnosticSearchScope::Global,
1211 cx,
1212 );
1213 }
1214 }
1215 _ => (),
1216 }
1217 }
1218
1219 fn register_buffer_impl<'a>(
1220 project_state: &'a mut ProjectState,
1221 buffer: &Entity<Buffer>,
1222 project: &Entity<Project>,
1223 cx: &mut Context<Self>,
1224 ) -> &'a mut RegisteredBuffer {
1225 let buffer_id = buffer.entity_id();
1226
1227 if let Some(file) = buffer.read(cx).file() {
1228 let worktree_id = file.worktree_id(cx);
1229 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1230 project_state
1231 .license_detection_watchers
1232 .entry(worktree_id)
1233 .or_insert_with(|| {
1234 let project_entity_id = project.entity_id();
1235 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1236 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1237 else {
1238 return;
1239 };
1240 project_state
1241 .license_detection_watchers
1242 .remove(&worktree_id);
1243 })
1244 .detach();
1245 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1246 });
1247 }
1248 }
1249
1250 match project_state.registered_buffers.entry(buffer_id) {
1251 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1252 hash_map::Entry::Vacant(entry) => {
1253 let buf = buffer.read(cx);
1254 let snapshot = buf.text_snapshot();
1255 let file = buf.file().cloned();
1256 let project_entity_id = project.entity_id();
1257 entry.insert(RegisteredBuffer {
1258 snapshot,
1259 file,
1260 last_position: None,
1261 pending_predictions: Vec::new(),
1262 _subscriptions: [
1263 cx.subscribe(buffer, {
1264 let project = project.downgrade();
1265 move |this, buffer, event, cx| {
1266 if let language::BufferEvent::Edited { is_local } = event
1267 && let Some(project) = project.upgrade()
1268 {
1269 this.report_changes_for_buffer(
1270 &buffer, &project, false, *is_local, cx,
1271 );
1272 }
1273 }
1274 }),
1275 cx.observe_release(buffer, move |this, _buffer, _cx| {
1276 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1277 else {
1278 return;
1279 };
1280 project_state.registered_buffers.remove(&buffer_id);
1281 }),
1282 ],
1283 })
1284 }
1285 }
1286 }
1287
1288 fn report_changes_for_buffer(
1289 &mut self,
1290 buffer: &Entity<Buffer>,
1291 project: &Entity<Project>,
1292 is_predicted: bool,
1293 is_local: bool,
1294 cx: &mut Context<Self>,
1295 ) {
1296 let project_state = self.get_or_init_project(project, cx);
1297 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1298
1299 let buf = buffer.read(cx);
1300 let new_file = buf.file().cloned();
1301 let new_snapshot = buf.text_snapshot();
1302 if new_snapshot.version == registered_buffer.snapshot.version {
1303 return;
1304 }
1305 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1306 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1307 let mut edit_range: Option<Range<Anchor>> = None;
1308 let now = cx.background_executor().now();
1309
1310 for (_edit, anchor_range) in
1311 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1312 {
1313 edit_range = Some(match edit_range {
1314 None => anchor_range,
1315 Some(acc) => acc.start..anchor_range.end,
1316 });
1317 }
1318
1319 let Some(edit_range) = edit_range else {
1320 return;
1321 };
1322
1323 for pending_prediction in &mut registered_buffer.pending_predictions {
1324 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1325 pending_prediction.last_edit_at = now;
1326 }
1327 }
1328
1329 let include_in_history = is_local
1330 || collaborator_edit_overlaps_locality_region(
1331 project_state,
1332 project,
1333 buffer,
1334 &buf.snapshot(),
1335 &edit_range,
1336 cx,
1337 );
1338
1339 if !include_in_history {
1340 return;
1341 }
1342
1343 let is_recordable_history_edit =
1344 compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1345 .is_some();
1346
1347 let events = &mut project_state.events;
1348
1349 if !is_recordable_history_edit {
1350 if let Some(event) = project_state.last_event.take() {
1351 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1352 if events.len() + 1 >= EVENT_COUNT_MAX {
1353 events.pop_front();
1354 }
1355 events.push_back(event);
1356 }
1357 }
1358 return;
1359 }
1360
1361 if let Some(last_event) = project_state.last_event.as_mut() {
1362 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1363 == last_event.new_snapshot.remote_id()
1364 && old_snapshot.version == last_event.new_snapshot.version;
1365
1366 let prediction_source_changed = is_predicted != last_event.predicted;
1367
1368 let should_coalesce = is_next_snapshot_of_same_buffer
1369 && !prediction_source_changed
1370 && lines_between_ranges(
1371 &edit_range.to_point(&new_snapshot),
1372 &last_event.latest_edit_range.to_point(&new_snapshot),
1373 ) <= CHANGE_GROUPING_LINE_SPAN;
1374
1375 if should_coalesce {
1376 let pause_elapsed = last_event
1377 .last_edit_time
1378 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1379 .unwrap_or(false);
1380 if pause_elapsed {
1381 last_event.snapshot_after_last_editing_pause =
1382 Some(last_event.new_snapshot.clone());
1383 last_event.total_edit_range_at_last_pause_boundary =
1384 Some(last_event.total_edit_range.clone());
1385 }
1386
1387 last_event.latest_edit_range = edit_range.clone();
1388 last_event.total_edit_range =
1389 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1390 last_event.new_snapshot = new_snapshot;
1391 last_event.last_edit_time = Some(now);
1392 return;
1393 }
1394 }
1395
1396 if let Some(event) = project_state.last_event.take() {
1397 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1398 if events.len() + 1 >= EVENT_COUNT_MAX {
1399 events.pop_front();
1400 }
1401 events.push_back(event);
1402 }
1403 }
1404
1405 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1406
1407 project_state.last_event = Some(LastEvent {
1408 old_file,
1409 new_file,
1410 old_snapshot,
1411 new_snapshot,
1412 latest_edit_range: edit_range.clone(),
1413 total_edit_range: edit_range,
1414 total_edit_range_at_last_pause_boundary: None,
1415 predicted: is_predicted,
1416 snapshot_after_last_editing_pause: None,
1417 last_edit_time: Some(now),
1418 });
1419 }
1420
1421 fn prediction_at(
1422 &mut self,
1423 buffer: &Entity<Buffer>,
1424 position: Option<language::Anchor>,
1425 project: &Entity<Project>,
1426 cx: &App,
1427 ) -> Option<BufferEditPrediction<'_>> {
1428 let project_state = self.projects.get_mut(&project.entity_id())?;
1429 if let Some(position) = position
1430 && let Some(buffer) = project_state
1431 .registered_buffers
1432 .get_mut(&buffer.entity_id())
1433 {
1434 buffer.last_position = Some(position);
1435 }
1436
1437 let CurrentEditPrediction {
1438 requested_by,
1439 prediction,
1440 ..
1441 } = project_state.current_prediction.as_ref()?;
1442
1443 if prediction.targets_buffer(buffer.read(cx)) {
1444 Some(BufferEditPrediction::Local { prediction })
1445 } else {
1446 let show_jump = match requested_by {
1447 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1448 requested_by_buffer_id == &buffer.entity_id()
1449 }
1450 PredictionRequestedBy::DiagnosticsUpdate => true,
1451 };
1452
1453 if show_jump {
1454 Some(BufferEditPrediction::Jump { prediction })
1455 } else {
1456 None
1457 }
1458 }
1459 }
1460
1461 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1462 let Some(current_prediction) = self
1463 .projects
1464 .get_mut(&project.entity_id())
1465 .and_then(|project_state| project_state.current_prediction.take())
1466 else {
1467 return;
1468 };
1469
1470 self.report_changes_for_buffer(
1471 ¤t_prediction.prediction.buffer,
1472 project,
1473 true,
1474 true,
1475 cx,
1476 );
1477
1478 // can't hold &mut project_state ref across report_changes_for_buffer_call
1479 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1480 return;
1481 };
1482
1483 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1484 project_state.cancel_pending_prediction(pending_prediction, cx);
1485 }
1486
1487 match self.edit_prediction_model {
1488 EditPredictionModel::Mercury => {
1489 mercury::edit_prediction_accepted(
1490 current_prediction.prediction.id,
1491 self.client.http_client(),
1492 cx,
1493 );
1494 }
1495 EditPredictionModel::Zeta => {
1496 let is_cloud = !matches!(
1497 all_language_settings(None, cx).edit_predictions.provider,
1498 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1499 );
1500 if is_cloud {
1501 zeta::edit_prediction_accepted(self, current_prediction, cx)
1502 }
1503 }
1504 EditPredictionModel::Fim { .. } => {}
1505 }
1506 }
1507
1508 async fn handle_rejected_predictions(
1509 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1510 client: Arc<Client>,
1511 llm_token: LlmApiToken,
1512 app_version: Version,
1513 background_executor: BackgroundExecutor,
1514 ) {
1515 let mut rx = std::pin::pin!(rx.peekable());
1516 let mut batched = Vec::new();
1517
1518 while let Some(EditPredictionRejectionPayload {
1519 rejection,
1520 organization_id,
1521 }) = rx.next().await
1522 {
1523 batched.push(rejection);
1524
1525 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1526 select_biased! {
1527 next = rx.as_mut().peek().fuse() => {
1528 if next.is_some() {
1529 continue;
1530 }
1531 }
1532 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1533 }
1534 }
1535
1536 let url = client
1537 .http_client()
1538 .build_zed_llm_url("/predict_edits/reject", &[])
1539 .unwrap();
1540
1541 let flush_count = batched
1542 .len()
1543 // in case items have accumulated after failure
1544 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1545 let start = batched.len() - flush_count;
1546
1547 let body = RejectEditPredictionsBodyRef {
1548 rejections: &batched[start..],
1549 };
1550
1551 let result = Self::send_api_request::<()>(
1552 |builder| {
1553 let req = builder
1554 .uri(url.as_ref())
1555 .body(serde_json::to_string(&body)?.into());
1556 anyhow::Ok(req?)
1557 },
1558 client.clone(),
1559 llm_token.clone(),
1560 organization_id,
1561 app_version.clone(),
1562 true,
1563 )
1564 .await;
1565
1566 if result.log_err().is_some() {
1567 batched.drain(start..);
1568 }
1569 }
1570 }
1571
1572 async fn run_settled_predictions_worker(
1573 this: WeakEntity<Self>,
1574 mut rx: UnboundedReceiver<Instant>,
1575 cx: &mut AsyncApp,
1576 ) {
1577 let mut next_wake_time: Option<Instant> = None;
1578 loop {
1579 let now = cx.background_executor().now();
1580 if let Some(wake_time) = next_wake_time.take() {
1581 cx.background_executor()
1582 .timer(wake_time.duration_since(now))
1583 .await;
1584 } else {
1585 let Some(new_enqueue_time) = rx.next().await else {
1586 break;
1587 };
1588 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1589 while rx.next().now_or_never().flatten().is_some() {}
1590 continue;
1591 }
1592
1593 let Some(this) = this.upgrade() else {
1594 break;
1595 };
1596
1597 let now = cx.background_executor().now();
1598
1599 let mut oldest_edited_at = None;
1600
1601 this.update(cx, |this, _| {
1602 for (_, project_state) in this.projects.iter_mut() {
1603 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1604 registered_buffer
1605 .pending_predictions
1606 .retain_mut(|pending_prediction| {
1607 let age =
1608 now.saturating_duration_since(pending_prediction.enqueued_at);
1609 if age >= EDIT_PREDICTION_SETTLED_TTL {
1610 return false;
1611 }
1612
1613 let quiet_for =
1614 now.saturating_duration_since(pending_prediction.last_edit_at);
1615 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1616 let settled_editable_region = registered_buffer
1617 .snapshot
1618 .text_for_range(
1619 pending_prediction.editable_anchor_range.clone(),
1620 )
1621 .collect::<String>();
1622
1623 #[cfg(test)]
1624 if let Some(callback) = &this.settled_event_callback {
1625 callback(
1626 pending_prediction.request_id.clone(),
1627 settled_editable_region.clone(),
1628 );
1629 }
1630
1631 telemetry::event!(
1632 EDIT_PREDICTION_SETTLED_EVENT,
1633 request_id = pending_prediction.request_id.0.clone(),
1634 settled_editable_region,
1635 example = pending_prediction.example.take(),
1636 e2e_latency = pending_prediction.e2e_latency.as_millis(),
1637 );
1638
1639 return false;
1640 }
1641
1642 if oldest_edited_at
1643 .is_none_or(|t| pending_prediction.last_edit_at < t)
1644 {
1645 oldest_edited_at = Some(pending_prediction.last_edit_at);
1646 }
1647
1648 true
1649 });
1650 }
1651 }
1652 });
1653
1654 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1655 }
1656 }
1657
1658 pub(crate) fn enqueue_settled_prediction(
1659 &mut self,
1660 request_id: EditPredictionId,
1661 project: &Entity<Project>,
1662 edited_buffer: &Entity<Buffer>,
1663 edited_buffer_snapshot: &BufferSnapshot,
1664 editable_offset_range: Range<usize>,
1665 example: Option<ExampleSpec>,
1666 e2e_latency: std::time::Duration,
1667 cx: &mut Context<Self>,
1668 ) {
1669 let this = &mut *self;
1670 let project_state = this.get_or_init_project(project, cx);
1671 if let Some(buffer) = project_state
1672 .registered_buffers
1673 .get_mut(&edited_buffer.entity_id())
1674 {
1675 let now = cx.background_executor().now();
1676 buffer.pending_predictions.push(PendingSettledPrediction {
1677 request_id: request_id,
1678 editable_anchor_range: edited_buffer_snapshot
1679 .anchor_range_around(editable_offset_range),
1680 example,
1681 e2e_latency,
1682 enqueued_at: now,
1683 last_edit_at: now,
1684 });
1685 this.settled_predictions_tx.unbounded_send(now).ok();
1686 }
1687 }
1688
1689 fn reject_current_prediction(
1690 &mut self,
1691 reason: EditPredictionRejectReason,
1692 project: &Entity<Project>,
1693 cx: &App,
1694 ) {
1695 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1696 project_state.pending_predictions.clear();
1697 if let Some(prediction) = project_state.current_prediction.take() {
1698 let model_version = prediction.prediction.model_version.clone();
1699 self.reject_prediction(
1700 prediction.prediction.id,
1701 reason,
1702 prediction.was_shown,
1703 model_version,
1704 Some(prediction.e2e_latency),
1705 cx,
1706 );
1707 }
1708 };
1709 }
1710
1711 fn did_show_current_prediction(
1712 &mut self,
1713 project: &Entity<Project>,
1714 display_type: edit_prediction_types::SuggestionDisplayType,
1715 _cx: &mut Context<Self>,
1716 ) {
1717 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1718 return;
1719 };
1720
1721 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1722 return;
1723 };
1724
1725 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1726 let previous_shown_with = current_prediction.shown_with;
1727
1728 if previous_shown_with.is_none() || !is_jump {
1729 current_prediction.shown_with = Some(display_type);
1730 }
1731
1732 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1733
1734 if is_first_non_jump_show {
1735 current_prediction.was_shown = true;
1736 }
1737
1738 if is_first_non_jump_show {
1739 self.shown_predictions
1740 .push_front(current_prediction.prediction.clone());
1741 if self.shown_predictions.len() > 50 {
1742 let completion = self.shown_predictions.pop_back().unwrap();
1743 self.rated_predictions.remove(&completion.id);
1744 }
1745 }
1746 }
1747
1748 fn reject_prediction(
1749 &mut self,
1750 prediction_id: EditPredictionId,
1751 reason: EditPredictionRejectReason,
1752 was_shown: bool,
1753 model_version: Option<String>,
1754 e2e_latency: Option<std::time::Duration>,
1755 cx: &App,
1756 ) {
1757 match self.edit_prediction_model {
1758 EditPredictionModel::Zeta => {
1759 let is_cloud = !matches!(
1760 all_language_settings(None, cx).edit_predictions.provider,
1761 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1762 );
1763
1764 if is_cloud {
1765 let organization_id = self
1766 .user_store
1767 .read(cx)
1768 .current_organization()
1769 .map(|organization| organization.id.clone());
1770
1771 self.reject_predictions_tx
1772 .unbounded_send(EditPredictionRejectionPayload {
1773 rejection: EditPredictionRejection {
1774 request_id: prediction_id.to_string(),
1775 reason,
1776 was_shown,
1777 model_version,
1778 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1779 },
1780 organization_id,
1781 })
1782 .log_err();
1783 }
1784 }
1785 EditPredictionModel::Mercury => {
1786 mercury::edit_prediction_rejected(
1787 prediction_id,
1788 was_shown,
1789 reason,
1790 self.client.http_client(),
1791 cx,
1792 );
1793 }
1794 EditPredictionModel::Fim { .. } => {}
1795 }
1796 }
1797
1798 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1799 self.projects
1800 .get(&project.entity_id())
1801 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1802 }
1803
1804 pub fn refresh_prediction_from_buffer(
1805 &mut self,
1806 project: Entity<Project>,
1807 buffer: Entity<Buffer>,
1808 position: language::Anchor,
1809 cx: &mut Context<Self>,
1810 ) {
1811 self.queue_prediction_refresh(
1812 project.clone(),
1813 PredictEditsRequestTrigger::Other,
1814 buffer.entity_id(),
1815 cx,
1816 move |this, cx| {
1817 let Some(request_task) = this
1818 .update(cx, |this, cx| {
1819 this.request_prediction(
1820 &project,
1821 &buffer,
1822 position,
1823 PredictEditsRequestTrigger::Other,
1824 cx,
1825 )
1826 })
1827 .log_err()
1828 else {
1829 return Task::ready(anyhow::Ok(None));
1830 };
1831
1832 cx.spawn(async move |_cx| {
1833 request_task.await.map(|prediction_result| {
1834 prediction_result.map(|prediction_result| {
1835 (
1836 prediction_result,
1837 PredictionRequestedBy::Buffer(buffer.entity_id()),
1838 )
1839 })
1840 })
1841 })
1842 },
1843 )
1844 }
1845
1846 pub fn refresh_prediction_from_diagnostics(
1847 &mut self,
1848 project: Entity<Project>,
1849 scope: DiagnosticSearchScope,
1850 cx: &mut Context<Self>,
1851 ) {
1852 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1853 return;
1854 }
1855
1856 if currently_following(&project, cx) {
1857 return;
1858 }
1859
1860 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1861 return;
1862 };
1863
1864 // Prefer predictions from buffer
1865 if project_state.current_prediction.is_some() {
1866 log::debug!(
1867 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1868 );
1869 return;
1870 }
1871
1872 self.queue_prediction_refresh(
1873 project.clone(),
1874 PredictEditsRequestTrigger::Diagnostics,
1875 project.entity_id(),
1876 cx,
1877 move |this, cx| {
1878 let Some((active_buffer, snapshot, cursor_point)) = this
1879 .read_with(cx, |this, cx| {
1880 let project_state = this.projects.get(&project.entity_id())?;
1881 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1882 let snapshot = buffer.read(cx).snapshot();
1883
1884 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1885 return None;
1886 }
1887
1888 let cursor_point = position
1889 .map(|pos| pos.to_point(&snapshot))
1890 .unwrap_or_default();
1891
1892 Some((buffer, snapshot, cursor_point))
1893 })
1894 .log_err()
1895 .flatten()
1896 else {
1897 return Task::ready(anyhow::Ok(None));
1898 };
1899
1900 cx.spawn(async move |cx| {
1901 let diagnostic_search_range = match scope {
1902 DiagnosticSearchScope::Local => {
1903 let diagnostic_search_start =
1904 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1905 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1906 Point::new(diagnostic_search_start, 0)
1907 ..Point::new(diagnostic_search_end, 0)
1908 }
1909 DiagnosticSearchScope::Global => Default::default(),
1910 };
1911
1912 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1913 active_buffer,
1914 &snapshot,
1915 diagnostic_search_range,
1916 cursor_point,
1917 &project,
1918 cx,
1919 )
1920 .await?
1921 else {
1922 return anyhow::Ok(None);
1923 };
1924
1925 let Some(prediction_result) = this
1926 .update(cx, |this, cx| {
1927 this.request_prediction(
1928 &project,
1929 &jump_buffer,
1930 jump_position,
1931 PredictEditsRequestTrigger::Diagnostics,
1932 cx,
1933 )
1934 })?
1935 .await?
1936 else {
1937 return anyhow::Ok(None);
1938 };
1939
1940 this.update(cx, |this, cx| {
1941 Some((
1942 if this
1943 .get_or_init_project(&project, cx)
1944 .current_prediction
1945 .is_none()
1946 {
1947 prediction_result
1948 } else {
1949 EditPredictionResult {
1950 id: prediction_result.id,
1951 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1952 e2e_latency: prediction_result.e2e_latency,
1953 }
1954 },
1955 PredictionRequestedBy::DiagnosticsUpdate,
1956 ))
1957 })
1958 })
1959 },
1960 );
1961 }
1962
1963 fn predictions_enabled_at(
1964 snapshot: &BufferSnapshot,
1965 position: Option<language::Anchor>,
1966 cx: &App,
1967 ) -> bool {
1968 let file = snapshot.file();
1969 let all_settings = all_language_settings(file, cx);
1970 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1971 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1972 {
1973 return false;
1974 }
1975
1976 if let Some(last_position) = position {
1977 let settings = snapshot.settings_at(last_position, cx);
1978
1979 if !settings.edit_predictions_disabled_in.is_empty()
1980 && let Some(scope) = snapshot.language_scope_at(last_position)
1981 && let Some(scope_name) = scope.override_name()
1982 && settings
1983 .edit_predictions_disabled_in
1984 .iter()
1985 .any(|s| s == scope_name)
1986 {
1987 return false;
1988 }
1989 }
1990
1991 true
1992 }
1993
1994 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1995}
1996
1997fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
1998 let Some(app_state) = AppState::try_global(cx) else {
1999 return false;
2000 };
2001
2002 app_state
2003 .workspace_store
2004 .read(cx)
2005 .workspaces()
2006 .filter_map(|workspace| workspace.upgrade())
2007 .any(|workspace| {
2008 workspace.read(cx).project().entity_id() == project.entity_id()
2009 && workspace
2010 .read(cx)
2011 .leader_for_pane(workspace.read(cx).active_pane())
2012 .is_some()
2013 })
2014}
2015
2016fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2017 match provider {
2018 EditPredictionProvider::Zed
2019 | EditPredictionProvider::Mercury
2020 | EditPredictionProvider::Ollama
2021 | EditPredictionProvider::OpenAiCompatibleApi
2022 | EditPredictionProvider::Experimental(_) => true,
2023 EditPredictionProvider::None
2024 | EditPredictionProvider::Copilot
2025 | EditPredictionProvider::Codestral => false,
2026 }
2027}
2028
2029impl EditPredictionStore {
2030 fn queue_prediction_refresh(
2031 &mut self,
2032 project: Entity<Project>,
2033 request_trigger: PredictEditsRequestTrigger,
2034 throttle_entity: EntityId,
2035 cx: &mut Context<Self>,
2036 do_refresh: impl FnOnce(
2037 WeakEntity<Self>,
2038 &mut AsyncApp,
2039 )
2040 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2041 + 'static,
2042 ) {
2043 fn select_throttle(
2044 project_state: &mut ProjectState,
2045 request_trigger: PredictEditsRequestTrigger,
2046 ) -> &mut Option<(EntityId, Instant)> {
2047 match request_trigger {
2048 PredictEditsRequestTrigger::Diagnostics => {
2049 &mut project_state.last_jump_prediction_refresh
2050 }
2051 _ => &mut project_state.last_edit_prediction_refresh,
2052 }
2053 }
2054
2055 let (needs_acceptance_tracking, max_pending_predictions) =
2056 match all_language_settings(None, cx).edit_predictions.provider {
2057 EditPredictionProvider::Zed
2058 | EditPredictionProvider::Mercury
2059 | EditPredictionProvider::Experimental(_) => (true, 2),
2060 EditPredictionProvider::Ollama => (false, 1),
2061 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2062 EditPredictionProvider::None
2063 | EditPredictionProvider::Copilot
2064 | EditPredictionProvider::Codestral => {
2065 log::error!("queue_prediction_refresh called with non-store provider");
2066 return;
2067 }
2068 };
2069
2070 let drop_on_cancel = !needs_acceptance_tracking;
2071 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2072 let project_state = self.get_or_init_project(&project, cx);
2073 let pending_prediction_id = project_state.next_pending_prediction_id;
2074 project_state.next_pending_prediction_id += 1;
2075 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2076
2077 let task = cx.spawn(async move |this, cx| {
2078 let throttle_wait = this
2079 .update(cx, |this, cx| {
2080 let project_state = this.get_or_init_project(&project, cx);
2081 let throttle = *select_throttle(project_state, request_trigger);
2082
2083 let now = cx.background_executor().now();
2084 throttle.and_then(|(last_entity, last_timestamp)| {
2085 if throttle_entity != last_entity {
2086 return None;
2087 }
2088 (last_timestamp + throttle_timeout).checked_duration_since(now)
2089 })
2090 })
2091 .ok()
2092 .flatten();
2093
2094 if let Some(timeout) = throttle_wait {
2095 cx.background_executor().timer(timeout).await;
2096 }
2097
2098 // If this task was cancelled before the throttle timeout expired,
2099 // do not perform a request. Also skip if another task already
2100 // proceeded since we were enqueued (duplicate).
2101 let mut is_cancelled = true;
2102 this.update(cx, |this, cx| {
2103 let project_state = this.get_or_init_project(&project, cx);
2104 let was_cancelled = project_state
2105 .cancelled_predictions
2106 .remove(&pending_prediction_id);
2107 if was_cancelled {
2108 return;
2109 }
2110
2111 // Another request has been already sent since this was enqueued
2112 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2113 return;
2114 }
2115
2116 let new_refresh = (throttle_entity, cx.background_executor().now());
2117 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2118 is_cancelled = false;
2119 })
2120 .ok();
2121 if is_cancelled {
2122 return None;
2123 }
2124
2125 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2126 let new_prediction_id = new_prediction_result
2127 .as_ref()
2128 .map(|(prediction, _)| prediction.id.clone());
2129
2130 // When a prediction completes, remove it from the pending list, and cancel
2131 // any pending predictions that were enqueued before it.
2132 this.update(cx, |this, cx| {
2133 let project_state = this.get_or_init_project(&project, cx);
2134
2135 let is_cancelled = project_state
2136 .cancelled_predictions
2137 .remove(&pending_prediction_id);
2138
2139 let new_current_prediction = if !is_cancelled
2140 && let Some((prediction_result, requested_by)) = new_prediction_result
2141 {
2142 match prediction_result.prediction {
2143 Ok(prediction) => {
2144 let new_prediction = CurrentEditPrediction {
2145 requested_by,
2146 prediction,
2147 was_shown: false,
2148 shown_with: None,
2149 e2e_latency: prediction_result.e2e_latency,
2150 };
2151
2152 if let Some(current_prediction) =
2153 project_state.current_prediction.as_ref()
2154 {
2155 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2156 {
2157 this.reject_current_prediction(
2158 EditPredictionRejectReason::Replaced,
2159 &project,
2160 cx,
2161 );
2162
2163 Some(new_prediction)
2164 } else {
2165 this.reject_prediction(
2166 new_prediction.prediction.id,
2167 EditPredictionRejectReason::CurrentPreferred,
2168 false,
2169 new_prediction.prediction.model_version,
2170 Some(new_prediction.e2e_latency),
2171 cx,
2172 );
2173 None
2174 }
2175 } else {
2176 Some(new_prediction)
2177 }
2178 }
2179 Err(reject_reason) => {
2180 this.reject_prediction(
2181 prediction_result.id,
2182 reject_reason,
2183 false,
2184 None,
2185 Some(prediction_result.e2e_latency),
2186 cx,
2187 );
2188 None
2189 }
2190 }
2191 } else {
2192 None
2193 };
2194
2195 let project_state = this.get_or_init_project(&project, cx);
2196
2197 if let Some(new_prediction) = new_current_prediction {
2198 project_state.current_prediction = Some(new_prediction);
2199 }
2200
2201 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2202 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2203 if pending_prediction.id == pending_prediction_id {
2204 pending_predictions.remove(ix);
2205 for pending_prediction in pending_predictions.drain(0..ix) {
2206 project_state.cancel_pending_prediction(pending_prediction, cx)
2207 }
2208 break;
2209 }
2210 }
2211 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2212 cx.notify();
2213 })
2214 .ok();
2215
2216 new_prediction_id
2217 });
2218
2219 if project_state.pending_predictions.len() < max_pending_predictions {
2220 project_state
2221 .pending_predictions
2222 .push(PendingPrediction {
2223 id: pending_prediction_id,
2224 task,
2225 drop_on_cancel,
2226 })
2227 .unwrap();
2228 } else {
2229 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2230 project_state
2231 .pending_predictions
2232 .push(PendingPrediction {
2233 id: pending_prediction_id,
2234 task,
2235 drop_on_cancel,
2236 })
2237 .unwrap();
2238 project_state.cancel_pending_prediction(pending_prediction, cx);
2239 }
2240 }
2241
2242 pub fn request_prediction(
2243 &mut self,
2244 project: &Entity<Project>,
2245 active_buffer: &Entity<Buffer>,
2246 position: language::Anchor,
2247 trigger: PredictEditsRequestTrigger,
2248 cx: &mut Context<Self>,
2249 ) -> Task<Result<Option<EditPredictionResult>>> {
2250 self.request_prediction_internal(
2251 project.clone(),
2252 active_buffer.clone(),
2253 position,
2254 trigger,
2255 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2256 cx,
2257 )
2258 }
2259
2260 fn request_prediction_internal(
2261 &mut self,
2262 project: Entity<Project>,
2263 active_buffer: Entity<Buffer>,
2264 position: language::Anchor,
2265 trigger: PredictEditsRequestTrigger,
2266 allow_jump: bool,
2267 cx: &mut Context<Self>,
2268 ) -> Task<Result<Option<EditPredictionResult>>> {
2269 self.get_or_init_project(&project, cx);
2270 let project_state = self.projects.get(&project.entity_id()).unwrap();
2271 let stored_events = project_state.events(cx);
2272 let has_events = !stored_events.is_empty();
2273 let events: Vec<Arc<zeta_prompt::Event>> =
2274 stored_events.iter().map(|e| e.event.clone()).collect();
2275 let debug_tx = project_state.debug_tx.clone();
2276
2277 let snapshot = active_buffer.read(cx).snapshot();
2278 let cursor_point = position.to_point(&snapshot);
2279 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2280 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2281 let diagnostic_search_range =
2282 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2283
2284 let related_files = self.context_for_project(&project, cx);
2285
2286 let is_open_source = snapshot
2287 .file()
2288 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2289 && events.iter().all(|event| event.in_open_source_repo())
2290 && related_files.iter().all(|file| file.in_open_source_repo);
2291
2292 let can_collect_data = !cfg!(test)
2293 && is_open_source
2294 && self.is_data_collection_enabled(cx)
2295 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2296
2297 let inputs = EditPredictionModelInput {
2298 project: project.clone(),
2299 buffer: active_buffer,
2300 snapshot,
2301 position,
2302 events,
2303 related_files,
2304 trigger,
2305 diagnostic_search_range: diagnostic_search_range,
2306 debug_tx,
2307 can_collect_data,
2308 is_open_source,
2309 };
2310
2311 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2312
2313 let task = match self.edit_prediction_model {
2314 EditPredictionModel::Zeta => {
2315 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2316 }
2317 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2318 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2319 };
2320
2321 cx.spawn(async move |this, cx| {
2322 let prediction = task.await?;
2323
2324 // Only fall back to diagnostics-based prediction if we got a
2325 // the model had nothing to suggest for the buffer
2326 if prediction.is_none()
2327 && allow_jump
2328 && has_events
2329 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2330 {
2331 this.update(cx, |this, cx| {
2332 this.refresh_prediction_from_diagnostics(
2333 project,
2334 DiagnosticSearchScope::Local,
2335 cx,
2336 );
2337 })?;
2338 return anyhow::Ok(None);
2339 }
2340
2341 Ok(prediction)
2342 })
2343 }
2344
2345 pub(crate) async fn next_diagnostic_location(
2346 active_buffer: Entity<Buffer>,
2347 active_buffer_snapshot: &BufferSnapshot,
2348 active_buffer_diagnostic_search_range: Range<Point>,
2349 active_buffer_cursor_point: Point,
2350 project: &Entity<Project>,
2351 cx: &mut AsyncApp,
2352 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2353 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2354 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2355 .flat_map(|(_, _, _, selections)| {
2356 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2357 })
2358 .collect();
2359
2360 let mut jump_location = active_buffer_snapshot
2361 .diagnostic_groups(None)
2362 .into_iter()
2363 .filter_map(|(_, group)| {
2364 let range = &group.entries[group.primary_ix]
2365 .range
2366 .to_point(&active_buffer_snapshot);
2367 if range.overlaps(&active_buffer_diagnostic_search_range) {
2368 return None;
2369 }
2370 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2371 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2372 });
2373 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2374 <= DIAGNOSTIC_LINES_RANGE;
2375 if near_collaborator && !near_local {
2376 return None;
2377 }
2378 Some(range.start)
2379 })
2380 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2381 .map(|position| {
2382 (
2383 active_buffer.clone(),
2384 active_buffer_snapshot.anchor_before(position),
2385 )
2386 });
2387
2388 if jump_location.is_none() {
2389 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2390 let file = buffer.file()?;
2391
2392 Some(ProjectPath {
2393 worktree_id: file.worktree_id(cx),
2394 path: file.path().clone(),
2395 })
2396 });
2397
2398 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2399 project
2400 .diagnostic_summaries(false, cx)
2401 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2402 .map(|(path, _, _)| {
2403 let shared_prefix = path
2404 .path
2405 .components()
2406 .zip(
2407 active_buffer_path
2408 .as_ref()
2409 .map(|p| p.path.components())
2410 .unwrap_or_default(),
2411 )
2412 .take_while(|(a, b)| a == b)
2413 .count();
2414 (path, shared_prefix)
2415 })
2416 .collect()
2417 });
2418
2419 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2420
2421 for (path, _) in candidates {
2422 let candidate_buffer = project
2423 .update(cx, |project, cx| project.open_buffer(path, cx))
2424 .await?;
2425
2426 let (has_collaborators, diagnostic_position) =
2427 candidate_buffer.read_with(cx, |buffer, _cx| {
2428 let snapshot = buffer.snapshot();
2429 let has_collaborators = snapshot
2430 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2431 .next()
2432 .is_some();
2433 let position = buffer
2434 .buffer_diagnostics(None)
2435 .into_iter()
2436 .min_by_key(|entry| entry.diagnostic.severity)
2437 .map(|entry| entry.range.start);
2438 (has_collaborators, position)
2439 });
2440
2441 if has_collaborators {
2442 continue;
2443 }
2444
2445 if let Some(position) = diagnostic_position {
2446 jump_location = Some((candidate_buffer, position));
2447 break;
2448 }
2449 }
2450 }
2451
2452 anyhow::Ok(jump_location)
2453 }
2454
2455 async fn send_raw_llm_request(
2456 request: RawCompletionRequest,
2457 client: Arc<Client>,
2458 custom_url: Option<Arc<Url>>,
2459 llm_token: LlmApiToken,
2460 organization_id: Option<OrganizationId>,
2461 app_version: Version,
2462 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2463 let url = if let Some(custom_url) = custom_url {
2464 custom_url.as_ref().clone()
2465 } else {
2466 client
2467 .http_client()
2468 .build_zed_llm_url("/predict_edits/raw", &[])?
2469 };
2470
2471 Self::send_api_request(
2472 |builder| {
2473 let req = builder
2474 .uri(url.as_ref())
2475 .body(serde_json::to_string(&request)?.into());
2476 Ok(req?)
2477 },
2478 client,
2479 llm_token,
2480 organization_id,
2481 app_version,
2482 true,
2483 )
2484 .await
2485 }
2486
2487 pub(crate) async fn send_v3_request(
2488 input: ZetaPromptInput,
2489 client: Arc<Client>,
2490 llm_token: LlmApiToken,
2491 organization_id: Option<OrganizationId>,
2492 app_version: Version,
2493 trigger: PredictEditsRequestTrigger,
2494 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2495 let url = client
2496 .http_client()
2497 .build_zed_llm_url("/predict_edits/v3", &[])?;
2498
2499 let request = PredictEditsV3Request { input, trigger };
2500
2501 let json_bytes = serde_json::to_vec(&request)?;
2502 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2503
2504 Self::send_api_request(
2505 |builder| {
2506 let req = builder
2507 .uri(url.as_ref())
2508 .header("Content-Encoding", "zstd")
2509 .body(compressed.clone().into());
2510 Ok(req?)
2511 },
2512 client,
2513 llm_token,
2514 organization_id,
2515 app_version,
2516 true,
2517 )
2518 .await
2519 }
2520
2521 async fn send_api_request<Res>(
2522 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2523 client: Arc<Client>,
2524 llm_token: LlmApiToken,
2525 organization_id: Option<OrganizationId>,
2526 app_version: Version,
2527 require_auth: bool,
2528 ) -> Result<(Res, Option<EditPredictionUsage>)>
2529 where
2530 Res: DeserializeOwned,
2531 {
2532 let http_client = client.http_client();
2533
2534 let mut token = if require_auth {
2535 Some(llm_token.acquire(&client, organization_id.clone()).await?)
2536 } else {
2537 llm_token
2538 .acquire(&client, organization_id.clone())
2539 .await
2540 .ok()
2541 };
2542 let mut did_retry = false;
2543
2544 loop {
2545 let request_builder = http_client::Request::builder().method(Method::POST);
2546
2547 let mut request_builder = request_builder
2548 .header("Content-Type", "application/json")
2549 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2550
2551 // Only add Authorization header if we have a token
2552 if let Some(ref token_value) = token {
2553 request_builder =
2554 request_builder.header("Authorization", format!("Bearer {}", token_value));
2555 }
2556
2557 let request = build(request_builder)?;
2558
2559 let mut response = http_client.send(request).await?;
2560
2561 if let Some(minimum_required_version) = response
2562 .headers()
2563 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2564 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2565 {
2566 anyhow::ensure!(
2567 app_version >= minimum_required_version,
2568 ZedUpdateRequiredError {
2569 minimum_version: minimum_required_version
2570 }
2571 );
2572 }
2573
2574 if response.status().is_success() {
2575 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2576
2577 let mut body = Vec::new();
2578 response.body_mut().read_to_end(&mut body).await?;
2579 return Ok((serde_json::from_slice(&body)?, usage));
2580 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2581 did_retry = true;
2582 token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2583 } else {
2584 let mut body = String::new();
2585 response.body_mut().read_to_string(&mut body).await?;
2586 anyhow::bail!(
2587 "Request failed with status: {:?}\nBody: {}",
2588 response.status(),
2589 body
2590 );
2591 }
2592 }
2593 }
2594
2595 pub fn refresh_context(
2596 &mut self,
2597 project: &Entity<Project>,
2598 buffer: &Entity<language::Buffer>,
2599 cursor_position: language::Anchor,
2600 cx: &mut Context<Self>,
2601 ) {
2602 self.get_or_init_project(project, cx)
2603 .context
2604 .update(cx, |store, cx| {
2605 store.refresh(buffer.clone(), cursor_position, cx);
2606 });
2607 }
2608
2609 #[cfg(feature = "cli-support")]
2610 pub fn set_context_for_buffer(
2611 &mut self,
2612 project: &Entity<Project>,
2613 related_files: Vec<RelatedFile>,
2614 cx: &mut Context<Self>,
2615 ) {
2616 self.get_or_init_project(project, cx)
2617 .context
2618 .update(cx, |store, cx| {
2619 store.set_related_files(related_files, cx);
2620 });
2621 }
2622
2623 #[cfg(feature = "cli-support")]
2624 pub fn set_recent_paths_for_project(
2625 &mut self,
2626 project: &Entity<Project>,
2627 paths: impl IntoIterator<Item = project::ProjectPath>,
2628 cx: &mut Context<Self>,
2629 ) {
2630 let project_state = self.get_or_init_project(project, cx);
2631 project_state.recent_paths = paths.into_iter().collect();
2632 }
2633
2634 fn is_file_open_source(
2635 &self,
2636 project: &Entity<Project>,
2637 file: &Arc<dyn File>,
2638 cx: &App,
2639 ) -> bool {
2640 if !file.is_local() || file.is_private() {
2641 return false;
2642 }
2643 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2644 return false;
2645 };
2646 project_state
2647 .license_detection_watchers
2648 .get(&file.worktree_id(cx))
2649 .as_ref()
2650 .is_some_and(|watcher| watcher.is_project_open_source())
2651 }
2652
2653 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2654 self.data_collection_choice.is_enabled(cx)
2655 }
2656
2657 fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2658 let choice = KeyValueStore::global(cx)
2659 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2660 .log_err()
2661 .flatten();
2662
2663 match choice.as_deref() {
2664 Some("true") => DataCollectionChoice::Enabled,
2665 Some("false") => DataCollectionChoice::Disabled,
2666 Some(_) => {
2667 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2668 DataCollectionChoice::NotAnswered
2669 }
2670 None => DataCollectionChoice::NotAnswered,
2671 }
2672 }
2673
2674 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2675 self.data_collection_choice = self.data_collection_choice.toggle();
2676 let new_choice = self.data_collection_choice;
2677 let is_enabled = new_choice.is_enabled(cx);
2678 let kvp = KeyValueStore::global(cx);
2679 db::write_and_log(cx, move || async move {
2680 kvp.write_kvp(
2681 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2682 is_enabled.to_string(),
2683 )
2684 .await
2685 });
2686 }
2687
2688 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2689 self.shown_predictions.iter()
2690 }
2691
2692 pub fn shown_completions_len(&self) -> usize {
2693 self.shown_predictions.len()
2694 }
2695
2696 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2697 self.rated_predictions.contains(id)
2698 }
2699
2700 pub fn rate_prediction(
2701 &mut self,
2702 prediction: &EditPrediction,
2703 rating: EditPredictionRating,
2704 feedback: String,
2705 cx: &mut Context<Self>,
2706 ) {
2707 let organization = self.user_store.read(cx).current_organization();
2708
2709 self.rated_predictions.insert(prediction.id.clone());
2710
2711 cx.background_spawn({
2712 let client = self.client.clone();
2713 let prediction_id = prediction.id.to_string();
2714 let inputs = serde_json::to_value(&prediction.inputs);
2715 let output = prediction
2716 .edit_preview
2717 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2718 async move {
2719 client
2720 .cloud_client()
2721 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2722 organization_id: organization.map(|organization| organization.id.clone()),
2723 request_id: prediction_id,
2724 rating: match rating {
2725 EditPredictionRating::Positive => "positive".to_string(),
2726 EditPredictionRating::Negative => "negative".to_string(),
2727 },
2728 inputs: inputs?,
2729 output,
2730 feedback,
2731 })
2732 .await?;
2733
2734 anyhow::Ok(())
2735 }
2736 })
2737 .detach_and_log_err(cx);
2738
2739 cx.notify();
2740 }
2741}
2742
2743fn collaborator_edit_overlaps_locality_region(
2744 project_state: &ProjectState,
2745 project: &Entity<Project>,
2746 buffer: &Entity<Buffer>,
2747 snapshot: &BufferSnapshot,
2748 edit_range: &Range<Anchor>,
2749 cx: &App,
2750) -> bool {
2751 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2752 return false;
2753 };
2754
2755 if active_buffer.entity_id() != buffer.entity_id() {
2756 return false;
2757 }
2758
2759 let locality_point_range = expand_context_syntactically_then_linewise(
2760 snapshot,
2761 (position..position).to_point(snapshot),
2762 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2763 );
2764 let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2765
2766 edit_range.overlaps(&locality_anchor_range, snapshot)
2767}
2768
2769fn merge_trailing_events_if_needed(
2770 events: &mut VecDeque<StoredEvent>,
2771 end_snapshot: &TextBufferSnapshot,
2772 latest_snapshot: &TextBufferSnapshot,
2773 latest_edit_range: &Range<Anchor>,
2774) {
2775 if let Some(last_event) = events.back() {
2776 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2777 return;
2778 }
2779 if !latest_snapshot
2780 .version
2781 .observed_all(&last_event.new_snapshot_version)
2782 {
2783 return;
2784 }
2785 }
2786
2787 let mut next_old_event = None;
2788 let mut mergeable_count = 0;
2789 for old_event in events.iter().rev() {
2790 if let Some(next_old_event) = next_old_event
2791 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2792 {
2793 break;
2794 }
2795 mergeable_count += 1;
2796 next_old_event = Some(old_event);
2797 }
2798
2799 if mergeable_count <= 1 {
2800 return;
2801 }
2802
2803 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2804 let oldest_event = events_to_merge.peek().unwrap();
2805 let oldest_snapshot = oldest_event.old_snapshot.clone();
2806 let newest_snapshot = end_snapshot;
2807 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2808
2809 for event in events.range(events.len() - mergeable_count + 1..) {
2810 merged_edit_range =
2811 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2812 }
2813
2814 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2815 &oldest_snapshot,
2816 newest_snapshot,
2817 &merged_edit_range,
2818 ) {
2819 let merged_event = match oldest_event.event.as_ref() {
2820 zeta_prompt::Event::BufferChange {
2821 old_path,
2822 path,
2823 in_open_source_repo,
2824 ..
2825 } => StoredEvent {
2826 event: Arc::new(zeta_prompt::Event::BufferChange {
2827 old_path: old_path.clone(),
2828 path: path.clone(),
2829 diff,
2830 in_open_source_repo: *in_open_source_repo,
2831 predicted: events_to_merge.all(|e| {
2832 matches!(
2833 e.event.as_ref(),
2834 zeta_prompt::Event::BufferChange {
2835 predicted: true,
2836 ..
2837 }
2838 )
2839 }),
2840 }),
2841 old_snapshot: oldest_snapshot.clone(),
2842 new_snapshot_version: newest_snapshot.version.clone(),
2843 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2844 ..newest_snapshot.anchor_before(edit_range.end),
2845 },
2846 };
2847 events.truncate(events.len() - mergeable_count);
2848 events.push_back(merged_event);
2849 }
2850}
2851
2852fn merge_anchor_ranges(
2853 left: &Range<Anchor>,
2854 right: &Range<Anchor>,
2855 snapshot: &TextBufferSnapshot,
2856) -> Range<Anchor> {
2857 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2858 left.start
2859 } else {
2860 right.start
2861 };
2862 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2863 left.end
2864 } else {
2865 right.end
2866 };
2867 start..end
2868}
2869
2870#[derive(Error, Debug)]
2871#[error(
2872 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2873)]
2874pub struct ZedUpdateRequiredError {
2875 minimum_version: Version,
2876}
2877
2878#[derive(Debug, Clone, Copy)]
2879pub enum DataCollectionChoice {
2880 NotAnswered,
2881 Enabled,
2882 Disabled,
2883}
2884
2885impl DataCollectionChoice {
2886 pub fn is_enabled(self, cx: &App) -> bool {
2887 if cx.is_staff() {
2888 return true;
2889 }
2890 match self {
2891 Self::Enabled => true,
2892 Self::NotAnswered | Self::Disabled => false,
2893 }
2894 }
2895
2896 #[must_use]
2897 pub fn toggle(&self) -> DataCollectionChoice {
2898 match self {
2899 Self::Enabled => Self::Disabled,
2900 Self::Disabled => Self::Enabled,
2901 Self::NotAnswered => Self::Enabled,
2902 }
2903 }
2904}
2905
2906impl From<bool> for DataCollectionChoice {
2907 fn from(value: bool) -> Self {
2908 match value {
2909 true => DataCollectionChoice::Enabled,
2910 false => DataCollectionChoice::Disabled,
2911 }
2912 }
2913}
2914
2915struct ZedPredictUpsell;
2916
2917impl Dismissable for ZedPredictUpsell {
2918 const KEY: &'static str = "dismissed-edit-predict-upsell";
2919
2920 fn dismissed(cx: &App) -> bool {
2921 // To make this backwards compatible with older versions of Zed, we
2922 // check if the user has seen the previous Edit Prediction Onboarding
2923 // before, by checking the data collection choice which was written to
2924 // the database once the user clicked on "Accept and Enable"
2925 let kvp = KeyValueStore::global(cx);
2926 if kvp
2927 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2928 .log_err()
2929 .is_some_and(|s| s.is_some())
2930 {
2931 return true;
2932 }
2933
2934 kvp.read_kvp(Self::KEY)
2935 .log_err()
2936 .is_some_and(|s| s.is_some())
2937 }
2938}
2939
2940pub fn should_show_upsell_modal(cx: &App) -> bool {
2941 !ZedPredictUpsell::dismissed(cx)
2942}
2943
2944pub fn init(cx: &mut App) {
2945 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2946 workspace.register_action(
2947 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2948 ZedPredictModal::toggle(
2949 workspace,
2950 workspace.user_store().clone(),
2951 workspace.client().clone(),
2952 window,
2953 cx,
2954 )
2955 },
2956 );
2957
2958 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2959 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2960 settings
2961 .project
2962 .all_languages
2963 .edit_predictions
2964 .get_or_insert_default()
2965 .provider = Some(EditPredictionProvider::None)
2966 });
2967 });
2968 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2969 EditPredictionStore::try_global(cx).and_then(|store| {
2970 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2971 })
2972 }
2973
2974 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2975 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2976 copilot_ui::initiate_sign_in(copilot, window, cx);
2977 }
2978 });
2979 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2980 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2981 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2982 }
2983 });
2984 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2985 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2986 copilot_ui::initiate_sign_out(copilot, window, cx);
2987 }
2988 });
2989 })
2990 .detach();
2991}