1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56
57pub mod cursor_excerpt;
58pub mod example_spec;
59pub mod fim;
60mod license_detection;
61pub mod mercury;
62pub mod ollama;
63mod onboarding_modal;
64pub mod open_ai_response;
65mod prediction;
66pub mod sweep_ai;
67
68pub mod udiff;
69
70mod capture_example;
71pub mod open_ai_compatible;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
79use crate::example_spec::ExampleSpec;
80use crate::license_detection::LicenseDetectionWatcher;
81use crate::mercury::Mercury;
82use crate::onboarding_modal::ZedPredictModal;
83pub use crate::prediction::EditPrediction;
84pub use crate::prediction::EditPredictionId;
85use crate::prediction::EditPredictionResult;
86pub use crate::sweep_ai::SweepAi;
87pub use capture_example::capture_example;
88pub use language_model::ApiKeyState;
89pub use telemetry_events::EditPredictionRating;
90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
91
92actions!(
93 edit_prediction,
94 [
95 /// Resets the edit prediction onboarding state.
96 ResetOnboarding,
97 /// Clears the edit prediction history.
98 ClearHistory,
99 ]
100);
101
102/// Maximum number of events to track.
103const EVENT_COUNT_MAX: usize = 10;
104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
105const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
106const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
107const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
108const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
109const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
110const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
111const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
112
113pub struct EditPredictionJumpsFeatureFlag;
114
115impl FeatureFlag for EditPredictionJumpsFeatureFlag {
116 const NAME: &'static str = "edit_prediction_jumps";
117}
118
119#[derive(Clone)]
120struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
121
122impl Global for EditPredictionStoreGlobal {}
123
124/// Configuration for using the raw Zeta2 endpoint.
125/// When set, the client uses the raw endpoint and constructs the prompt itself.
126/// The version is also used as the Baseten environment name (lowercased).
127#[derive(Clone)]
128pub struct Zeta2RawConfig {
129 pub model_id: Option<String>,
130 pub environment: Option<String>,
131 pub format: ZetaFormat,
132}
133
134pub struct EditPredictionStore {
135 client: Arc<Client>,
136 user_store: Entity<UserStore>,
137 llm_token: LlmApiToken,
138 _llm_token_subscription: Subscription,
139 _fetch_experiments_task: Task<()>,
140 projects: HashMap<EntityId, ProjectState>,
141 update_required: bool,
142 edit_prediction_model: EditPredictionModel,
143 zeta2_raw_config: Option<Zeta2RawConfig>,
144 preferred_experiment: Option<String>,
145 available_experiments: Vec<String>,
146 pub sweep_ai: SweepAi,
147 pub mercury: Mercury,
148 data_collection_choice: DataCollectionChoice,
149 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
150 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
151 shown_predictions: VecDeque<EditPrediction>,
152 rated_predictions: HashSet<EditPredictionId>,
153 #[cfg(test)]
154 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
155}
156
157pub(crate) struct EditPredictionRejectionPayload {
158 rejection: EditPredictionRejection,
159 organization_id: Option<OrganizationId>,
160}
161
162#[derive(Copy, Clone, PartialEq, Eq)]
163pub enum EditPredictionModel {
164 Zeta,
165 Fim { format: EditPredictionPromptFormat },
166 Sweep,
167 Mercury,
168}
169
170#[derive(Clone)]
171pub struct EditPredictionModelInput {
172 project: Entity<Project>,
173 buffer: Entity<Buffer>,
174 snapshot: BufferSnapshot,
175 position: Anchor,
176 events: Vec<Arc<zeta_prompt::Event>>,
177 related_files: Vec<RelatedFile>,
178 recent_paths: VecDeque<ProjectPath>,
179 trigger: PredictEditsRequestTrigger,
180 diagnostic_search_range: Range<Point>,
181 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
182 can_collect_data: bool,
183 is_open_source: bool,
184 pub user_actions: Vec<UserActionRecord>,
185}
186
187#[derive(Debug)]
188pub enum DebugEvent {
189 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
190 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
191 EditPredictionStarted(EditPredictionStartedDebugEvent),
192 EditPredictionFinished(EditPredictionFinishedDebugEvent),
193}
194
195#[derive(Debug)]
196pub struct ContextRetrievalStartedDebugEvent {
197 pub project_entity_id: EntityId,
198 pub timestamp: Instant,
199 pub search_prompt: String,
200}
201
202#[derive(Debug)]
203pub struct ContextRetrievalFinishedDebugEvent {
204 pub project_entity_id: EntityId,
205 pub timestamp: Instant,
206 pub metadata: Vec<(&'static str, SharedString)>,
207}
208
209#[derive(Debug)]
210pub struct EditPredictionStartedDebugEvent {
211 pub buffer: WeakEntity<Buffer>,
212 pub position: Anchor,
213 pub prompt: Option<String>,
214}
215
216#[derive(Debug)]
217pub struct EditPredictionFinishedDebugEvent {
218 pub buffer: WeakEntity<Buffer>,
219 pub position: Anchor,
220 pub model_output: Option<String>,
221}
222
223const USER_ACTION_HISTORY_SIZE: usize = 16;
224
225#[derive(Clone, Debug)]
226pub struct UserActionRecord {
227 pub action_type: UserActionType,
228 pub buffer_id: EntityId,
229 pub line_number: u32,
230 pub offset: usize,
231 pub timestamp_epoch_ms: u64,
232}
233
234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
235pub enum UserActionType {
236 InsertChar,
237 InsertSelection,
238 DeleteChar,
239 DeleteSelection,
240 CursorMovement,
241}
242
243/// An event with associated metadata for reconstructing buffer state.
244#[derive(Clone)]
245pub struct StoredEvent {
246 pub event: Arc<zeta_prompt::Event>,
247 pub old_snapshot: TextBufferSnapshot,
248 pub new_snapshot_version: clock::Global,
249 pub total_edit_range: Range<Anchor>,
250}
251
252impl StoredEvent {
253 fn can_merge(
254 &self,
255 next_old_event: &StoredEvent,
256 latest_snapshot: &TextBufferSnapshot,
257 latest_edit_range: &Range<Anchor>,
258 ) -> bool {
259 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
260 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
261 return false;
262 }
263 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
264 return false;
265 }
266 if self.new_snapshot_version != next_old_event.old_snapshot.version {
267 return false;
268 }
269 if !latest_snapshot
270 .version
271 .observed_all(&next_old_event.new_snapshot_version)
272 {
273 return false;
274 }
275
276 let a_is_predicted = matches!(
277 self.event.as_ref(),
278 zeta_prompt::Event::BufferChange {
279 predicted: true,
280 ..
281 }
282 );
283 let b_is_predicted = matches!(
284 next_old_event.event.as_ref(),
285 zeta_prompt::Event::BufferChange {
286 predicted: true,
287 ..
288 }
289 );
290
291 // If events come from the same source (both predicted or both manual) then
292 // we would have coalesced them already.
293 if a_is_predicted == b_is_predicted {
294 return false;
295 }
296
297 let left_range = self.total_edit_range.to_point(latest_snapshot);
298 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
299 let latest_range = latest_edit_range.to_point(latest_snapshot);
300
301 // Events near to the latest edit are not merged if their sources differ.
302 if lines_between_ranges(&left_range, &latest_range)
303 .min(lines_between_ranges(&right_range, &latest_range))
304 <= CHANGE_GROUPING_LINE_SPAN
305 {
306 return false;
307 }
308
309 // Events that are distant from each other are not merged.
310 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
311 return false;
312 }
313
314 true
315 }
316}
317
318fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
319 if left.start > right.end {
320 return left.start.row - right.end.row;
321 }
322 if right.start > left.end {
323 return right.start.row - left.end.row;
324 }
325 0
326}
327
328struct ProjectState {
329 events: VecDeque<StoredEvent>,
330 last_event: Option<LastEvent>,
331 recent_paths: VecDeque<ProjectPath>,
332 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
333 current_prediction: Option<CurrentEditPrediction>,
334 next_pending_prediction_id: usize,
335 pending_predictions: ArrayVec<PendingPrediction, 2>,
336 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
337 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
338 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
339 cancelled_predictions: HashSet<usize>,
340 context: Entity<RelatedExcerptStore>,
341 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
342 user_actions: VecDeque<UserActionRecord>,
343 _subscriptions: [gpui::Subscription; 2],
344 copilot: Option<Entity<Copilot>>,
345}
346
347impl ProjectState {
348 fn record_user_action(&mut self, action: UserActionRecord) {
349 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
350 self.user_actions.pop_front();
351 }
352 self.user_actions.push_back(action);
353 }
354
355 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
356 self.events
357 .iter()
358 .cloned()
359 .chain(self.last_event.as_ref().iter().flat_map(|event| {
360 let (one, two) = event.split_by_pause();
361 let one = one.finalize(&self.license_detection_watchers, cx);
362 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
363 one.into_iter().chain(two)
364 }))
365 .collect()
366 }
367
368 fn cancel_pending_prediction(
369 &mut self,
370 pending_prediction: PendingPrediction,
371 cx: &mut Context<EditPredictionStore>,
372 ) {
373 self.cancelled_predictions.insert(pending_prediction.id);
374
375 if pending_prediction.drop_on_cancel {
376 drop(pending_prediction.task);
377 } else {
378 cx.spawn(async move |this, cx| {
379 let Some(prediction_id) = pending_prediction.task.await else {
380 return;
381 };
382
383 this.update(cx, |this, cx| {
384 this.reject_prediction(
385 prediction_id,
386 EditPredictionRejectReason::Canceled,
387 false,
388 None,
389 cx,
390 );
391 })
392 .ok();
393 })
394 .detach()
395 }
396 }
397
398 fn active_buffer(
399 &self,
400 project: &Entity<Project>,
401 cx: &App,
402 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
403 let project = project.read(cx);
404 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
405 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
406 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
407 Some((active_buffer, registered_buffer.last_position))
408 }
409}
410
411#[derive(Debug, Clone)]
412struct CurrentEditPrediction {
413 pub requested_by: PredictionRequestedBy,
414 pub prediction: EditPrediction,
415 pub was_shown: bool,
416 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
417}
418
419impl CurrentEditPrediction {
420 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
421 let Some(new_edits) = self
422 .prediction
423 .interpolate(&self.prediction.buffer.read(cx))
424 else {
425 return false;
426 };
427
428 if self.prediction.buffer != old_prediction.prediction.buffer {
429 return true;
430 }
431
432 let Some(old_edits) = old_prediction
433 .prediction
434 .interpolate(&old_prediction.prediction.buffer.read(cx))
435 else {
436 return true;
437 };
438
439 let requested_by_buffer_id = self.requested_by.buffer_id();
440
441 // This reduces the occurrence of UI thrash from replacing edits
442 //
443 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
444 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
445 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
446 && old_edits.len() == 1
447 && new_edits.len() == 1
448 {
449 let (old_range, old_text) = &old_edits[0];
450 let (new_range, new_text) = &new_edits[0];
451 new_range == old_range && new_text.starts_with(old_text.as_ref())
452 } else {
453 true
454 }
455 }
456}
457
458#[derive(Debug, Clone)]
459enum PredictionRequestedBy {
460 DiagnosticsUpdate,
461 Buffer(EntityId),
462}
463
464impl PredictionRequestedBy {
465 pub fn buffer_id(&self) -> Option<EntityId> {
466 match self {
467 PredictionRequestedBy::DiagnosticsUpdate => None,
468 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
469 }
470 }
471}
472
473const DIAGNOSTIC_LINES_RANGE: u32 = 20;
474
475#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
476pub enum DiagnosticSearchScope {
477 Local,
478 Global,
479}
480
481#[derive(Debug)]
482struct PendingPrediction {
483 id: usize,
484 task: Task<Option<EditPredictionId>>,
485 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
486 /// If false, the task is awaited to completion so rejection can be reported.
487 drop_on_cancel: bool,
488}
489
490/// A prediction from the perspective of a buffer.
491#[derive(Debug)]
492enum BufferEditPrediction<'a> {
493 Local { prediction: &'a EditPrediction },
494 Jump { prediction: &'a EditPrediction },
495}
496
497#[cfg(test)]
498impl std::ops::Deref for BufferEditPrediction<'_> {
499 type Target = EditPrediction;
500
501 fn deref(&self) -> &Self::Target {
502 match self {
503 BufferEditPrediction::Local { prediction } => prediction,
504 BufferEditPrediction::Jump { prediction } => prediction,
505 }
506 }
507}
508
509#[derive(Clone)]
510struct PendingSettledPrediction {
511 request_id: EditPredictionId,
512 editable_anchor_range: Range<Anchor>,
513 example: Option<ExampleSpec>,
514 enqueued_at: Instant,
515 last_edit_at: Instant,
516}
517
518struct RegisteredBuffer {
519 file: Option<Arc<dyn File>>,
520 snapshot: TextBufferSnapshot,
521 pending_predictions: Vec<PendingSettledPrediction>,
522 last_position: Option<Anchor>,
523 _subscriptions: [gpui::Subscription; 2],
524}
525
526#[derive(Clone)]
527struct LastEvent {
528 old_snapshot: TextBufferSnapshot,
529 new_snapshot: TextBufferSnapshot,
530 old_file: Option<Arc<dyn File>>,
531 new_file: Option<Arc<dyn File>>,
532 latest_edit_range: Range<Anchor>,
533 total_edit_range: Range<Anchor>,
534 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
535 predicted: bool,
536 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
537 last_edit_time: Option<Instant>,
538}
539
540impl LastEvent {
541 pub fn finalize(
542 &self,
543 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
544 cx: &App,
545 ) -> Option<StoredEvent> {
546 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
547 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
548
549 let in_open_source_repo =
550 [self.new_file.as_ref(), self.old_file.as_ref()]
551 .iter()
552 .all(|file| {
553 file.is_some_and(|file| {
554 license_detection_watchers
555 .get(&file.worktree_id(cx))
556 .is_some_and(|watcher| watcher.is_project_open_source())
557 })
558 });
559
560 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
561 &self.old_snapshot,
562 &self.new_snapshot,
563 &self.total_edit_range,
564 )?;
565
566 if path == old_path && diff.is_empty() {
567 None
568 } else {
569 Some(StoredEvent {
570 event: Arc::new(zeta_prompt::Event::BufferChange {
571 old_path,
572 path,
573 diff,
574 in_open_source_repo,
575 predicted: self.predicted,
576 }),
577 old_snapshot: self.old_snapshot.clone(),
578 new_snapshot_version: self.new_snapshot.version.clone(),
579 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
580 ..self.new_snapshot.anchor_before(edit_range.end),
581 })
582 }
583 }
584
585 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
586 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
587 return (self.clone(), None);
588 };
589
590 let total_edit_range_before_pause = self
591 .total_edit_range_at_last_pause_boundary
592 .clone()
593 .unwrap_or_else(|| self.total_edit_range.clone());
594
595 let Some(total_edit_range_after_pause) =
596 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
597 else {
598 return (self.clone(), None);
599 };
600
601 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
602 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
603
604 let before = LastEvent {
605 old_snapshot: self.old_snapshot.clone(),
606 new_snapshot: boundary_snapshot.clone(),
607 old_file: self.old_file.clone(),
608 new_file: self.new_file.clone(),
609 latest_edit_range: latest_edit_range_before_pause,
610 total_edit_range: total_edit_range_before_pause,
611 total_edit_range_at_last_pause_boundary: None,
612 predicted: self.predicted,
613 snapshot_after_last_editing_pause: None,
614 last_edit_time: self.last_edit_time,
615 };
616
617 let after = LastEvent {
618 old_snapshot: boundary_snapshot.clone(),
619 new_snapshot: self.new_snapshot.clone(),
620 old_file: self.old_file.clone(),
621 new_file: self.new_file.clone(),
622 latest_edit_range: latest_edit_range_after_pause,
623 total_edit_range: total_edit_range_after_pause,
624 total_edit_range_at_last_pause_boundary: None,
625 predicted: self.predicted,
626 snapshot_after_last_editing_pause: None,
627 last_edit_time: self.last_edit_time,
628 };
629
630 (before, Some(after))
631 }
632}
633
634fn compute_total_edit_range_between_snapshots(
635 old_snapshot: &TextBufferSnapshot,
636 new_snapshot: &TextBufferSnapshot,
637) -> Option<Range<Anchor>> {
638 let edits: Vec<Edit<usize>> = new_snapshot
639 .edits_since::<usize>(&old_snapshot.version)
640 .collect();
641
642 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
643 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
644 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
645
646 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
647}
648
649fn compute_old_range_for_new_range(
650 old_snapshot: &TextBufferSnapshot,
651 new_snapshot: &TextBufferSnapshot,
652 total_edit_range: &Range<Anchor>,
653) -> Option<Range<Point>> {
654 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
655 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
656
657 let edits: Vec<Edit<usize>> = new_snapshot
658 .edits_since::<usize>(&old_snapshot.version)
659 .collect();
660 let mut old_start_offset = None;
661 let mut old_end_offset = None;
662 let mut delta: isize = 0;
663
664 for edit in &edits {
665 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
666 old_start_offset = Some(if new_start_offset < edit.new.start {
667 new_start_offset.checked_add_signed(-delta)?
668 } else {
669 edit.old.start
670 });
671 }
672
673 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
674 old_end_offset = Some(if new_end_offset < edit.new.start {
675 new_end_offset.checked_add_signed(-delta)?
676 } else {
677 edit.old.end
678 });
679 }
680
681 delta += edit.new.len() as isize - edit.old.len() as isize;
682 }
683
684 let old_start_offset =
685 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
686 let old_end_offset =
687 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
688
689 Some(
690 old_snapshot.offset_to_point(old_start_offset)
691 ..old_snapshot.offset_to_point(old_end_offset),
692 )
693}
694
695fn compute_diff_between_snapshots_in_range(
696 old_snapshot: &TextBufferSnapshot,
697 new_snapshot: &TextBufferSnapshot,
698 total_edit_range: &Range<Anchor>,
699) -> Option<(String, Range<Point>)> {
700 let new_start_point = total_edit_range.start.to_point(new_snapshot);
701 let new_end_point = total_edit_range.end.to_point(new_snapshot);
702 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
703 let old_start_point = old_range.start;
704 let old_end_point = old_range.end;
705
706 const CONTEXT_LINES: u32 = 3;
707
708 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
709 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
710 let old_context_end_row =
711 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
712 let new_context_end_row =
713 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
714
715 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
716 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
717 let old_end_line_offset = old_snapshot
718 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
719 let new_end_line_offset = new_snapshot
720 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
721 let old_edit_range = old_start_line_offset..old_end_line_offset;
722 let new_edit_range = new_start_line_offset..new_end_line_offset;
723
724 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
725 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
726
727 let diff = language::unified_diff_with_offsets(
728 &old_region_text,
729 &new_region_text,
730 old_context_start_row,
731 new_context_start_row,
732 );
733
734 Some((diff, new_start_point..new_end_point))
735}
736
737fn buffer_path_with_id_fallback(
738 file: Option<&Arc<dyn File>>,
739 snapshot: &TextBufferSnapshot,
740 cx: &App,
741) -> Arc<Path> {
742 if let Some(file) = file {
743 file.full_path(cx).into()
744 } else {
745 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
746 }
747}
748
749impl EditPredictionStore {
750 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
751 cx.try_global::<EditPredictionStoreGlobal>()
752 .map(|global| global.0.clone())
753 }
754
755 pub fn global(
756 client: &Arc<Client>,
757 user_store: &Entity<UserStore>,
758 cx: &mut App,
759 ) -> Entity<Self> {
760 cx.try_global::<EditPredictionStoreGlobal>()
761 .map(|global| global.0.clone())
762 .unwrap_or_else(|| {
763 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
764 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
765 ep_store
766 })
767 }
768
769 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
770 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
771 let data_collection_choice = Self::load_data_collection_choice();
772
773 let llm_token = LlmApiToken::default();
774
775 let (reject_tx, reject_rx) = mpsc::unbounded();
776 cx.background_spawn({
777 let client = client.clone();
778 let llm_token = llm_token.clone();
779 let app_version = AppVersion::global(cx);
780 let background_executor = cx.background_executor().clone();
781 async move {
782 Self::handle_rejected_predictions(
783 reject_rx,
784 client,
785 llm_token,
786 app_version,
787 background_executor,
788 )
789 .await
790 }
791 })
792 .detach();
793
794 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
795 cx.spawn(async move |this, cx| {
796 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
797 })
798 .detach();
799
800 let mut current_user = user_store.read(cx).watch_current_user();
801 let fetch_experiments_task = cx.spawn(async move |this, cx| {
802 while current_user.borrow().is_none() {
803 current_user.next().await;
804 }
805 this.update(cx, |this, cx| {
806 this.refresh_available_experiments(cx);
807 })
808 .log_err();
809 });
810
811 let this = Self {
812 projects: HashMap::default(),
813 client,
814 user_store,
815 llm_token,
816 _fetch_experiments_task: fetch_experiments_task,
817 _llm_token_subscription: cx.subscribe(
818 &refresh_llm_token_listener,
819 |this, _listener, _event, cx| {
820 let client = this.client.clone();
821 let llm_token = this.llm_token.clone();
822 let organization_id = this
823 .user_store
824 .read(cx)
825 .current_organization()
826 .map(|organization| organization.id.clone());
827 cx.spawn(async move |_this, _cx| {
828 llm_token.refresh(&client, organization_id).await?;
829 anyhow::Ok(())
830 })
831 .detach_and_log_err(cx);
832 },
833 ),
834 update_required: false,
835 edit_prediction_model: EditPredictionModel::Zeta,
836 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
837 preferred_experiment: None,
838 available_experiments: Vec::new(),
839 sweep_ai: SweepAi::new(cx),
840 mercury: Mercury::new(cx),
841
842 data_collection_choice,
843 reject_predictions_tx: reject_tx,
844 settled_predictions_tx,
845 rated_predictions: Default::default(),
846 shown_predictions: Default::default(),
847 #[cfg(test)]
848 settled_event_callback: None,
849 };
850
851 this
852 }
853
854 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
855 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
856 let format = ZetaFormat::parse(&version_str).ok()?;
857 let model_id = env::var("ZED_ZETA_MODEL").ok();
858 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
859 Some(Zeta2RawConfig {
860 model_id,
861 environment,
862 format,
863 })
864 }
865
866 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
867 self.edit_prediction_model = model;
868 }
869
870 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
871 self.zeta2_raw_config = Some(config);
872 }
873
874 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
875 self.zeta2_raw_config.as_ref()
876 }
877
878 pub fn preferred_experiment(&self) -> Option<&str> {
879 self.preferred_experiment.as_deref()
880 }
881
882 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
883 self.preferred_experiment = experiment;
884 }
885
886 pub fn available_experiments(&self) -> &[String] {
887 &self.available_experiments
888 }
889
890 pub fn active_experiment(&self) -> Option<&str> {
891 self.preferred_experiment.as_deref().or_else(|| {
892 self.shown_predictions
893 .iter()
894 .find_map(|p| p.model_version.as_ref())
895 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
896 })
897 }
898
899 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
900 let client = self.client.clone();
901 let llm_token = self.llm_token.clone();
902 let app_version = AppVersion::global(cx);
903 let organization_id = self
904 .user_store
905 .read(cx)
906 .current_organization()
907 .map(|organization| organization.id.clone());
908
909 cx.spawn(async move |this, cx| {
910 let experiments = cx
911 .background_spawn(async move {
912 let http_client = client.http_client();
913 let token = llm_token.acquire(&client, organization_id).await?;
914 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
915 let request = http_client::Request::builder()
916 .method(Method::GET)
917 .uri(url.as_ref())
918 .header("Authorization", format!("Bearer {}", token))
919 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
920 .body(Default::default())?;
921 let mut response = http_client.send(request).await?;
922 if response.status().is_success() {
923 let mut body = Vec::new();
924 response.body_mut().read_to_end(&mut body).await?;
925 let experiments: Vec<String> = serde_json::from_slice(&body)?;
926 Ok(experiments)
927 } else {
928 let mut body = String::new();
929 response.body_mut().read_to_string(&mut body).await?;
930 anyhow::bail!(
931 "Failed to fetch experiments: {:?}\nBody: {}",
932 response.status(),
933 body
934 );
935 }
936 })
937 .await?;
938 this.update(cx, |this, cx| {
939 this.available_experiments = experiments;
940 cx.notify();
941 })?;
942 anyhow::Ok(())
943 })
944 .detach_and_log_err(cx);
945 }
946
947 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
948 use ui::IconName;
949 match self.edit_prediction_model {
950 EditPredictionModel::Sweep => {
951 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
952 .with_disabled(IconName::SweepAiDisabled)
953 .with_up(IconName::SweepAiUp)
954 .with_down(IconName::SweepAiDown)
955 .with_error(IconName::SweepAiError)
956 }
957 EditPredictionModel::Mercury => {
958 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
959 }
960 EditPredictionModel::Zeta => {
961 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
962 .with_disabled(IconName::ZedPredictDisabled)
963 .with_up(IconName::ZedPredictUp)
964 .with_down(IconName::ZedPredictDown)
965 .with_error(IconName::ZedPredictError)
966 }
967 EditPredictionModel::Fim { .. } => {
968 let settings = &all_language_settings(None, cx).edit_predictions;
969 match settings.provider {
970 EditPredictionProvider::Ollama => {
971 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
972 }
973 _ => {
974 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
975 }
976 }
977 }
978 }
979 }
980
981 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
982 self.sweep_ai.api_token.read(cx).has_key()
983 }
984
985 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
986 self.mercury.api_token.read(cx).has_key()
987 }
988
989 pub fn clear_history(&mut self) {
990 for project_state in self.projects.values_mut() {
991 project_state.events.clear();
992 project_state.last_event.take();
993 }
994 }
995
996 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
997 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
998 project_state.events.clear();
999 project_state.last_event.take();
1000 }
1001 }
1002
1003 pub fn edit_history_for_project(
1004 &self,
1005 project: &Entity<Project>,
1006 cx: &App,
1007 ) -> Vec<StoredEvent> {
1008 self.projects
1009 .get(&project.entity_id())
1010 .map(|project_state| project_state.events(cx))
1011 .unwrap_or_default()
1012 }
1013
1014 pub fn context_for_project<'a>(
1015 &'a self,
1016 project: &Entity<Project>,
1017 cx: &'a mut App,
1018 ) -> Vec<RelatedFile> {
1019 self.projects
1020 .get(&project.entity_id())
1021 .map(|project_state| {
1022 project_state.context.update(cx, |context, cx| {
1023 context
1024 .related_files_with_buffers(cx)
1025 .map(|(mut related_file, buffer)| {
1026 related_file.in_open_source_repo = buffer
1027 .read(cx)
1028 .file()
1029 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1030 related_file
1031 })
1032 .collect()
1033 })
1034 })
1035 .unwrap_or_default()
1036 }
1037
1038 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1039 self.projects
1040 .get(&project.entity_id())
1041 .and_then(|project| project.copilot.clone())
1042 }
1043
1044 pub fn start_copilot_for_project(
1045 &mut self,
1046 project: &Entity<Project>,
1047 cx: &mut Context<Self>,
1048 ) -> Option<Entity<Copilot>> {
1049 if DisableAiSettings::get(None, cx).disable_ai {
1050 return None;
1051 }
1052 let state = self.get_or_init_project(project, cx);
1053
1054 if state.copilot.is_some() {
1055 return state.copilot.clone();
1056 }
1057 let _project = project.clone();
1058 let project = project.read(cx);
1059
1060 let node = project.node_runtime().cloned();
1061 if let Some(node) = node {
1062 let next_id = project.languages().next_language_server_id();
1063 let fs = project.fs().clone();
1064
1065 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1066 state.copilot = Some(copilot.clone());
1067 Some(copilot)
1068 } else {
1069 None
1070 }
1071 }
1072
1073 pub fn context_for_project_with_buffers<'a>(
1074 &'a self,
1075 project: &Entity<Project>,
1076 cx: &'a mut App,
1077 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1078 self.projects
1079 .get(&project.entity_id())
1080 .map(|project| {
1081 project.context.update(cx, |context, cx| {
1082 context.related_files_with_buffers(cx).collect()
1083 })
1084 })
1085 .unwrap_or_default()
1086 }
1087
1088 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1089 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1090 self.user_store.read(cx).edit_prediction_usage()
1091 } else {
1092 None
1093 }
1094 }
1095
1096 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1097 self.get_or_init_project(project, cx);
1098 }
1099
1100 pub fn register_buffer(
1101 &mut self,
1102 buffer: &Entity<Buffer>,
1103 project: &Entity<Project>,
1104 cx: &mut Context<Self>,
1105 ) {
1106 let project_state = self.get_or_init_project(project, cx);
1107 Self::register_buffer_impl(project_state, buffer, project, cx);
1108 }
1109
1110 fn get_or_init_project(
1111 &mut self,
1112 project: &Entity<Project>,
1113 cx: &mut Context<Self>,
1114 ) -> &mut ProjectState {
1115 let entity_id = project.entity_id();
1116 self.projects
1117 .entry(entity_id)
1118 .or_insert_with(|| ProjectState {
1119 context: {
1120 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1121 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1122 this.handle_excerpt_store_event(entity_id, event);
1123 })
1124 .detach();
1125 related_excerpt_store
1126 },
1127 events: VecDeque::new(),
1128 last_event: None,
1129 recent_paths: VecDeque::new(),
1130 debug_tx: None,
1131 registered_buffers: HashMap::default(),
1132 current_prediction: None,
1133 cancelled_predictions: HashSet::default(),
1134 pending_predictions: ArrayVec::new(),
1135 next_pending_prediction_id: 0,
1136 last_edit_prediction_refresh: None,
1137 last_jump_prediction_refresh: None,
1138 license_detection_watchers: HashMap::default(),
1139 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1140 _subscriptions: [
1141 cx.subscribe(&project, Self::handle_project_event),
1142 cx.observe_release(&project, move |this, _, cx| {
1143 this.projects.remove(&entity_id);
1144 cx.notify();
1145 }),
1146 ],
1147 copilot: None,
1148 })
1149 }
1150
1151 pub fn remove_project(&mut self, project: &Entity<Project>) {
1152 self.projects.remove(&project.entity_id());
1153 }
1154
1155 fn handle_excerpt_store_event(
1156 &mut self,
1157 project_entity_id: EntityId,
1158 event: &RelatedExcerptStoreEvent,
1159 ) {
1160 if let Some(project_state) = self.projects.get(&project_entity_id) {
1161 if let Some(debug_tx) = project_state.debug_tx.clone() {
1162 match event {
1163 RelatedExcerptStoreEvent::StartedRefresh => {
1164 debug_tx
1165 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1166 ContextRetrievalStartedDebugEvent {
1167 project_entity_id: project_entity_id,
1168 timestamp: Instant::now(),
1169 search_prompt: String::new(),
1170 },
1171 ))
1172 .ok();
1173 }
1174 RelatedExcerptStoreEvent::FinishedRefresh {
1175 cache_hit_count,
1176 cache_miss_count,
1177 mean_definition_latency,
1178 max_definition_latency,
1179 } => {
1180 debug_tx
1181 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1182 ContextRetrievalFinishedDebugEvent {
1183 project_entity_id: project_entity_id,
1184 timestamp: Instant::now(),
1185 metadata: vec![
1186 (
1187 "Cache Hits",
1188 format!(
1189 "{}/{}",
1190 cache_hit_count,
1191 cache_hit_count + cache_miss_count
1192 )
1193 .into(),
1194 ),
1195 (
1196 "Max LSP Time",
1197 format!("{} ms", max_definition_latency.as_millis())
1198 .into(),
1199 ),
1200 (
1201 "Mean LSP Time",
1202 format!("{} ms", mean_definition_latency.as_millis())
1203 .into(),
1204 ),
1205 ],
1206 },
1207 ))
1208 .ok();
1209 }
1210 }
1211 }
1212 }
1213 }
1214
1215 pub fn debug_info(
1216 &mut self,
1217 project: &Entity<Project>,
1218 cx: &mut Context<Self>,
1219 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1220 let project_state = self.get_or_init_project(project, cx);
1221 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1222 project_state.debug_tx = Some(debug_watch_tx);
1223 debug_watch_rx
1224 }
1225
1226 fn handle_project_event(
1227 &mut self,
1228 project: Entity<Project>,
1229 event: &project::Event,
1230 cx: &mut Context<Self>,
1231 ) {
1232 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1233 return;
1234 }
1235 // TODO [zeta2] init with recent paths
1236 match event {
1237 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1238 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1239 return;
1240 };
1241 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1242 if let Some(path) = path {
1243 if let Some(ix) = project_state
1244 .recent_paths
1245 .iter()
1246 .position(|probe| probe == &path)
1247 {
1248 project_state.recent_paths.remove(ix);
1249 }
1250 project_state.recent_paths.push_front(path);
1251 }
1252 }
1253 project::Event::DiagnosticsUpdated { .. } => {
1254 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1255 self.refresh_prediction_from_diagnostics(
1256 project,
1257 DiagnosticSearchScope::Global,
1258 cx,
1259 );
1260 }
1261 }
1262 _ => (),
1263 }
1264 }
1265
1266 fn register_buffer_impl<'a>(
1267 project_state: &'a mut ProjectState,
1268 buffer: &Entity<Buffer>,
1269 project: &Entity<Project>,
1270 cx: &mut Context<Self>,
1271 ) -> &'a mut RegisteredBuffer {
1272 let buffer_id = buffer.entity_id();
1273
1274 if let Some(file) = buffer.read(cx).file() {
1275 let worktree_id = file.worktree_id(cx);
1276 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1277 project_state
1278 .license_detection_watchers
1279 .entry(worktree_id)
1280 .or_insert_with(|| {
1281 let project_entity_id = project.entity_id();
1282 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1283 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1284 else {
1285 return;
1286 };
1287 project_state
1288 .license_detection_watchers
1289 .remove(&worktree_id);
1290 })
1291 .detach();
1292 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1293 });
1294 }
1295 }
1296
1297 match project_state.registered_buffers.entry(buffer_id) {
1298 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1299 hash_map::Entry::Vacant(entry) => {
1300 let buf = buffer.read(cx);
1301 let snapshot = buf.text_snapshot();
1302 let file = buf.file().cloned();
1303 let project_entity_id = project.entity_id();
1304 entry.insert(RegisteredBuffer {
1305 snapshot,
1306 file,
1307 last_position: None,
1308 pending_predictions: Vec::new(),
1309 _subscriptions: [
1310 cx.subscribe(buffer, {
1311 let project = project.downgrade();
1312 move |this, buffer, event, cx| {
1313 if let language::BufferEvent::Edited { is_local } = event
1314 && let Some(project) = project.upgrade()
1315 {
1316 this.report_changes_for_buffer(
1317 &buffer, &project, false, *is_local, cx,
1318 );
1319 }
1320 }
1321 }),
1322 cx.observe_release(buffer, move |this, _buffer, _cx| {
1323 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1324 else {
1325 return;
1326 };
1327 project_state.registered_buffers.remove(&buffer_id);
1328 }),
1329 ],
1330 })
1331 }
1332 }
1333 }
1334
1335 fn report_changes_for_buffer(
1336 &mut self,
1337 buffer: &Entity<Buffer>,
1338 project: &Entity<Project>,
1339 is_predicted: bool,
1340 is_local: bool,
1341 cx: &mut Context<Self>,
1342 ) {
1343 let project_state = self.get_or_init_project(project, cx);
1344 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1345
1346 let buf = buffer.read(cx);
1347 let new_file = buf.file().cloned();
1348 let new_snapshot = buf.text_snapshot();
1349 if new_snapshot.version == registered_buffer.snapshot.version {
1350 return;
1351 }
1352 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1353 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1354 let mut num_edits = 0usize;
1355 let mut total_deleted = 0usize;
1356 let mut total_inserted = 0usize;
1357 let mut edit_range: Option<Range<Anchor>> = None;
1358 let mut last_offset: Option<usize> = None;
1359 let now = cx.background_executor().now();
1360
1361 for (edit, anchor_range) in
1362 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1363 {
1364 num_edits += 1;
1365 total_deleted += edit.old.len();
1366 total_inserted += edit.new.len();
1367 edit_range = Some(match edit_range {
1368 None => anchor_range,
1369 Some(acc) => acc.start..anchor_range.end,
1370 });
1371 last_offset = Some(edit.new.end);
1372 }
1373
1374 let Some(edit_range) = edit_range else {
1375 return;
1376 };
1377
1378 for pending_prediction in &mut registered_buffer.pending_predictions {
1379 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1380 pending_prediction.last_edit_at = now;
1381 }
1382 }
1383
1384 let include_in_history = is_local
1385 || collaborator_edit_overlaps_locality_region(
1386 project_state,
1387 project,
1388 buffer,
1389 &buf.snapshot(),
1390 &edit_range,
1391 cx,
1392 );
1393
1394 if is_local {
1395 let action_type = match (total_deleted, total_inserted, num_edits) {
1396 (0, ins, n) if ins == n => UserActionType::InsertChar,
1397 (0, _, _) => UserActionType::InsertSelection,
1398 (del, 0, n) if del == n => UserActionType::DeleteChar,
1399 (_, 0, _) => UserActionType::DeleteSelection,
1400 (_, ins, n) if ins == n => UserActionType::InsertChar,
1401 (_, _, _) => UserActionType::InsertSelection,
1402 };
1403
1404 if let Some(offset) = last_offset {
1405 let point = new_snapshot.offset_to_point(offset);
1406 let timestamp_epoch_ms = SystemTime::now()
1407 .duration_since(UNIX_EPOCH)
1408 .map(|d| d.as_millis() as u64)
1409 .unwrap_or(0);
1410 project_state.record_user_action(UserActionRecord {
1411 action_type,
1412 buffer_id: buffer.entity_id(),
1413 line_number: point.row,
1414 offset,
1415 timestamp_epoch_ms,
1416 });
1417 }
1418 }
1419
1420 if !include_in_history {
1421 return;
1422 }
1423
1424 let events = &mut project_state.events;
1425
1426 if let Some(last_event) = project_state.last_event.as_mut() {
1427 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1428 == last_event.new_snapshot.remote_id()
1429 && old_snapshot.version == last_event.new_snapshot.version;
1430
1431 let prediction_source_changed = is_predicted != last_event.predicted;
1432
1433 let should_coalesce = is_next_snapshot_of_same_buffer
1434 && !prediction_source_changed
1435 && lines_between_ranges(
1436 &edit_range.to_point(&new_snapshot),
1437 &last_event.latest_edit_range.to_point(&new_snapshot),
1438 ) <= CHANGE_GROUPING_LINE_SPAN;
1439
1440 if should_coalesce {
1441 let pause_elapsed = last_event
1442 .last_edit_time
1443 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1444 .unwrap_or(false);
1445 if pause_elapsed {
1446 last_event.snapshot_after_last_editing_pause =
1447 Some(last_event.new_snapshot.clone());
1448 last_event.total_edit_range_at_last_pause_boundary =
1449 Some(last_event.total_edit_range.clone());
1450 }
1451
1452 last_event.latest_edit_range = edit_range.clone();
1453 last_event.total_edit_range =
1454 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1455 last_event.new_snapshot = new_snapshot;
1456 last_event.last_edit_time = Some(now);
1457 return;
1458 }
1459 }
1460
1461 if let Some(event) = project_state.last_event.take() {
1462 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1463 if events.len() + 1 >= EVENT_COUNT_MAX {
1464 events.pop_front();
1465 }
1466 events.push_back(event);
1467 }
1468 }
1469
1470 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1471
1472 project_state.last_event = Some(LastEvent {
1473 old_file,
1474 new_file,
1475 old_snapshot,
1476 new_snapshot,
1477 latest_edit_range: edit_range.clone(),
1478 total_edit_range: edit_range,
1479 total_edit_range_at_last_pause_boundary: None,
1480 predicted: is_predicted,
1481 snapshot_after_last_editing_pause: None,
1482 last_edit_time: Some(now),
1483 });
1484 }
1485
1486 fn prediction_at(
1487 &mut self,
1488 buffer: &Entity<Buffer>,
1489 position: Option<language::Anchor>,
1490 project: &Entity<Project>,
1491 cx: &App,
1492 ) -> Option<BufferEditPrediction<'_>> {
1493 let project_state = self.projects.get_mut(&project.entity_id())?;
1494 if let Some(position) = position
1495 && let Some(buffer) = project_state
1496 .registered_buffers
1497 .get_mut(&buffer.entity_id())
1498 {
1499 buffer.last_position = Some(position);
1500 }
1501
1502 let CurrentEditPrediction {
1503 requested_by,
1504 prediction,
1505 ..
1506 } = project_state.current_prediction.as_ref()?;
1507
1508 if prediction.targets_buffer(buffer.read(cx)) {
1509 Some(BufferEditPrediction::Local { prediction })
1510 } else {
1511 let show_jump = match requested_by {
1512 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1513 requested_by_buffer_id == &buffer.entity_id()
1514 }
1515 PredictionRequestedBy::DiagnosticsUpdate => true,
1516 };
1517
1518 if show_jump {
1519 Some(BufferEditPrediction::Jump { prediction })
1520 } else {
1521 None
1522 }
1523 }
1524 }
1525
1526 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1527 let Some(current_prediction) = self
1528 .projects
1529 .get_mut(&project.entity_id())
1530 .and_then(|project_state| project_state.current_prediction.take())
1531 else {
1532 return;
1533 };
1534
1535 self.report_changes_for_buffer(
1536 ¤t_prediction.prediction.buffer,
1537 project,
1538 true,
1539 true,
1540 cx,
1541 );
1542
1543 // can't hold &mut project_state ref across report_changes_for_buffer_call
1544 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1545 return;
1546 };
1547
1548 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1549 project_state.cancel_pending_prediction(pending_prediction, cx);
1550 }
1551
1552 match self.edit_prediction_model {
1553 EditPredictionModel::Sweep => {
1554 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1555 }
1556 EditPredictionModel::Mercury => {
1557 mercury::edit_prediction_accepted(
1558 current_prediction.prediction.id,
1559 self.client.http_client(),
1560 cx,
1561 );
1562 }
1563 EditPredictionModel::Zeta => {
1564 let is_cloud = !matches!(
1565 all_language_settings(None, cx).edit_predictions.provider,
1566 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1567 );
1568 if is_cloud {
1569 zeta::edit_prediction_accepted(self, current_prediction, cx)
1570 }
1571 }
1572 EditPredictionModel::Fim { .. } => {}
1573 }
1574 }
1575
1576 async fn handle_rejected_predictions(
1577 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1578 client: Arc<Client>,
1579 llm_token: LlmApiToken,
1580 app_version: Version,
1581 background_executor: BackgroundExecutor,
1582 ) {
1583 let mut rx = std::pin::pin!(rx.peekable());
1584 let mut batched = Vec::new();
1585
1586 while let Some(EditPredictionRejectionPayload {
1587 rejection,
1588 organization_id,
1589 }) = rx.next().await
1590 {
1591 batched.push(rejection);
1592
1593 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1594 select_biased! {
1595 next = rx.as_mut().peek().fuse() => {
1596 if next.is_some() {
1597 continue;
1598 }
1599 }
1600 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1601 }
1602 }
1603
1604 let url = client
1605 .http_client()
1606 .build_zed_llm_url("/predict_edits/reject", &[])
1607 .unwrap();
1608
1609 let flush_count = batched
1610 .len()
1611 // in case items have accumulated after failure
1612 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1613 let start = batched.len() - flush_count;
1614
1615 let body = RejectEditPredictionsBodyRef {
1616 rejections: &batched[start..],
1617 };
1618
1619 let result = Self::send_api_request::<()>(
1620 |builder| {
1621 let req = builder
1622 .uri(url.as_ref())
1623 .body(serde_json::to_string(&body)?.into());
1624 anyhow::Ok(req?)
1625 },
1626 client.clone(),
1627 llm_token.clone(),
1628 organization_id,
1629 app_version.clone(),
1630 true,
1631 )
1632 .await;
1633
1634 if result.log_err().is_some() {
1635 batched.drain(start..);
1636 }
1637 }
1638 }
1639
1640 async fn run_settled_predictions_worker(
1641 this: WeakEntity<Self>,
1642 mut rx: UnboundedReceiver<Instant>,
1643 cx: &mut AsyncApp,
1644 ) {
1645 let mut next_wake_time: Option<Instant> = None;
1646 loop {
1647 let now = cx.background_executor().now();
1648 if let Some(wake_time) = next_wake_time.take() {
1649 cx.background_executor()
1650 .timer(wake_time.duration_since(now))
1651 .await;
1652 } else {
1653 let Some(new_enqueue_time) = rx.next().await else {
1654 break;
1655 };
1656 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1657 while rx.next().now_or_never().flatten().is_some() {}
1658 continue;
1659 }
1660
1661 let Some(this) = this.upgrade() else {
1662 break;
1663 };
1664
1665 let now = cx.background_executor().now();
1666
1667 let mut oldest_edited_at = None;
1668
1669 this.update(cx, |this, _| {
1670 for (_, project_state) in this.projects.iter_mut() {
1671 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1672 registered_buffer
1673 .pending_predictions
1674 .retain_mut(|pending_prediction| {
1675 let age =
1676 now.saturating_duration_since(pending_prediction.enqueued_at);
1677 if age >= EDIT_PREDICTION_SETTLED_TTL {
1678 return false;
1679 }
1680
1681 let quiet_for =
1682 now.saturating_duration_since(pending_prediction.last_edit_at);
1683 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1684 let settled_editable_region = registered_buffer
1685 .snapshot
1686 .text_for_range(
1687 pending_prediction.editable_anchor_range.clone(),
1688 )
1689 .collect::<String>();
1690
1691 #[cfg(test)]
1692 if let Some(callback) = &this.settled_event_callback {
1693 callback(
1694 pending_prediction.request_id.clone(),
1695 settled_editable_region.clone(),
1696 );
1697 }
1698
1699 telemetry::event!(
1700 EDIT_PREDICTION_SETTLED_EVENT,
1701 request_id = pending_prediction.request_id.0.clone(),
1702 settled_editable_region,
1703 example = pending_prediction.example.take(),
1704 );
1705
1706 return false;
1707 }
1708
1709 if oldest_edited_at
1710 .is_none_or(|t| pending_prediction.last_edit_at < t)
1711 {
1712 oldest_edited_at = Some(pending_prediction.last_edit_at);
1713 }
1714
1715 true
1716 });
1717 }
1718 }
1719 });
1720
1721 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1722 }
1723 }
1724
1725 pub(crate) fn enqueue_settled_prediction(
1726 &mut self,
1727 request_id: EditPredictionId,
1728 project: &Entity<Project>,
1729 edited_buffer: &Entity<Buffer>,
1730 edited_buffer_snapshot: &BufferSnapshot,
1731 editable_offset_range: Range<usize>,
1732 example: Option<ExampleSpec>,
1733 cx: &mut Context<Self>,
1734 ) {
1735 let this = &mut *self;
1736 let project_state = this.get_or_init_project(project, cx);
1737 if let Some(buffer) = project_state
1738 .registered_buffers
1739 .get_mut(&edited_buffer.entity_id())
1740 {
1741 let now = cx.background_executor().now();
1742 buffer.pending_predictions.push(PendingSettledPrediction {
1743 request_id: request_id,
1744 editable_anchor_range: edited_buffer_snapshot
1745 .anchor_range_around(editable_offset_range),
1746 example,
1747 enqueued_at: now,
1748 last_edit_at: now,
1749 });
1750 this.settled_predictions_tx.unbounded_send(now).ok();
1751 }
1752 }
1753
1754 fn reject_current_prediction(
1755 &mut self,
1756 reason: EditPredictionRejectReason,
1757 project: &Entity<Project>,
1758 cx: &App,
1759 ) {
1760 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1761 project_state.pending_predictions.clear();
1762 if let Some(prediction) = project_state.current_prediction.take() {
1763 let model_version = prediction.prediction.model_version.clone();
1764 self.reject_prediction(
1765 prediction.prediction.id,
1766 reason,
1767 prediction.was_shown,
1768 model_version,
1769 cx,
1770 );
1771 }
1772 };
1773 }
1774
1775 fn did_show_current_prediction(
1776 &mut self,
1777 project: &Entity<Project>,
1778 display_type: edit_prediction_types::SuggestionDisplayType,
1779 cx: &mut Context<Self>,
1780 ) {
1781 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1782 return;
1783 };
1784
1785 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1786 return;
1787 };
1788
1789 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1790 let previous_shown_with = current_prediction.shown_with;
1791
1792 if previous_shown_with.is_none() || !is_jump {
1793 current_prediction.shown_with = Some(display_type);
1794 }
1795
1796 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1797
1798 if is_first_non_jump_show {
1799 current_prediction.was_shown = true;
1800 }
1801
1802 let display_type_changed = previous_shown_with != Some(display_type);
1803
1804 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1805 sweep_ai::edit_prediction_shown(
1806 &self.sweep_ai,
1807 self.client.clone(),
1808 ¤t_prediction.prediction,
1809 display_type,
1810 cx,
1811 );
1812 }
1813
1814 if is_first_non_jump_show {
1815 self.shown_predictions
1816 .push_front(current_prediction.prediction.clone());
1817 if self.shown_predictions.len() > 50 {
1818 let completion = self.shown_predictions.pop_back().unwrap();
1819 self.rated_predictions.remove(&completion.id);
1820 }
1821 }
1822 }
1823
1824 fn reject_prediction(
1825 &mut self,
1826 prediction_id: EditPredictionId,
1827 reason: EditPredictionRejectReason,
1828 was_shown: bool,
1829 model_version: Option<String>,
1830 cx: &App,
1831 ) {
1832 match self.edit_prediction_model {
1833 EditPredictionModel::Zeta => {
1834 let is_cloud = !matches!(
1835 all_language_settings(None, cx).edit_predictions.provider,
1836 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1837 );
1838
1839 if is_cloud {
1840 let organization_id = self
1841 .user_store
1842 .read(cx)
1843 .current_organization()
1844 .map(|organization| organization.id.clone());
1845
1846 self.reject_predictions_tx
1847 .unbounded_send(EditPredictionRejectionPayload {
1848 rejection: EditPredictionRejection {
1849 request_id: prediction_id.to_string(),
1850 reason,
1851 was_shown,
1852 model_version,
1853 },
1854 organization_id,
1855 })
1856 .log_err();
1857 }
1858 }
1859 EditPredictionModel::Mercury => {
1860 mercury::edit_prediction_rejected(
1861 prediction_id,
1862 was_shown,
1863 reason,
1864 self.client.http_client(),
1865 cx,
1866 );
1867 }
1868 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1869 }
1870 }
1871
1872 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1873 self.projects
1874 .get(&project.entity_id())
1875 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1876 }
1877
1878 pub fn refresh_prediction_from_buffer(
1879 &mut self,
1880 project: Entity<Project>,
1881 buffer: Entity<Buffer>,
1882 position: language::Anchor,
1883 cx: &mut Context<Self>,
1884 ) {
1885 self.queue_prediction_refresh(
1886 project.clone(),
1887 PredictEditsRequestTrigger::Other,
1888 buffer.entity_id(),
1889 cx,
1890 move |this, cx| {
1891 let Some(request_task) = this
1892 .update(cx, |this, cx| {
1893 this.request_prediction(
1894 &project,
1895 &buffer,
1896 position,
1897 PredictEditsRequestTrigger::Other,
1898 cx,
1899 )
1900 })
1901 .log_err()
1902 else {
1903 return Task::ready(anyhow::Ok(None));
1904 };
1905
1906 cx.spawn(async move |_cx| {
1907 request_task.await.map(|prediction_result| {
1908 prediction_result.map(|prediction_result| {
1909 (
1910 prediction_result,
1911 PredictionRequestedBy::Buffer(buffer.entity_id()),
1912 )
1913 })
1914 })
1915 })
1916 },
1917 )
1918 }
1919
1920 pub fn refresh_prediction_from_diagnostics(
1921 &mut self,
1922 project: Entity<Project>,
1923 scope: DiagnosticSearchScope,
1924 cx: &mut Context<Self>,
1925 ) {
1926 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1927 return;
1928 }
1929
1930 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1931 return;
1932 };
1933
1934 // Prefer predictions from buffer
1935 if project_state.current_prediction.is_some() {
1936 log::debug!(
1937 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1938 );
1939 return;
1940 }
1941
1942 self.queue_prediction_refresh(
1943 project.clone(),
1944 PredictEditsRequestTrigger::Diagnostics,
1945 project.entity_id(),
1946 cx,
1947 move |this, cx| {
1948 let Some((active_buffer, snapshot, cursor_point)) = this
1949 .read_with(cx, |this, cx| {
1950 let project_state = this.projects.get(&project.entity_id())?;
1951 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1952 let snapshot = buffer.read(cx).snapshot();
1953
1954 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1955 return None;
1956 }
1957
1958 let cursor_point = position
1959 .map(|pos| pos.to_point(&snapshot))
1960 .unwrap_or_default();
1961
1962 Some((buffer, snapshot, cursor_point))
1963 })
1964 .log_err()
1965 .flatten()
1966 else {
1967 return Task::ready(anyhow::Ok(None));
1968 };
1969
1970 cx.spawn(async move |cx| {
1971 let diagnostic_search_range = match scope {
1972 DiagnosticSearchScope::Local => {
1973 let diagnostic_search_start =
1974 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1975 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1976 Point::new(diagnostic_search_start, 0)
1977 ..Point::new(diagnostic_search_end, 0)
1978 }
1979 DiagnosticSearchScope::Global => Default::default(),
1980 };
1981
1982 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1983 active_buffer,
1984 &snapshot,
1985 diagnostic_search_range,
1986 cursor_point,
1987 &project,
1988 cx,
1989 )
1990 .await?
1991 else {
1992 return anyhow::Ok(None);
1993 };
1994
1995 let Some(prediction_result) = this
1996 .update(cx, |this, cx| {
1997 this.request_prediction(
1998 &project,
1999 &jump_buffer,
2000 jump_position,
2001 PredictEditsRequestTrigger::Diagnostics,
2002 cx,
2003 )
2004 })?
2005 .await?
2006 else {
2007 return anyhow::Ok(None);
2008 };
2009
2010 this.update(cx, |this, cx| {
2011 Some((
2012 if this
2013 .get_or_init_project(&project, cx)
2014 .current_prediction
2015 .is_none()
2016 {
2017 prediction_result
2018 } else {
2019 EditPredictionResult {
2020 id: prediction_result.id,
2021 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2022 }
2023 },
2024 PredictionRequestedBy::DiagnosticsUpdate,
2025 ))
2026 })
2027 })
2028 },
2029 );
2030 }
2031
2032 fn predictions_enabled_at(
2033 snapshot: &BufferSnapshot,
2034 position: Option<language::Anchor>,
2035 cx: &App,
2036 ) -> bool {
2037 let file = snapshot.file();
2038 let all_settings = all_language_settings(file, cx);
2039 if !all_settings.show_edit_predictions(snapshot.language(), cx)
2040 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2041 {
2042 return false;
2043 }
2044
2045 if let Some(last_position) = position {
2046 let settings = snapshot.settings_at(last_position, cx);
2047
2048 if !settings.edit_predictions_disabled_in.is_empty()
2049 && let Some(scope) = snapshot.language_scope_at(last_position)
2050 && let Some(scope_name) = scope.override_name()
2051 && settings
2052 .edit_predictions_disabled_in
2053 .iter()
2054 .any(|s| s == scope_name)
2055 {
2056 return false;
2057 }
2058 }
2059
2060 true
2061 }
2062
2063 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2064}
2065
2066fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2067 match provider {
2068 EditPredictionProvider::Zed
2069 | EditPredictionProvider::Sweep
2070 | EditPredictionProvider::Mercury
2071 | EditPredictionProvider::Ollama
2072 | EditPredictionProvider::OpenAiCompatibleApi
2073 | EditPredictionProvider::Experimental(_) => true,
2074 EditPredictionProvider::None
2075 | EditPredictionProvider::Copilot
2076 | EditPredictionProvider::Codestral => false,
2077 }
2078}
2079
2080impl EditPredictionStore {
2081 fn queue_prediction_refresh(
2082 &mut self,
2083 project: Entity<Project>,
2084 request_trigger: PredictEditsRequestTrigger,
2085 throttle_entity: EntityId,
2086 cx: &mut Context<Self>,
2087 do_refresh: impl FnOnce(
2088 WeakEntity<Self>,
2089 &mut AsyncApp,
2090 )
2091 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2092 + 'static,
2093 ) {
2094 fn select_throttle(
2095 project_state: &mut ProjectState,
2096 request_trigger: PredictEditsRequestTrigger,
2097 ) -> &mut Option<(EntityId, Instant)> {
2098 match request_trigger {
2099 PredictEditsRequestTrigger::Diagnostics => {
2100 &mut project_state.last_jump_prediction_refresh
2101 }
2102 _ => &mut project_state.last_edit_prediction_refresh,
2103 }
2104 }
2105
2106 let (needs_acceptance_tracking, max_pending_predictions) =
2107 match all_language_settings(None, cx).edit_predictions.provider {
2108 EditPredictionProvider::Zed
2109 | EditPredictionProvider::Sweep
2110 | EditPredictionProvider::Mercury
2111 | EditPredictionProvider::Experimental(_) => (true, 2),
2112 EditPredictionProvider::Ollama => (false, 1),
2113 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2114 EditPredictionProvider::None
2115 | EditPredictionProvider::Copilot
2116 | EditPredictionProvider::Codestral => {
2117 log::error!("queue_prediction_refresh called with non-store provider");
2118 return;
2119 }
2120 };
2121
2122 let drop_on_cancel = !needs_acceptance_tracking;
2123 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2124 let project_state = self.get_or_init_project(&project, cx);
2125 let pending_prediction_id = project_state.next_pending_prediction_id;
2126 project_state.next_pending_prediction_id += 1;
2127 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2128
2129 let task = cx.spawn(async move |this, cx| {
2130 let throttle_wait = this
2131 .update(cx, |this, cx| {
2132 let project_state = this.get_or_init_project(&project, cx);
2133 let throttle = *select_throttle(project_state, request_trigger);
2134
2135 throttle.and_then(|(last_entity, last_timestamp)| {
2136 if throttle_entity != last_entity {
2137 return None;
2138 }
2139 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2140 })
2141 })
2142 .ok()
2143 .flatten();
2144
2145 if let Some(timeout) = throttle_wait {
2146 cx.background_executor().timer(timeout).await;
2147 }
2148
2149 // If this task was cancelled before the throttle timeout expired,
2150 // do not perform a request. Also skip if another task already
2151 // proceeded since we were enqueued (duplicate).
2152 let mut is_cancelled = true;
2153 this.update(cx, |this, cx| {
2154 let project_state = this.get_or_init_project(&project, cx);
2155 let was_cancelled = project_state
2156 .cancelled_predictions
2157 .remove(&pending_prediction_id);
2158 if was_cancelled {
2159 return;
2160 }
2161
2162 // Another request has been already sent since this was enqueued
2163 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2164 return;
2165 }
2166
2167 let new_refresh = (throttle_entity, Instant::now());
2168 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2169 is_cancelled = false;
2170 })
2171 .ok();
2172 if is_cancelled {
2173 return None;
2174 }
2175
2176 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2177 let new_prediction_id = new_prediction_result
2178 .as_ref()
2179 .map(|(prediction, _)| prediction.id.clone());
2180
2181 // When a prediction completes, remove it from the pending list, and cancel
2182 // any pending predictions that were enqueued before it.
2183 this.update(cx, |this, cx| {
2184 let project_state = this.get_or_init_project(&project, cx);
2185
2186 let is_cancelled = project_state
2187 .cancelled_predictions
2188 .remove(&pending_prediction_id);
2189
2190 let new_current_prediction = if !is_cancelled
2191 && let Some((prediction_result, requested_by)) = new_prediction_result
2192 {
2193 match prediction_result.prediction {
2194 Ok(prediction) => {
2195 let new_prediction = CurrentEditPrediction {
2196 requested_by,
2197 prediction,
2198 was_shown: false,
2199 shown_with: None,
2200 };
2201
2202 if let Some(current_prediction) =
2203 project_state.current_prediction.as_ref()
2204 {
2205 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2206 {
2207 this.reject_current_prediction(
2208 EditPredictionRejectReason::Replaced,
2209 &project,
2210 cx,
2211 );
2212
2213 Some(new_prediction)
2214 } else {
2215 this.reject_prediction(
2216 new_prediction.prediction.id,
2217 EditPredictionRejectReason::CurrentPreferred,
2218 false,
2219 new_prediction.prediction.model_version,
2220 cx,
2221 );
2222 None
2223 }
2224 } else {
2225 Some(new_prediction)
2226 }
2227 }
2228 Err(reject_reason) => {
2229 this.reject_prediction(
2230 prediction_result.id,
2231 reject_reason,
2232 false,
2233 None,
2234 cx,
2235 );
2236 None
2237 }
2238 }
2239 } else {
2240 None
2241 };
2242
2243 let project_state = this.get_or_init_project(&project, cx);
2244
2245 if let Some(new_prediction) = new_current_prediction {
2246 project_state.current_prediction = Some(new_prediction);
2247 }
2248
2249 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2250 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2251 if pending_prediction.id == pending_prediction_id {
2252 pending_predictions.remove(ix);
2253 for pending_prediction in pending_predictions.drain(0..ix) {
2254 project_state.cancel_pending_prediction(pending_prediction, cx)
2255 }
2256 break;
2257 }
2258 }
2259 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2260 cx.notify();
2261 })
2262 .ok();
2263
2264 new_prediction_id
2265 });
2266
2267 if project_state.pending_predictions.len() < max_pending_predictions {
2268 project_state.pending_predictions.push(PendingPrediction {
2269 id: pending_prediction_id,
2270 task,
2271 drop_on_cancel,
2272 });
2273 } else {
2274 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2275 project_state.pending_predictions.push(PendingPrediction {
2276 id: pending_prediction_id,
2277 task,
2278 drop_on_cancel,
2279 });
2280 project_state.cancel_pending_prediction(pending_prediction, cx);
2281 }
2282 }
2283
2284 pub fn request_prediction(
2285 &mut self,
2286 project: &Entity<Project>,
2287 active_buffer: &Entity<Buffer>,
2288 position: language::Anchor,
2289 trigger: PredictEditsRequestTrigger,
2290 cx: &mut Context<Self>,
2291 ) -> Task<Result<Option<EditPredictionResult>>> {
2292 self.request_prediction_internal(
2293 project.clone(),
2294 active_buffer.clone(),
2295 position,
2296 trigger,
2297 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2298 cx,
2299 )
2300 }
2301
2302 fn request_prediction_internal(
2303 &mut self,
2304 project: Entity<Project>,
2305 active_buffer: Entity<Buffer>,
2306 position: language::Anchor,
2307 trigger: PredictEditsRequestTrigger,
2308 allow_jump: bool,
2309 cx: &mut Context<Self>,
2310 ) -> Task<Result<Option<EditPredictionResult>>> {
2311 self.get_or_init_project(&project, cx);
2312 let project_state = self.projects.get(&project.entity_id()).unwrap();
2313 let stored_events = project_state.events(cx);
2314 let has_events = !stored_events.is_empty();
2315 let events: Vec<Arc<zeta_prompt::Event>> =
2316 stored_events.iter().map(|e| e.event.clone()).collect();
2317 let debug_tx = project_state.debug_tx.clone();
2318
2319 let snapshot = active_buffer.read(cx).snapshot();
2320 let cursor_point = position.to_point(&snapshot);
2321 let current_offset = position.to_offset(&snapshot);
2322
2323 let mut user_actions: Vec<UserActionRecord> =
2324 project_state.user_actions.iter().cloned().collect();
2325
2326 if let Some(last_action) = user_actions.last() {
2327 if last_action.buffer_id == active_buffer.entity_id()
2328 && current_offset != last_action.offset
2329 {
2330 let timestamp_epoch_ms = SystemTime::now()
2331 .duration_since(UNIX_EPOCH)
2332 .map(|d| d.as_millis() as u64)
2333 .unwrap_or(0);
2334 user_actions.push(UserActionRecord {
2335 action_type: UserActionType::CursorMovement,
2336 buffer_id: active_buffer.entity_id(),
2337 line_number: cursor_point.row,
2338 offset: current_offset,
2339 timestamp_epoch_ms,
2340 });
2341 }
2342 }
2343 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2344 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2345 let diagnostic_search_range =
2346 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2347
2348 let related_files = self.context_for_project(&project, cx);
2349
2350 let is_open_source = snapshot
2351 .file()
2352 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2353 && events.iter().all(|event| event.in_open_source_repo())
2354 && related_files.iter().all(|file| file.in_open_source_repo);
2355
2356 let can_collect_data = !cfg!(test)
2357 && is_open_source
2358 && self.is_data_collection_enabled(cx)
2359 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2360
2361 let recent_paths = project_state.recent_paths.clone();
2362
2363 let inputs = EditPredictionModelInput {
2364 project: project.clone(),
2365 buffer: active_buffer,
2366 snapshot,
2367 position,
2368 events,
2369 related_files,
2370 recent_paths,
2371 trigger,
2372 diagnostic_search_range: diagnostic_search_range,
2373 debug_tx,
2374 user_actions,
2375 can_collect_data,
2376 is_open_source,
2377 };
2378
2379 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2380
2381 let task = match self.edit_prediction_model {
2382 EditPredictionModel::Zeta => {
2383 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2384 }
2385 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2386 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2387 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2388 };
2389
2390 cx.spawn(async move |this, cx| {
2391 let prediction = task.await?;
2392
2393 // Only fall back to diagnostics-based prediction if we got a
2394 // the model had nothing to suggest for the buffer
2395 if prediction.is_none()
2396 && allow_jump
2397 && has_events
2398 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2399 {
2400 this.update(cx, |this, cx| {
2401 this.refresh_prediction_from_diagnostics(
2402 project,
2403 DiagnosticSearchScope::Local,
2404 cx,
2405 );
2406 })?;
2407 return anyhow::Ok(None);
2408 }
2409
2410 Ok(prediction)
2411 })
2412 }
2413
2414 pub(crate) async fn next_diagnostic_location(
2415 active_buffer: Entity<Buffer>,
2416 active_buffer_snapshot: &BufferSnapshot,
2417 active_buffer_diagnostic_search_range: Range<Point>,
2418 active_buffer_cursor_point: Point,
2419 project: &Entity<Project>,
2420 cx: &mut AsyncApp,
2421 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2422 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2423 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2424 .flat_map(|(_, _, _, selections)| {
2425 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2426 })
2427 .collect();
2428
2429 let mut jump_location = active_buffer_snapshot
2430 .diagnostic_groups(None)
2431 .into_iter()
2432 .filter_map(|(_, group)| {
2433 let range = &group.entries[group.primary_ix]
2434 .range
2435 .to_point(&active_buffer_snapshot);
2436 if range.overlaps(&active_buffer_diagnostic_search_range) {
2437 return None;
2438 }
2439 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2440 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2441 });
2442 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2443 <= DIAGNOSTIC_LINES_RANGE;
2444 if near_collaborator && !near_local {
2445 return None;
2446 }
2447 Some(range.start)
2448 })
2449 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2450 .map(|position| {
2451 (
2452 active_buffer.clone(),
2453 active_buffer_snapshot.anchor_before(position),
2454 )
2455 });
2456
2457 if jump_location.is_none() {
2458 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2459 let file = buffer.file()?;
2460
2461 Some(ProjectPath {
2462 worktree_id: file.worktree_id(cx),
2463 path: file.path().clone(),
2464 })
2465 });
2466
2467 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2468 project
2469 .diagnostic_summaries(false, cx)
2470 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2471 .map(|(path, _, _)| {
2472 let shared_prefix = path
2473 .path
2474 .components()
2475 .zip(
2476 active_buffer_path
2477 .as_ref()
2478 .map(|p| p.path.components())
2479 .unwrap_or_default(),
2480 )
2481 .take_while(|(a, b)| a == b)
2482 .count();
2483 (path, shared_prefix)
2484 })
2485 .collect()
2486 });
2487
2488 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2489
2490 for (path, _) in candidates {
2491 let candidate_buffer = project
2492 .update(cx, |project, cx| project.open_buffer(path, cx))
2493 .await?;
2494
2495 let (has_collaborators, diagnostic_position) =
2496 candidate_buffer.read_with(cx, |buffer, _cx| {
2497 let snapshot = buffer.snapshot();
2498 let has_collaborators = snapshot
2499 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2500 .next()
2501 .is_some();
2502 let position = buffer
2503 .buffer_diagnostics(None)
2504 .into_iter()
2505 .min_by_key(|entry| entry.diagnostic.severity)
2506 .map(|entry| entry.range.start);
2507 (has_collaborators, position)
2508 });
2509
2510 if has_collaborators {
2511 continue;
2512 }
2513
2514 if let Some(position) = diagnostic_position {
2515 jump_location = Some((candidate_buffer, position));
2516 break;
2517 }
2518 }
2519 }
2520
2521 anyhow::Ok(jump_location)
2522 }
2523
2524 async fn send_raw_llm_request(
2525 request: RawCompletionRequest,
2526 client: Arc<Client>,
2527 custom_url: Option<Arc<Url>>,
2528 llm_token: LlmApiToken,
2529 organization_id: Option<OrganizationId>,
2530 app_version: Version,
2531 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2532 let url = if let Some(custom_url) = custom_url {
2533 custom_url.as_ref().clone()
2534 } else {
2535 client
2536 .http_client()
2537 .build_zed_llm_url("/predict_edits/raw", &[])?
2538 };
2539
2540 Self::send_api_request(
2541 |builder| {
2542 let req = builder
2543 .uri(url.as_ref())
2544 .body(serde_json::to_string(&request)?.into());
2545 Ok(req?)
2546 },
2547 client,
2548 llm_token,
2549 organization_id,
2550 app_version,
2551 true,
2552 )
2553 .await
2554 }
2555
2556 pub(crate) async fn send_v3_request(
2557 input: ZetaPromptInput,
2558 client: Arc<Client>,
2559 llm_token: LlmApiToken,
2560 organization_id: Option<OrganizationId>,
2561 app_version: Version,
2562 trigger: PredictEditsRequestTrigger,
2563 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2564 let url = client
2565 .http_client()
2566 .build_zed_llm_url("/predict_edits/v3", &[])?;
2567
2568 let request = PredictEditsV3Request { input, trigger };
2569
2570 let json_bytes = serde_json::to_vec(&request)?;
2571 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2572
2573 Self::send_api_request(
2574 |builder| {
2575 let req = builder
2576 .uri(url.as_ref())
2577 .header("Content-Encoding", "zstd")
2578 .body(compressed.clone().into());
2579 Ok(req?)
2580 },
2581 client,
2582 llm_token,
2583 organization_id,
2584 app_version,
2585 true,
2586 )
2587 .await
2588 }
2589
2590 async fn send_api_request<Res>(
2591 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2592 client: Arc<Client>,
2593 llm_token: LlmApiToken,
2594 organization_id: Option<OrganizationId>,
2595 app_version: Version,
2596 require_auth: bool,
2597 ) -> Result<(Res, Option<EditPredictionUsage>)>
2598 where
2599 Res: DeserializeOwned,
2600 {
2601 let http_client = client.http_client();
2602
2603 let mut token = if require_auth {
2604 Some(llm_token.acquire(&client, organization_id.clone()).await?)
2605 } else {
2606 llm_token
2607 .acquire(&client, organization_id.clone())
2608 .await
2609 .ok()
2610 };
2611 let mut did_retry = false;
2612
2613 loop {
2614 let request_builder = http_client::Request::builder().method(Method::POST);
2615
2616 let mut request_builder = request_builder
2617 .header("Content-Type", "application/json")
2618 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2619
2620 // Only add Authorization header if we have a token
2621 if let Some(ref token_value) = token {
2622 request_builder =
2623 request_builder.header("Authorization", format!("Bearer {}", token_value));
2624 }
2625
2626 let request = build(request_builder)?;
2627
2628 let mut response = http_client.send(request).await?;
2629
2630 if let Some(minimum_required_version) = response
2631 .headers()
2632 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2633 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2634 {
2635 anyhow::ensure!(
2636 app_version >= minimum_required_version,
2637 ZedUpdateRequiredError {
2638 minimum_version: minimum_required_version
2639 }
2640 );
2641 }
2642
2643 if response.status().is_success() {
2644 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2645
2646 let mut body = Vec::new();
2647 response.body_mut().read_to_end(&mut body).await?;
2648 return Ok((serde_json::from_slice(&body)?, usage));
2649 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2650 did_retry = true;
2651 token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2652 } else {
2653 let mut body = String::new();
2654 response.body_mut().read_to_string(&mut body).await?;
2655 anyhow::bail!(
2656 "Request failed with status: {:?}\nBody: {}",
2657 response.status(),
2658 body
2659 );
2660 }
2661 }
2662 }
2663
2664 pub fn refresh_context(
2665 &mut self,
2666 project: &Entity<Project>,
2667 buffer: &Entity<language::Buffer>,
2668 cursor_position: language::Anchor,
2669 cx: &mut Context<Self>,
2670 ) {
2671 self.get_or_init_project(project, cx)
2672 .context
2673 .update(cx, |store, cx| {
2674 store.refresh(buffer.clone(), cursor_position, cx);
2675 });
2676 }
2677
2678 #[cfg(feature = "cli-support")]
2679 pub fn set_context_for_buffer(
2680 &mut self,
2681 project: &Entity<Project>,
2682 related_files: Vec<RelatedFile>,
2683 cx: &mut Context<Self>,
2684 ) {
2685 self.get_or_init_project(project, cx)
2686 .context
2687 .update(cx, |store, cx| {
2688 store.set_related_files(related_files, cx);
2689 });
2690 }
2691
2692 #[cfg(feature = "cli-support")]
2693 pub fn set_recent_paths_for_project(
2694 &mut self,
2695 project: &Entity<Project>,
2696 paths: impl IntoIterator<Item = project::ProjectPath>,
2697 cx: &mut Context<Self>,
2698 ) {
2699 let project_state = self.get_or_init_project(project, cx);
2700 project_state.recent_paths = paths.into_iter().collect();
2701 }
2702
2703 fn is_file_open_source(
2704 &self,
2705 project: &Entity<Project>,
2706 file: &Arc<dyn File>,
2707 cx: &App,
2708 ) -> bool {
2709 if !file.is_local() || file.is_private() {
2710 return false;
2711 }
2712 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2713 return false;
2714 };
2715 project_state
2716 .license_detection_watchers
2717 .get(&file.worktree_id(cx))
2718 .as_ref()
2719 .is_some_and(|watcher| watcher.is_project_open_source())
2720 }
2721
2722 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2723 self.data_collection_choice.is_enabled(cx)
2724 }
2725
2726 fn load_data_collection_choice() -> DataCollectionChoice {
2727 let choice = KEY_VALUE_STORE
2728 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2729 .log_err()
2730 .flatten();
2731
2732 match choice.as_deref() {
2733 Some("true") => DataCollectionChoice::Enabled,
2734 Some("false") => DataCollectionChoice::Disabled,
2735 Some(_) => {
2736 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2737 DataCollectionChoice::NotAnswered
2738 }
2739 None => DataCollectionChoice::NotAnswered,
2740 }
2741 }
2742
2743 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2744 self.data_collection_choice = self.data_collection_choice.toggle();
2745 let new_choice = self.data_collection_choice;
2746 let is_enabled = new_choice.is_enabled(cx);
2747 db::write_and_log(cx, move || {
2748 KEY_VALUE_STORE.write_kvp(
2749 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2750 is_enabled.to_string(),
2751 )
2752 });
2753 }
2754
2755 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2756 self.shown_predictions.iter()
2757 }
2758
2759 pub fn shown_completions_len(&self) -> usize {
2760 self.shown_predictions.len()
2761 }
2762
2763 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2764 self.rated_predictions.contains(id)
2765 }
2766
2767 pub fn rate_prediction(
2768 &mut self,
2769 prediction: &EditPrediction,
2770 rating: EditPredictionRating,
2771 feedback: String,
2772 cx: &mut Context<Self>,
2773 ) {
2774 let organization = self.user_store.read(cx).current_organization();
2775
2776 self.rated_predictions.insert(prediction.id.clone());
2777
2778 cx.background_spawn({
2779 let client = self.client.clone();
2780 let prediction_id = prediction.id.to_string();
2781 let inputs = serde_json::to_value(&prediction.inputs);
2782 let output = prediction
2783 .edit_preview
2784 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2785 async move {
2786 client
2787 .cloud_client()
2788 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2789 organization_id: organization.map(|organization| organization.id.clone()),
2790 request_id: prediction_id,
2791 rating: match rating {
2792 EditPredictionRating::Positive => "positive".to_string(),
2793 EditPredictionRating::Negative => "negative".to_string(),
2794 },
2795 inputs: inputs?,
2796 output,
2797 feedback,
2798 })
2799 .await?;
2800
2801 anyhow::Ok(())
2802 }
2803 })
2804 .detach_and_log_err(cx);
2805
2806 cx.notify();
2807 }
2808}
2809
2810fn collaborator_edit_overlaps_locality_region(
2811 project_state: &ProjectState,
2812 project: &Entity<Project>,
2813 buffer: &Entity<Buffer>,
2814 snapshot: &BufferSnapshot,
2815 edit_range: &Range<Anchor>,
2816 cx: &App,
2817) -> bool {
2818 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2819 return false;
2820 };
2821
2822 if active_buffer.entity_id() != buffer.entity_id() {
2823 return false;
2824 }
2825
2826 let locality_point_range = expand_context_syntactically_then_linewise(
2827 snapshot,
2828 (position..position).to_point(snapshot),
2829 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2830 );
2831 let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2832
2833 edit_range.overlaps(&locality_anchor_range, snapshot)
2834}
2835
2836fn merge_trailing_events_if_needed(
2837 events: &mut VecDeque<StoredEvent>,
2838 end_snapshot: &TextBufferSnapshot,
2839 latest_snapshot: &TextBufferSnapshot,
2840 latest_edit_range: &Range<Anchor>,
2841) {
2842 if let Some(last_event) = events.back() {
2843 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2844 return;
2845 }
2846 if !latest_snapshot
2847 .version
2848 .observed_all(&last_event.new_snapshot_version)
2849 {
2850 return;
2851 }
2852 }
2853
2854 let mut next_old_event = None;
2855 let mut mergeable_count = 0;
2856 for old_event in events.iter().rev() {
2857 if let Some(next_old_event) = next_old_event
2858 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2859 {
2860 break;
2861 }
2862 mergeable_count += 1;
2863 next_old_event = Some(old_event);
2864 }
2865
2866 if mergeable_count <= 1 {
2867 return;
2868 }
2869
2870 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2871 let oldest_event = events_to_merge.peek().unwrap();
2872 let oldest_snapshot = oldest_event.old_snapshot.clone();
2873 let newest_snapshot = end_snapshot;
2874 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2875
2876 for event in events.range(events.len() - mergeable_count + 1..) {
2877 merged_edit_range =
2878 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2879 }
2880
2881 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2882 &oldest_snapshot,
2883 newest_snapshot,
2884 &merged_edit_range,
2885 ) {
2886 let merged_event = match oldest_event.event.as_ref() {
2887 zeta_prompt::Event::BufferChange {
2888 old_path,
2889 path,
2890 in_open_source_repo,
2891 ..
2892 } => StoredEvent {
2893 event: Arc::new(zeta_prompt::Event::BufferChange {
2894 old_path: old_path.clone(),
2895 path: path.clone(),
2896 diff,
2897 in_open_source_repo: *in_open_source_repo,
2898 predicted: events_to_merge.all(|e| {
2899 matches!(
2900 e.event.as_ref(),
2901 zeta_prompt::Event::BufferChange {
2902 predicted: true,
2903 ..
2904 }
2905 )
2906 }),
2907 }),
2908 old_snapshot: oldest_snapshot.clone(),
2909 new_snapshot_version: newest_snapshot.version.clone(),
2910 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2911 ..newest_snapshot.anchor_before(edit_range.end),
2912 },
2913 };
2914 events.truncate(events.len() - mergeable_count);
2915 events.push_back(merged_event);
2916 }
2917}
2918
2919fn merge_anchor_ranges(
2920 left: &Range<Anchor>,
2921 right: &Range<Anchor>,
2922 snapshot: &TextBufferSnapshot,
2923) -> Range<Anchor> {
2924 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2925 left.start
2926 } else {
2927 right.start
2928 };
2929 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2930 left.end
2931 } else {
2932 right.end
2933 };
2934 start..end
2935}
2936
2937#[derive(Error, Debug)]
2938#[error(
2939 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2940)]
2941pub struct ZedUpdateRequiredError {
2942 minimum_version: Version,
2943}
2944
2945#[derive(Debug, Clone, Copy)]
2946pub enum DataCollectionChoice {
2947 NotAnswered,
2948 Enabled,
2949 Disabled,
2950}
2951
2952impl DataCollectionChoice {
2953 pub fn is_enabled(self, cx: &App) -> bool {
2954 if cx.is_staff() {
2955 return true;
2956 }
2957 match self {
2958 Self::Enabled => true,
2959 Self::NotAnswered | Self::Disabled => false,
2960 }
2961 }
2962
2963 #[must_use]
2964 pub fn toggle(&self) -> DataCollectionChoice {
2965 match self {
2966 Self::Enabled => Self::Disabled,
2967 Self::Disabled => Self::Enabled,
2968 Self::NotAnswered => Self::Enabled,
2969 }
2970 }
2971}
2972
2973impl From<bool> for DataCollectionChoice {
2974 fn from(value: bool) -> Self {
2975 match value {
2976 true => DataCollectionChoice::Enabled,
2977 false => DataCollectionChoice::Disabled,
2978 }
2979 }
2980}
2981
2982struct ZedPredictUpsell;
2983
2984impl Dismissable for ZedPredictUpsell {
2985 const KEY: &'static str = "dismissed-edit-predict-upsell";
2986
2987 fn dismissed() -> bool {
2988 // To make this backwards compatible with older versions of Zed, we
2989 // check if the user has seen the previous Edit Prediction Onboarding
2990 // before, by checking the data collection choice which was written to
2991 // the database once the user clicked on "Accept and Enable"
2992 if KEY_VALUE_STORE
2993 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2994 .log_err()
2995 .is_some_and(|s| s.is_some())
2996 {
2997 return true;
2998 }
2999
3000 KEY_VALUE_STORE
3001 .read_kvp(Self::KEY)
3002 .log_err()
3003 .is_some_and(|s| s.is_some())
3004 }
3005}
3006
3007pub fn should_show_upsell_modal() -> bool {
3008 !ZedPredictUpsell::dismissed()
3009}
3010
3011pub fn init(cx: &mut App) {
3012 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3013 workspace.register_action(
3014 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3015 ZedPredictModal::toggle(
3016 workspace,
3017 workspace.user_store().clone(),
3018 workspace.client().clone(),
3019 window,
3020 cx,
3021 )
3022 },
3023 );
3024
3025 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3026 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3027 settings
3028 .project
3029 .all_languages
3030 .edit_predictions
3031 .get_or_insert_default()
3032 .provider = Some(EditPredictionProvider::None)
3033 });
3034 });
3035 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3036 EditPredictionStore::try_global(cx).and_then(|store| {
3037 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3038 })
3039 }
3040
3041 workspace.register_action(|workspace, _: &SignIn, window, cx| {
3042 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3043 copilot_ui::initiate_sign_in(copilot, window, cx);
3044 }
3045 });
3046 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3047 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3048 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3049 }
3050 });
3051 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3052 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3053 copilot_ui::initiate_sign_out(copilot, window, cx);
3054 }
3055 });
3056 })
3057 .detach();
3058}