1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::{AppState, Workspace};
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56
57pub mod cursor_excerpt;
58pub mod example_spec;
59pub mod fim;
60mod license_detection;
61pub mod mercury;
62pub mod ollama;
63mod onboarding_modal;
64pub mod open_ai_response;
65mod prediction;
66pub mod sweep_ai;
67
68pub mod udiff;
69
70mod capture_example;
71pub mod open_ai_compatible;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
79use crate::example_spec::ExampleSpec;
80use crate::license_detection::LicenseDetectionWatcher;
81use crate::mercury::Mercury;
82use crate::onboarding_modal::ZedPredictModal;
83pub use crate::prediction::EditPrediction;
84pub use crate::prediction::EditPredictionId;
85use crate::prediction::EditPredictionResult;
86pub use crate::sweep_ai::SweepAi;
87pub use capture_example::capture_example;
88pub use language_model::ApiKeyState;
89pub use telemetry_events::EditPredictionRating;
90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
91
92actions!(
93 edit_prediction,
94 [
95 /// Resets the edit prediction onboarding state.
96 ResetOnboarding,
97 /// Clears the edit prediction history.
98 ClearHistory,
99 ]
100);
101
102/// Maximum number of events to track.
103const EVENT_COUNT_MAX: usize = 10;
104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
105const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
106const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
107const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
108const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
109const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
110const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
111const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
112
113pub struct EditPredictionJumpsFeatureFlag;
114
115impl FeatureFlag for EditPredictionJumpsFeatureFlag {
116 const NAME: &'static str = "edit_prediction_jumps";
117}
118
119#[derive(Clone)]
120struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
121
122impl Global for EditPredictionStoreGlobal {}
123
124/// Configuration for using the raw Zeta2 endpoint.
125/// When set, the client uses the raw endpoint and constructs the prompt itself.
126/// The version is also used as the Baseten environment name (lowercased).
127#[derive(Clone)]
128pub struct Zeta2RawConfig {
129 pub model_id: Option<String>,
130 pub environment: Option<String>,
131 pub format: ZetaFormat,
132}
133
134pub struct EditPredictionStore {
135 client: Arc<Client>,
136 user_store: Entity<UserStore>,
137 llm_token: LlmApiToken,
138 _fetch_experiments_task: Task<()>,
139 projects: HashMap<EntityId, ProjectState>,
140 update_required: bool,
141 edit_prediction_model: EditPredictionModel,
142 zeta2_raw_config: Option<Zeta2RawConfig>,
143 preferred_experiment: Option<String>,
144 available_experiments: Vec<String>,
145 pub sweep_ai: SweepAi,
146 pub mercury: Mercury,
147 data_collection_choice: DataCollectionChoice,
148 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
149 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
150 shown_predictions: VecDeque<EditPrediction>,
151 rated_predictions: HashSet<EditPredictionId>,
152 #[cfg(test)]
153 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
154}
155
156pub(crate) struct EditPredictionRejectionPayload {
157 rejection: EditPredictionRejection,
158 organization_id: Option<OrganizationId>,
159}
160
161#[derive(Copy, Clone, PartialEq, Eq)]
162pub enum EditPredictionModel {
163 Zeta,
164 Fim { format: EditPredictionPromptFormat },
165 Sweep,
166 Mercury,
167}
168
169#[derive(Clone)]
170pub struct EditPredictionModelInput {
171 project: Entity<Project>,
172 buffer: Entity<Buffer>,
173 snapshot: BufferSnapshot,
174 position: Anchor,
175 events: Vec<Arc<zeta_prompt::Event>>,
176 related_files: Vec<RelatedFile>,
177 recent_paths: VecDeque<ProjectPath>,
178 trigger: PredictEditsRequestTrigger,
179 diagnostic_search_range: Range<Point>,
180 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
181 can_collect_data: bool,
182 is_open_source: bool,
183 pub user_actions: Vec<UserActionRecord>,
184}
185
186#[derive(Debug)]
187pub enum DebugEvent {
188 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
189 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
190 EditPredictionStarted(EditPredictionStartedDebugEvent),
191 EditPredictionFinished(EditPredictionFinishedDebugEvent),
192}
193
194#[derive(Debug)]
195pub struct ContextRetrievalStartedDebugEvent {
196 pub project_entity_id: EntityId,
197 pub timestamp: Instant,
198 pub search_prompt: String,
199}
200
201#[derive(Debug)]
202pub struct ContextRetrievalFinishedDebugEvent {
203 pub project_entity_id: EntityId,
204 pub timestamp: Instant,
205 pub metadata: Vec<(&'static str, SharedString)>,
206}
207
208#[derive(Debug)]
209pub struct EditPredictionStartedDebugEvent {
210 pub buffer: WeakEntity<Buffer>,
211 pub position: Anchor,
212 pub prompt: Option<String>,
213}
214
215#[derive(Debug)]
216pub struct EditPredictionFinishedDebugEvent {
217 pub buffer: WeakEntity<Buffer>,
218 pub position: Anchor,
219 pub model_output: Option<String>,
220}
221
222const USER_ACTION_HISTORY_SIZE: usize = 16;
223
224#[derive(Clone, Debug)]
225pub struct UserActionRecord {
226 pub action_type: UserActionType,
227 pub buffer_id: EntityId,
228 pub line_number: u32,
229 pub offset: usize,
230 pub timestamp_epoch_ms: u64,
231}
232
233#[derive(Clone, Copy, Debug, PartialEq, Eq)]
234pub enum UserActionType {
235 InsertChar,
236 InsertSelection,
237 DeleteChar,
238 DeleteSelection,
239 CursorMovement,
240}
241
242/// An event with associated metadata for reconstructing buffer state.
243#[derive(Clone)]
244pub struct StoredEvent {
245 pub event: Arc<zeta_prompt::Event>,
246 pub old_snapshot: TextBufferSnapshot,
247 pub new_snapshot_version: clock::Global,
248 pub total_edit_range: Range<Anchor>,
249}
250
251impl StoredEvent {
252 fn can_merge(
253 &self,
254 next_old_event: &StoredEvent,
255 latest_snapshot: &TextBufferSnapshot,
256 latest_edit_range: &Range<Anchor>,
257 ) -> bool {
258 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
259 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
260 return false;
261 }
262 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
263 return false;
264 }
265 if self.new_snapshot_version != next_old_event.old_snapshot.version {
266 return false;
267 }
268 if !latest_snapshot
269 .version
270 .observed_all(&next_old_event.new_snapshot_version)
271 {
272 return false;
273 }
274
275 let a_is_predicted = matches!(
276 self.event.as_ref(),
277 zeta_prompt::Event::BufferChange {
278 predicted: true,
279 ..
280 }
281 );
282 let b_is_predicted = matches!(
283 next_old_event.event.as_ref(),
284 zeta_prompt::Event::BufferChange {
285 predicted: true,
286 ..
287 }
288 );
289
290 // If events come from the same source (both predicted or both manual) then
291 // we would have coalesced them already.
292 if a_is_predicted == b_is_predicted {
293 return false;
294 }
295
296 let left_range = self.total_edit_range.to_point(latest_snapshot);
297 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
298 let latest_range = latest_edit_range.to_point(latest_snapshot);
299
300 // Events near to the latest edit are not merged if their sources differ.
301 if lines_between_ranges(&left_range, &latest_range)
302 .min(lines_between_ranges(&right_range, &latest_range))
303 <= CHANGE_GROUPING_LINE_SPAN
304 {
305 return false;
306 }
307
308 // Events that are distant from each other are not merged.
309 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
310 return false;
311 }
312
313 true
314 }
315}
316
317fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
318 if left.start > right.end {
319 return left.start.row - right.end.row;
320 }
321 if right.start > left.end {
322 return right.start.row - left.end.row;
323 }
324 0
325}
326
327struct ProjectState {
328 events: VecDeque<StoredEvent>,
329 last_event: Option<LastEvent>,
330 recent_paths: VecDeque<ProjectPath>,
331 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
332 current_prediction: Option<CurrentEditPrediction>,
333 next_pending_prediction_id: usize,
334 pending_predictions: ArrayVec<PendingPrediction, 2>,
335 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
336 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
337 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
338 cancelled_predictions: HashSet<usize>,
339 context: Entity<RelatedExcerptStore>,
340 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
341 user_actions: VecDeque<UserActionRecord>,
342 _subscriptions: [gpui::Subscription; 2],
343 copilot: Option<Entity<Copilot>>,
344}
345
346impl ProjectState {
347 fn record_user_action(&mut self, action: UserActionRecord) {
348 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
349 self.user_actions.pop_front();
350 }
351 self.user_actions.push_back(action);
352 }
353
354 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
355 self.events
356 .iter()
357 .cloned()
358 .chain(self.last_event.as_ref().iter().flat_map(|event| {
359 let (one, two) = event.split_by_pause();
360 let one = one.finalize(&self.license_detection_watchers, cx);
361 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
362 one.into_iter().chain(two)
363 }))
364 .collect()
365 }
366
367 fn cancel_pending_prediction(
368 &mut self,
369 pending_prediction: PendingPrediction,
370 cx: &mut Context<EditPredictionStore>,
371 ) {
372 self.cancelled_predictions.insert(pending_prediction.id);
373
374 if pending_prediction.drop_on_cancel {
375 drop(pending_prediction.task);
376 } else {
377 cx.spawn(async move |this, cx| {
378 let Some(prediction_id) = pending_prediction.task.await else {
379 return;
380 };
381
382 this.update(cx, |this, cx| {
383 this.reject_prediction(
384 prediction_id,
385 EditPredictionRejectReason::Canceled,
386 false,
387 None,
388 cx,
389 );
390 })
391 .ok();
392 })
393 .detach()
394 }
395 }
396
397 fn active_buffer(
398 &self,
399 project: &Entity<Project>,
400 cx: &App,
401 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
402 let project = project.read(cx);
403 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
404 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
405 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
406 Some((active_buffer, registered_buffer.last_position))
407 }
408}
409
410#[derive(Debug, Clone)]
411struct CurrentEditPrediction {
412 pub requested_by: PredictionRequestedBy,
413 pub prediction: EditPrediction,
414 pub was_shown: bool,
415 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
416}
417
418impl CurrentEditPrediction {
419 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
420 let Some(new_edits) = self
421 .prediction
422 .interpolate(&self.prediction.buffer.read(cx))
423 else {
424 return false;
425 };
426
427 if self.prediction.buffer != old_prediction.prediction.buffer {
428 return true;
429 }
430
431 let Some(old_edits) = old_prediction
432 .prediction
433 .interpolate(&old_prediction.prediction.buffer.read(cx))
434 else {
435 return true;
436 };
437
438 let requested_by_buffer_id = self.requested_by.buffer_id();
439
440 // This reduces the occurrence of UI thrash from replacing edits
441 //
442 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
443 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
444 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
445 && old_edits.len() == 1
446 && new_edits.len() == 1
447 {
448 let (old_range, old_text) = &old_edits[0];
449 let (new_range, new_text) = &new_edits[0];
450 new_range == old_range && new_text.starts_with(old_text.as_ref())
451 } else {
452 true
453 }
454 }
455}
456
457#[derive(Debug, Clone)]
458enum PredictionRequestedBy {
459 DiagnosticsUpdate,
460 Buffer(EntityId),
461}
462
463impl PredictionRequestedBy {
464 pub fn buffer_id(&self) -> Option<EntityId> {
465 match self {
466 PredictionRequestedBy::DiagnosticsUpdate => None,
467 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
468 }
469 }
470}
471
472const DIAGNOSTIC_LINES_RANGE: u32 = 20;
473
474#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
475pub enum DiagnosticSearchScope {
476 Local,
477 Global,
478}
479
480#[derive(Debug)]
481struct PendingPrediction {
482 id: usize,
483 task: Task<Option<EditPredictionId>>,
484 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
485 /// If false, the task is awaited to completion so rejection can be reported.
486 drop_on_cancel: bool,
487}
488
489/// A prediction from the perspective of a buffer.
490#[derive(Debug)]
491enum BufferEditPrediction<'a> {
492 Local { prediction: &'a EditPrediction },
493 Jump { prediction: &'a EditPrediction },
494}
495
496#[cfg(test)]
497impl std::ops::Deref for BufferEditPrediction<'_> {
498 type Target = EditPrediction;
499
500 fn deref(&self) -> &Self::Target {
501 match self {
502 BufferEditPrediction::Local { prediction } => prediction,
503 BufferEditPrediction::Jump { prediction } => prediction,
504 }
505 }
506}
507
508#[derive(Clone)]
509struct PendingSettledPrediction {
510 request_id: EditPredictionId,
511 editable_anchor_range: Range<Anchor>,
512 example: Option<ExampleSpec>,
513 enqueued_at: Instant,
514 last_edit_at: Instant,
515}
516
517struct RegisteredBuffer {
518 file: Option<Arc<dyn File>>,
519 snapshot: TextBufferSnapshot,
520 pending_predictions: Vec<PendingSettledPrediction>,
521 last_position: Option<Anchor>,
522 _subscriptions: [gpui::Subscription; 2],
523}
524
525#[derive(Clone)]
526struct LastEvent {
527 old_snapshot: TextBufferSnapshot,
528 new_snapshot: TextBufferSnapshot,
529 old_file: Option<Arc<dyn File>>,
530 new_file: Option<Arc<dyn File>>,
531 latest_edit_range: Range<Anchor>,
532 total_edit_range: Range<Anchor>,
533 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
534 predicted: bool,
535 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
536 last_edit_time: Option<Instant>,
537}
538
539impl LastEvent {
540 pub fn finalize(
541 &self,
542 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
543 cx: &App,
544 ) -> Option<StoredEvent> {
545 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
546 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
547
548 let in_open_source_repo =
549 [self.new_file.as_ref(), self.old_file.as_ref()]
550 .iter()
551 .all(|file| {
552 file.is_some_and(|file| {
553 license_detection_watchers
554 .get(&file.worktree_id(cx))
555 .is_some_and(|watcher| watcher.is_project_open_source())
556 })
557 });
558
559 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
560 &self.old_snapshot,
561 &self.new_snapshot,
562 &self.total_edit_range,
563 )?;
564
565 if path == old_path && diff.is_empty() {
566 None
567 } else {
568 Some(StoredEvent {
569 event: Arc::new(zeta_prompt::Event::BufferChange {
570 old_path,
571 path,
572 diff,
573 in_open_source_repo,
574 predicted: self.predicted,
575 }),
576 old_snapshot: self.old_snapshot.clone(),
577 new_snapshot_version: self.new_snapshot.version.clone(),
578 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
579 ..self.new_snapshot.anchor_before(edit_range.end),
580 })
581 }
582 }
583
584 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
585 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
586 return (self.clone(), None);
587 };
588
589 let total_edit_range_before_pause = self
590 .total_edit_range_at_last_pause_boundary
591 .clone()
592 .unwrap_or_else(|| self.total_edit_range.clone());
593
594 let Some(total_edit_range_after_pause) =
595 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
596 else {
597 return (self.clone(), None);
598 };
599
600 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
601 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
602
603 let before = LastEvent {
604 old_snapshot: self.old_snapshot.clone(),
605 new_snapshot: boundary_snapshot.clone(),
606 old_file: self.old_file.clone(),
607 new_file: self.new_file.clone(),
608 latest_edit_range: latest_edit_range_before_pause,
609 total_edit_range: total_edit_range_before_pause,
610 total_edit_range_at_last_pause_boundary: None,
611 predicted: self.predicted,
612 snapshot_after_last_editing_pause: None,
613 last_edit_time: self.last_edit_time,
614 };
615
616 let after = LastEvent {
617 old_snapshot: boundary_snapshot.clone(),
618 new_snapshot: self.new_snapshot.clone(),
619 old_file: self.old_file.clone(),
620 new_file: self.new_file.clone(),
621 latest_edit_range: latest_edit_range_after_pause,
622 total_edit_range: total_edit_range_after_pause,
623 total_edit_range_at_last_pause_boundary: None,
624 predicted: self.predicted,
625 snapshot_after_last_editing_pause: None,
626 last_edit_time: self.last_edit_time,
627 };
628
629 (before, Some(after))
630 }
631}
632
633fn compute_total_edit_range_between_snapshots(
634 old_snapshot: &TextBufferSnapshot,
635 new_snapshot: &TextBufferSnapshot,
636) -> Option<Range<Anchor>> {
637 let edits: Vec<Edit<usize>> = new_snapshot
638 .edits_since::<usize>(&old_snapshot.version)
639 .collect();
640
641 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
642 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
643 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
644
645 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
646}
647
648fn compute_old_range_for_new_range(
649 old_snapshot: &TextBufferSnapshot,
650 new_snapshot: &TextBufferSnapshot,
651 total_edit_range: &Range<Anchor>,
652) -> Option<Range<Point>> {
653 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
654 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
655
656 let edits: Vec<Edit<usize>> = new_snapshot
657 .edits_since::<usize>(&old_snapshot.version)
658 .collect();
659 let mut old_start_offset = None;
660 let mut old_end_offset = None;
661 let mut delta: isize = 0;
662
663 for edit in &edits {
664 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
665 old_start_offset = Some(if new_start_offset < edit.new.start {
666 new_start_offset.checked_add_signed(-delta)?
667 } else {
668 edit.old.start
669 });
670 }
671
672 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
673 old_end_offset = Some(if new_end_offset < edit.new.start {
674 new_end_offset.checked_add_signed(-delta)?
675 } else {
676 edit.old.end
677 });
678 }
679
680 delta += edit.new.len() as isize - edit.old.len() as isize;
681 }
682
683 let old_start_offset =
684 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
685 let old_end_offset =
686 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
687
688 Some(
689 old_snapshot.offset_to_point(old_start_offset)
690 ..old_snapshot.offset_to_point(old_end_offset),
691 )
692}
693
694fn compute_diff_between_snapshots_in_range(
695 old_snapshot: &TextBufferSnapshot,
696 new_snapshot: &TextBufferSnapshot,
697 total_edit_range: &Range<Anchor>,
698) -> Option<(String, Range<Point>)> {
699 let new_start_point = total_edit_range.start.to_point(new_snapshot);
700 let new_end_point = total_edit_range.end.to_point(new_snapshot);
701 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
702 let old_start_point = old_range.start;
703 let old_end_point = old_range.end;
704
705 const CONTEXT_LINES: u32 = 3;
706
707 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
708 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
709 let old_context_end_row =
710 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
711 let new_context_end_row =
712 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
713
714 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
715 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
716 let old_end_line_offset = old_snapshot
717 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
718 let new_end_line_offset = new_snapshot
719 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
720 let old_edit_range = old_start_line_offset..old_end_line_offset;
721 let new_edit_range = new_start_line_offset..new_end_line_offset;
722
723 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
724 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
725
726 let diff = language::unified_diff_with_offsets(
727 &old_region_text,
728 &new_region_text,
729 old_context_start_row,
730 new_context_start_row,
731 );
732
733 Some((diff, new_start_point..new_end_point))
734}
735
736fn buffer_path_with_id_fallback(
737 file: Option<&Arc<dyn File>>,
738 snapshot: &TextBufferSnapshot,
739 cx: &App,
740) -> Arc<Path> {
741 if let Some(file) = file {
742 file.full_path(cx).into()
743 } else {
744 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
745 }
746}
747
748impl EditPredictionStore {
749 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
750 cx.try_global::<EditPredictionStoreGlobal>()
751 .map(|global| global.0.clone())
752 }
753
754 pub fn global(
755 client: &Arc<Client>,
756 user_store: &Entity<UserStore>,
757 cx: &mut App,
758 ) -> Entity<Self> {
759 cx.try_global::<EditPredictionStoreGlobal>()
760 .map(|global| global.0.clone())
761 .unwrap_or_else(|| {
762 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
763 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
764 ep_store
765 })
766 }
767
768 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
769 let data_collection_choice = Self::load_data_collection_choice();
770
771 let llm_token = LlmApiToken::global(cx);
772
773 let (reject_tx, reject_rx) = mpsc::unbounded();
774 cx.background_spawn({
775 let client = client.clone();
776 let llm_token = llm_token.clone();
777 let app_version = AppVersion::global(cx);
778 let background_executor = cx.background_executor().clone();
779 async move {
780 Self::handle_rejected_predictions(
781 reject_rx,
782 client,
783 llm_token,
784 app_version,
785 background_executor,
786 )
787 .await
788 }
789 })
790 .detach();
791
792 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
793 cx.spawn(async move |this, cx| {
794 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
795 })
796 .detach();
797
798 let mut current_user = user_store.read(cx).watch_current_user();
799 let fetch_experiments_task = cx.spawn(async move |this, cx| {
800 while current_user.borrow().is_none() {
801 current_user.next().await;
802 }
803 this.update(cx, |this, cx| {
804 this.refresh_available_experiments(cx);
805 })
806 .log_err();
807 });
808
809 let this = Self {
810 projects: HashMap::default(),
811 client,
812 user_store,
813 llm_token,
814 _fetch_experiments_task: fetch_experiments_task,
815 update_required: false,
816 edit_prediction_model: EditPredictionModel::Zeta,
817 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
818 preferred_experiment: None,
819 available_experiments: Vec::new(),
820 sweep_ai: SweepAi::new(cx),
821 mercury: Mercury::new(cx),
822
823 data_collection_choice,
824 reject_predictions_tx: reject_tx,
825 settled_predictions_tx,
826 rated_predictions: Default::default(),
827 shown_predictions: Default::default(),
828 #[cfg(test)]
829 settled_event_callback: None,
830 };
831
832 this
833 }
834
835 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
836 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
837 let format = ZetaFormat::parse(&version_str).ok()?;
838 let model_id = env::var("ZED_ZETA_MODEL").ok();
839 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
840 Some(Zeta2RawConfig {
841 model_id,
842 environment,
843 format,
844 })
845 }
846
847 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
848 self.edit_prediction_model = model;
849 }
850
851 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
852 self.zeta2_raw_config = Some(config);
853 }
854
855 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
856 self.zeta2_raw_config.as_ref()
857 }
858
859 pub fn preferred_experiment(&self) -> Option<&str> {
860 self.preferred_experiment.as_deref()
861 }
862
863 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
864 self.preferred_experiment = experiment;
865 }
866
867 pub fn available_experiments(&self) -> &[String] {
868 &self.available_experiments
869 }
870
871 pub fn active_experiment(&self) -> Option<&str> {
872 self.preferred_experiment.as_deref().or_else(|| {
873 self.shown_predictions
874 .iter()
875 .find_map(|p| p.model_version.as_ref())
876 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
877 })
878 }
879
880 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
881 let client = self.client.clone();
882 let llm_token = self.llm_token.clone();
883 let app_version = AppVersion::global(cx);
884 let organization_id = self
885 .user_store
886 .read(cx)
887 .current_organization()
888 .map(|organization| organization.id.clone());
889
890 cx.spawn(async move |this, cx| {
891 let experiments = cx
892 .background_spawn(async move {
893 let http_client = client.http_client();
894 let token = llm_token.acquire(&client, organization_id).await?;
895 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
896 let request = http_client::Request::builder()
897 .method(Method::GET)
898 .uri(url.as_ref())
899 .header("Authorization", format!("Bearer {}", token))
900 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
901 .body(Default::default())?;
902 let mut response = http_client.send(request).await?;
903 if response.status().is_success() {
904 let mut body = Vec::new();
905 response.body_mut().read_to_end(&mut body).await?;
906 let experiments: Vec<String> = serde_json::from_slice(&body)?;
907 Ok(experiments)
908 } else {
909 let mut body = String::new();
910 response.body_mut().read_to_string(&mut body).await?;
911 anyhow::bail!(
912 "Failed to fetch experiments: {:?}\nBody: {}",
913 response.status(),
914 body
915 );
916 }
917 })
918 .await?;
919 this.update(cx, |this, cx| {
920 this.available_experiments = experiments;
921 cx.notify();
922 })?;
923 anyhow::Ok(())
924 })
925 .detach_and_log_err(cx);
926 }
927
928 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
929 use ui::IconName;
930 match self.edit_prediction_model {
931 EditPredictionModel::Sweep => {
932 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
933 .with_disabled(IconName::SweepAiDisabled)
934 .with_up(IconName::SweepAiUp)
935 .with_down(IconName::SweepAiDown)
936 .with_error(IconName::SweepAiError)
937 }
938 EditPredictionModel::Mercury => {
939 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
940 }
941 EditPredictionModel::Zeta => {
942 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
943 .with_disabled(IconName::ZedPredictDisabled)
944 .with_up(IconName::ZedPredictUp)
945 .with_down(IconName::ZedPredictDown)
946 .with_error(IconName::ZedPredictError)
947 }
948 EditPredictionModel::Fim { .. } => {
949 let settings = &all_language_settings(None, cx).edit_predictions;
950 match settings.provider {
951 EditPredictionProvider::Ollama => {
952 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
953 }
954 _ => {
955 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
956 }
957 }
958 }
959 }
960 }
961
962 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
963 self.sweep_ai.api_token.read(cx).has_key()
964 }
965
966 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
967 self.mercury.api_token.read(cx).has_key()
968 }
969
970 pub fn mercury_has_payment_required_error(&self) -> bool {
971 self.mercury.has_payment_required_error()
972 }
973
974 pub fn clear_history(&mut self) {
975 for project_state in self.projects.values_mut() {
976 project_state.events.clear();
977 project_state.last_event.take();
978 }
979 }
980
981 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
982 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
983 project_state.events.clear();
984 project_state.last_event.take();
985 }
986 }
987
988 pub fn edit_history_for_project(
989 &self,
990 project: &Entity<Project>,
991 cx: &App,
992 ) -> Vec<StoredEvent> {
993 self.projects
994 .get(&project.entity_id())
995 .map(|project_state| project_state.events(cx))
996 .unwrap_or_default()
997 }
998
999 pub fn context_for_project<'a>(
1000 &'a self,
1001 project: &Entity<Project>,
1002 cx: &'a mut App,
1003 ) -> Vec<RelatedFile> {
1004 self.projects
1005 .get(&project.entity_id())
1006 .map(|project_state| {
1007 project_state.context.update(cx, |context, cx| {
1008 context
1009 .related_files_with_buffers(cx)
1010 .map(|(mut related_file, buffer)| {
1011 related_file.in_open_source_repo = buffer
1012 .read(cx)
1013 .file()
1014 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1015 related_file
1016 })
1017 .collect()
1018 })
1019 })
1020 .unwrap_or_default()
1021 }
1022
1023 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1024 self.projects
1025 .get(&project.entity_id())
1026 .and_then(|project| project.copilot.clone())
1027 }
1028
1029 pub fn start_copilot_for_project(
1030 &mut self,
1031 project: &Entity<Project>,
1032 cx: &mut Context<Self>,
1033 ) -> Option<Entity<Copilot>> {
1034 if DisableAiSettings::get(None, cx).disable_ai {
1035 return None;
1036 }
1037 let state = self.get_or_init_project(project, cx);
1038
1039 if state.copilot.is_some() {
1040 return state.copilot.clone();
1041 }
1042 let _project = project.clone();
1043 let project = project.read(cx);
1044
1045 let node = project.node_runtime().cloned();
1046 if let Some(node) = node {
1047 let next_id = project.languages().next_language_server_id();
1048 let fs = project.fs().clone();
1049
1050 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1051 state.copilot = Some(copilot.clone());
1052 Some(copilot)
1053 } else {
1054 None
1055 }
1056 }
1057
1058 pub fn context_for_project_with_buffers<'a>(
1059 &'a self,
1060 project: &Entity<Project>,
1061 cx: &'a mut App,
1062 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1063 self.projects
1064 .get(&project.entity_id())
1065 .map(|project| {
1066 project.context.update(cx, |context, cx| {
1067 context.related_files_with_buffers(cx).collect()
1068 })
1069 })
1070 .unwrap_or_default()
1071 }
1072
1073 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1074 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1075 self.user_store.read(cx).edit_prediction_usage()
1076 } else {
1077 None
1078 }
1079 }
1080
1081 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1082 self.get_or_init_project(project, cx);
1083 }
1084
1085 pub fn register_buffer(
1086 &mut self,
1087 buffer: &Entity<Buffer>,
1088 project: &Entity<Project>,
1089 cx: &mut Context<Self>,
1090 ) {
1091 let project_state = self.get_or_init_project(project, cx);
1092 Self::register_buffer_impl(project_state, buffer, project, cx);
1093 }
1094
1095 fn get_or_init_project(
1096 &mut self,
1097 project: &Entity<Project>,
1098 cx: &mut Context<Self>,
1099 ) -> &mut ProjectState {
1100 let entity_id = project.entity_id();
1101 self.projects
1102 .entry(entity_id)
1103 .or_insert_with(|| ProjectState {
1104 context: {
1105 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1106 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1107 this.handle_excerpt_store_event(entity_id, event);
1108 })
1109 .detach();
1110 related_excerpt_store
1111 },
1112 events: VecDeque::new(),
1113 last_event: None,
1114 recent_paths: VecDeque::new(),
1115 debug_tx: None,
1116 registered_buffers: HashMap::default(),
1117 current_prediction: None,
1118 cancelled_predictions: HashSet::default(),
1119 pending_predictions: ArrayVec::new(),
1120 next_pending_prediction_id: 0,
1121 last_edit_prediction_refresh: None,
1122 last_jump_prediction_refresh: None,
1123 license_detection_watchers: HashMap::default(),
1124 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1125 _subscriptions: [
1126 cx.subscribe(&project, Self::handle_project_event),
1127 cx.observe_release(&project, move |this, _, cx| {
1128 this.projects.remove(&entity_id);
1129 cx.notify();
1130 }),
1131 ],
1132 copilot: None,
1133 })
1134 }
1135
1136 pub fn remove_project(&mut self, project: &Entity<Project>) {
1137 self.projects.remove(&project.entity_id());
1138 }
1139
1140 fn handle_excerpt_store_event(
1141 &mut self,
1142 project_entity_id: EntityId,
1143 event: &RelatedExcerptStoreEvent,
1144 ) {
1145 if let Some(project_state) = self.projects.get(&project_entity_id) {
1146 if let Some(debug_tx) = project_state.debug_tx.clone() {
1147 match event {
1148 RelatedExcerptStoreEvent::StartedRefresh => {
1149 debug_tx
1150 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1151 ContextRetrievalStartedDebugEvent {
1152 project_entity_id: project_entity_id,
1153 timestamp: Instant::now(),
1154 search_prompt: String::new(),
1155 },
1156 ))
1157 .ok();
1158 }
1159 RelatedExcerptStoreEvent::FinishedRefresh {
1160 cache_hit_count,
1161 cache_miss_count,
1162 mean_definition_latency,
1163 max_definition_latency,
1164 } => {
1165 debug_tx
1166 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1167 ContextRetrievalFinishedDebugEvent {
1168 project_entity_id: project_entity_id,
1169 timestamp: Instant::now(),
1170 metadata: vec![
1171 (
1172 "Cache Hits",
1173 format!(
1174 "{}/{}",
1175 cache_hit_count,
1176 cache_hit_count + cache_miss_count
1177 )
1178 .into(),
1179 ),
1180 (
1181 "Max LSP Time",
1182 format!("{} ms", max_definition_latency.as_millis())
1183 .into(),
1184 ),
1185 (
1186 "Mean LSP Time",
1187 format!("{} ms", mean_definition_latency.as_millis())
1188 .into(),
1189 ),
1190 ],
1191 },
1192 ))
1193 .ok();
1194 }
1195 }
1196 }
1197 }
1198 }
1199
1200 pub fn debug_info(
1201 &mut self,
1202 project: &Entity<Project>,
1203 cx: &mut Context<Self>,
1204 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1205 let project_state = self.get_or_init_project(project, cx);
1206 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1207 project_state.debug_tx = Some(debug_watch_tx);
1208 debug_watch_rx
1209 }
1210
1211 fn handle_project_event(
1212 &mut self,
1213 project: Entity<Project>,
1214 event: &project::Event,
1215 cx: &mut Context<Self>,
1216 ) {
1217 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1218 return;
1219 }
1220 // TODO [zeta2] init with recent paths
1221 match event {
1222 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1223 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1224 return;
1225 };
1226 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1227 if let Some(path) = path {
1228 if let Some(ix) = project_state
1229 .recent_paths
1230 .iter()
1231 .position(|probe| probe == &path)
1232 {
1233 project_state.recent_paths.remove(ix);
1234 }
1235 project_state.recent_paths.push_front(path);
1236 }
1237 }
1238 project::Event::DiagnosticsUpdated { .. } => {
1239 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1240 self.refresh_prediction_from_diagnostics(
1241 project,
1242 DiagnosticSearchScope::Global,
1243 cx,
1244 );
1245 }
1246 }
1247 _ => (),
1248 }
1249 }
1250
1251 fn register_buffer_impl<'a>(
1252 project_state: &'a mut ProjectState,
1253 buffer: &Entity<Buffer>,
1254 project: &Entity<Project>,
1255 cx: &mut Context<Self>,
1256 ) -> &'a mut RegisteredBuffer {
1257 let buffer_id = buffer.entity_id();
1258
1259 if let Some(file) = buffer.read(cx).file() {
1260 let worktree_id = file.worktree_id(cx);
1261 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1262 project_state
1263 .license_detection_watchers
1264 .entry(worktree_id)
1265 .or_insert_with(|| {
1266 let project_entity_id = project.entity_id();
1267 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1268 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1269 else {
1270 return;
1271 };
1272 project_state
1273 .license_detection_watchers
1274 .remove(&worktree_id);
1275 })
1276 .detach();
1277 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1278 });
1279 }
1280 }
1281
1282 match project_state.registered_buffers.entry(buffer_id) {
1283 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1284 hash_map::Entry::Vacant(entry) => {
1285 let buf = buffer.read(cx);
1286 let snapshot = buf.text_snapshot();
1287 let file = buf.file().cloned();
1288 let project_entity_id = project.entity_id();
1289 entry.insert(RegisteredBuffer {
1290 snapshot,
1291 file,
1292 last_position: None,
1293 pending_predictions: Vec::new(),
1294 _subscriptions: [
1295 cx.subscribe(buffer, {
1296 let project = project.downgrade();
1297 move |this, buffer, event, cx| {
1298 if let language::BufferEvent::Edited { is_local } = event
1299 && let Some(project) = project.upgrade()
1300 {
1301 this.report_changes_for_buffer(
1302 &buffer, &project, false, *is_local, cx,
1303 );
1304 }
1305 }
1306 }),
1307 cx.observe_release(buffer, move |this, _buffer, _cx| {
1308 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1309 else {
1310 return;
1311 };
1312 project_state.registered_buffers.remove(&buffer_id);
1313 }),
1314 ],
1315 })
1316 }
1317 }
1318 }
1319
1320 fn report_changes_for_buffer(
1321 &mut self,
1322 buffer: &Entity<Buffer>,
1323 project: &Entity<Project>,
1324 is_predicted: bool,
1325 is_local: bool,
1326 cx: &mut Context<Self>,
1327 ) {
1328 let project_state = self.get_or_init_project(project, cx);
1329 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1330
1331 let buf = buffer.read(cx);
1332 let new_file = buf.file().cloned();
1333 let new_snapshot = buf.text_snapshot();
1334 if new_snapshot.version == registered_buffer.snapshot.version {
1335 return;
1336 }
1337 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1338 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1339 let mut num_edits = 0usize;
1340 let mut total_deleted = 0usize;
1341 let mut total_inserted = 0usize;
1342 let mut edit_range: Option<Range<Anchor>> = None;
1343 let mut last_offset: Option<usize> = None;
1344 let now = cx.background_executor().now();
1345
1346 for (edit, anchor_range) in
1347 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1348 {
1349 num_edits += 1;
1350 total_deleted += edit.old.len();
1351 total_inserted += edit.new.len();
1352 edit_range = Some(match edit_range {
1353 None => anchor_range,
1354 Some(acc) => acc.start..anchor_range.end,
1355 });
1356 last_offset = Some(edit.new.end);
1357 }
1358
1359 let Some(edit_range) = edit_range else {
1360 return;
1361 };
1362
1363 for pending_prediction in &mut registered_buffer.pending_predictions {
1364 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1365 pending_prediction.last_edit_at = now;
1366 }
1367 }
1368
1369 let include_in_history = is_local
1370 || collaborator_edit_overlaps_locality_region(
1371 project_state,
1372 project,
1373 buffer,
1374 &buf.snapshot(),
1375 &edit_range,
1376 cx,
1377 );
1378
1379 if is_local {
1380 let action_type = match (total_deleted, total_inserted, num_edits) {
1381 (0, ins, n) if ins == n => UserActionType::InsertChar,
1382 (0, _, _) => UserActionType::InsertSelection,
1383 (del, 0, n) if del == n => UserActionType::DeleteChar,
1384 (_, 0, _) => UserActionType::DeleteSelection,
1385 (_, ins, n) if ins == n => UserActionType::InsertChar,
1386 (_, _, _) => UserActionType::InsertSelection,
1387 };
1388
1389 if let Some(offset) = last_offset {
1390 let point = new_snapshot.offset_to_point(offset);
1391 let timestamp_epoch_ms = SystemTime::now()
1392 .duration_since(UNIX_EPOCH)
1393 .map(|d| d.as_millis() as u64)
1394 .unwrap_or(0);
1395 project_state.record_user_action(UserActionRecord {
1396 action_type,
1397 buffer_id: buffer.entity_id(),
1398 line_number: point.row,
1399 offset,
1400 timestamp_epoch_ms,
1401 });
1402 }
1403 }
1404
1405 if !include_in_history {
1406 return;
1407 }
1408
1409 let events = &mut project_state.events;
1410
1411 if let Some(last_event) = project_state.last_event.as_mut() {
1412 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1413 == last_event.new_snapshot.remote_id()
1414 && old_snapshot.version == last_event.new_snapshot.version;
1415
1416 let prediction_source_changed = is_predicted != last_event.predicted;
1417
1418 let should_coalesce = is_next_snapshot_of_same_buffer
1419 && !prediction_source_changed
1420 && lines_between_ranges(
1421 &edit_range.to_point(&new_snapshot),
1422 &last_event.latest_edit_range.to_point(&new_snapshot),
1423 ) <= CHANGE_GROUPING_LINE_SPAN;
1424
1425 if should_coalesce {
1426 let pause_elapsed = last_event
1427 .last_edit_time
1428 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1429 .unwrap_or(false);
1430 if pause_elapsed {
1431 last_event.snapshot_after_last_editing_pause =
1432 Some(last_event.new_snapshot.clone());
1433 last_event.total_edit_range_at_last_pause_boundary =
1434 Some(last_event.total_edit_range.clone());
1435 }
1436
1437 last_event.latest_edit_range = edit_range.clone();
1438 last_event.total_edit_range =
1439 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1440 last_event.new_snapshot = new_snapshot;
1441 last_event.last_edit_time = Some(now);
1442 return;
1443 }
1444 }
1445
1446 if let Some(event) = project_state.last_event.take() {
1447 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1448 if events.len() + 1 >= EVENT_COUNT_MAX {
1449 events.pop_front();
1450 }
1451 events.push_back(event);
1452 }
1453 }
1454
1455 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1456
1457 project_state.last_event = Some(LastEvent {
1458 old_file,
1459 new_file,
1460 old_snapshot,
1461 new_snapshot,
1462 latest_edit_range: edit_range.clone(),
1463 total_edit_range: edit_range,
1464 total_edit_range_at_last_pause_boundary: None,
1465 predicted: is_predicted,
1466 snapshot_after_last_editing_pause: None,
1467 last_edit_time: Some(now),
1468 });
1469 }
1470
1471 fn prediction_at(
1472 &mut self,
1473 buffer: &Entity<Buffer>,
1474 position: Option<language::Anchor>,
1475 project: &Entity<Project>,
1476 cx: &App,
1477 ) -> Option<BufferEditPrediction<'_>> {
1478 let project_state = self.projects.get_mut(&project.entity_id())?;
1479 if let Some(position) = position
1480 && let Some(buffer) = project_state
1481 .registered_buffers
1482 .get_mut(&buffer.entity_id())
1483 {
1484 buffer.last_position = Some(position);
1485 }
1486
1487 let CurrentEditPrediction {
1488 requested_by,
1489 prediction,
1490 ..
1491 } = project_state.current_prediction.as_ref()?;
1492
1493 if prediction.targets_buffer(buffer.read(cx)) {
1494 Some(BufferEditPrediction::Local { prediction })
1495 } else {
1496 let show_jump = match requested_by {
1497 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1498 requested_by_buffer_id == &buffer.entity_id()
1499 }
1500 PredictionRequestedBy::DiagnosticsUpdate => true,
1501 };
1502
1503 if show_jump {
1504 Some(BufferEditPrediction::Jump { prediction })
1505 } else {
1506 None
1507 }
1508 }
1509 }
1510
1511 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1512 let Some(current_prediction) = self
1513 .projects
1514 .get_mut(&project.entity_id())
1515 .and_then(|project_state| project_state.current_prediction.take())
1516 else {
1517 return;
1518 };
1519
1520 self.report_changes_for_buffer(
1521 ¤t_prediction.prediction.buffer,
1522 project,
1523 true,
1524 true,
1525 cx,
1526 );
1527
1528 // can't hold &mut project_state ref across report_changes_for_buffer_call
1529 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1530 return;
1531 };
1532
1533 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1534 project_state.cancel_pending_prediction(pending_prediction, cx);
1535 }
1536
1537 match self.edit_prediction_model {
1538 EditPredictionModel::Sweep => {
1539 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1540 }
1541 EditPredictionModel::Mercury => {
1542 mercury::edit_prediction_accepted(
1543 current_prediction.prediction.id,
1544 self.client.http_client(),
1545 cx,
1546 );
1547 }
1548 EditPredictionModel::Zeta => {
1549 let is_cloud = !matches!(
1550 all_language_settings(None, cx).edit_predictions.provider,
1551 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1552 );
1553 if is_cloud {
1554 zeta::edit_prediction_accepted(self, current_prediction, cx)
1555 }
1556 }
1557 EditPredictionModel::Fim { .. } => {}
1558 }
1559 }
1560
1561 async fn handle_rejected_predictions(
1562 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1563 client: Arc<Client>,
1564 llm_token: LlmApiToken,
1565 app_version: Version,
1566 background_executor: BackgroundExecutor,
1567 ) {
1568 let mut rx = std::pin::pin!(rx.peekable());
1569 let mut batched = Vec::new();
1570
1571 while let Some(EditPredictionRejectionPayload {
1572 rejection,
1573 organization_id,
1574 }) = rx.next().await
1575 {
1576 batched.push(rejection);
1577
1578 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1579 select_biased! {
1580 next = rx.as_mut().peek().fuse() => {
1581 if next.is_some() {
1582 continue;
1583 }
1584 }
1585 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1586 }
1587 }
1588
1589 let url = client
1590 .http_client()
1591 .build_zed_llm_url("/predict_edits/reject", &[])
1592 .unwrap();
1593
1594 let flush_count = batched
1595 .len()
1596 // in case items have accumulated after failure
1597 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1598 let start = batched.len() - flush_count;
1599
1600 let body = RejectEditPredictionsBodyRef {
1601 rejections: &batched[start..],
1602 };
1603
1604 let result = Self::send_api_request::<()>(
1605 |builder| {
1606 let req = builder
1607 .uri(url.as_ref())
1608 .body(serde_json::to_string(&body)?.into());
1609 anyhow::Ok(req?)
1610 },
1611 client.clone(),
1612 llm_token.clone(),
1613 organization_id,
1614 app_version.clone(),
1615 true,
1616 )
1617 .await;
1618
1619 if result.log_err().is_some() {
1620 batched.drain(start..);
1621 }
1622 }
1623 }
1624
1625 async fn run_settled_predictions_worker(
1626 this: WeakEntity<Self>,
1627 mut rx: UnboundedReceiver<Instant>,
1628 cx: &mut AsyncApp,
1629 ) {
1630 let mut next_wake_time: Option<Instant> = None;
1631 loop {
1632 let now = cx.background_executor().now();
1633 if let Some(wake_time) = next_wake_time.take() {
1634 cx.background_executor()
1635 .timer(wake_time.duration_since(now))
1636 .await;
1637 } else {
1638 let Some(new_enqueue_time) = rx.next().await else {
1639 break;
1640 };
1641 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1642 while rx.next().now_or_never().flatten().is_some() {}
1643 continue;
1644 }
1645
1646 let Some(this) = this.upgrade() else {
1647 break;
1648 };
1649
1650 let now = cx.background_executor().now();
1651
1652 let mut oldest_edited_at = None;
1653
1654 this.update(cx, |this, _| {
1655 for (_, project_state) in this.projects.iter_mut() {
1656 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1657 registered_buffer
1658 .pending_predictions
1659 .retain_mut(|pending_prediction| {
1660 let age =
1661 now.saturating_duration_since(pending_prediction.enqueued_at);
1662 if age >= EDIT_PREDICTION_SETTLED_TTL {
1663 return false;
1664 }
1665
1666 let quiet_for =
1667 now.saturating_duration_since(pending_prediction.last_edit_at);
1668 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1669 let settled_editable_region = registered_buffer
1670 .snapshot
1671 .text_for_range(
1672 pending_prediction.editable_anchor_range.clone(),
1673 )
1674 .collect::<String>();
1675
1676 #[cfg(test)]
1677 if let Some(callback) = &this.settled_event_callback {
1678 callback(
1679 pending_prediction.request_id.clone(),
1680 settled_editable_region.clone(),
1681 );
1682 }
1683
1684 telemetry::event!(
1685 EDIT_PREDICTION_SETTLED_EVENT,
1686 request_id = pending_prediction.request_id.0.clone(),
1687 settled_editable_region,
1688 example = pending_prediction.example.take(),
1689 );
1690
1691 return false;
1692 }
1693
1694 if oldest_edited_at
1695 .is_none_or(|t| pending_prediction.last_edit_at < t)
1696 {
1697 oldest_edited_at = Some(pending_prediction.last_edit_at);
1698 }
1699
1700 true
1701 });
1702 }
1703 }
1704 });
1705
1706 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1707 }
1708 }
1709
1710 pub(crate) fn enqueue_settled_prediction(
1711 &mut self,
1712 request_id: EditPredictionId,
1713 project: &Entity<Project>,
1714 edited_buffer: &Entity<Buffer>,
1715 edited_buffer_snapshot: &BufferSnapshot,
1716 editable_offset_range: Range<usize>,
1717 example: Option<ExampleSpec>,
1718 cx: &mut Context<Self>,
1719 ) {
1720 let this = &mut *self;
1721 let project_state = this.get_or_init_project(project, cx);
1722 if let Some(buffer) = project_state
1723 .registered_buffers
1724 .get_mut(&edited_buffer.entity_id())
1725 {
1726 let now = cx.background_executor().now();
1727 buffer.pending_predictions.push(PendingSettledPrediction {
1728 request_id: request_id,
1729 editable_anchor_range: edited_buffer_snapshot
1730 .anchor_range_around(editable_offset_range),
1731 example,
1732 enqueued_at: now,
1733 last_edit_at: now,
1734 });
1735 this.settled_predictions_tx.unbounded_send(now).ok();
1736 }
1737 }
1738
1739 fn reject_current_prediction(
1740 &mut self,
1741 reason: EditPredictionRejectReason,
1742 project: &Entity<Project>,
1743 cx: &App,
1744 ) {
1745 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1746 project_state.pending_predictions.clear();
1747 if let Some(prediction) = project_state.current_prediction.take() {
1748 let model_version = prediction.prediction.model_version.clone();
1749 self.reject_prediction(
1750 prediction.prediction.id,
1751 reason,
1752 prediction.was_shown,
1753 model_version,
1754 cx,
1755 );
1756 }
1757 };
1758 }
1759
1760 fn did_show_current_prediction(
1761 &mut self,
1762 project: &Entity<Project>,
1763 display_type: edit_prediction_types::SuggestionDisplayType,
1764 cx: &mut Context<Self>,
1765 ) {
1766 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1767 return;
1768 };
1769
1770 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1771 return;
1772 };
1773
1774 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1775 let previous_shown_with = current_prediction.shown_with;
1776
1777 if previous_shown_with.is_none() || !is_jump {
1778 current_prediction.shown_with = Some(display_type);
1779 }
1780
1781 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1782
1783 if is_first_non_jump_show {
1784 current_prediction.was_shown = true;
1785 }
1786
1787 let display_type_changed = previous_shown_with != Some(display_type);
1788
1789 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1790 sweep_ai::edit_prediction_shown(
1791 &self.sweep_ai,
1792 self.client.clone(),
1793 ¤t_prediction.prediction,
1794 display_type,
1795 cx,
1796 );
1797 }
1798
1799 if is_first_non_jump_show {
1800 self.shown_predictions
1801 .push_front(current_prediction.prediction.clone());
1802 if self.shown_predictions.len() > 50 {
1803 let completion = self.shown_predictions.pop_back().unwrap();
1804 self.rated_predictions.remove(&completion.id);
1805 }
1806 }
1807 }
1808
1809 fn reject_prediction(
1810 &mut self,
1811 prediction_id: EditPredictionId,
1812 reason: EditPredictionRejectReason,
1813 was_shown: bool,
1814 model_version: Option<String>,
1815 cx: &App,
1816 ) {
1817 match self.edit_prediction_model {
1818 EditPredictionModel::Zeta => {
1819 let is_cloud = !matches!(
1820 all_language_settings(None, cx).edit_predictions.provider,
1821 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1822 );
1823
1824 if is_cloud {
1825 let organization_id = self
1826 .user_store
1827 .read(cx)
1828 .current_organization()
1829 .map(|organization| organization.id.clone());
1830
1831 self.reject_predictions_tx
1832 .unbounded_send(EditPredictionRejectionPayload {
1833 rejection: EditPredictionRejection {
1834 request_id: prediction_id.to_string(),
1835 reason,
1836 was_shown,
1837 model_version,
1838 },
1839 organization_id,
1840 })
1841 .log_err();
1842 }
1843 }
1844 EditPredictionModel::Mercury => {
1845 mercury::edit_prediction_rejected(
1846 prediction_id,
1847 was_shown,
1848 reason,
1849 self.client.http_client(),
1850 cx,
1851 );
1852 }
1853 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1854 }
1855 }
1856
1857 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1858 self.projects
1859 .get(&project.entity_id())
1860 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1861 }
1862
1863 pub fn refresh_prediction_from_buffer(
1864 &mut self,
1865 project: Entity<Project>,
1866 buffer: Entity<Buffer>,
1867 position: language::Anchor,
1868 cx: &mut Context<Self>,
1869 ) {
1870 self.queue_prediction_refresh(
1871 project.clone(),
1872 PredictEditsRequestTrigger::Other,
1873 buffer.entity_id(),
1874 cx,
1875 move |this, cx| {
1876 let Some(request_task) = this
1877 .update(cx, |this, cx| {
1878 this.request_prediction(
1879 &project,
1880 &buffer,
1881 position,
1882 PredictEditsRequestTrigger::Other,
1883 cx,
1884 )
1885 })
1886 .log_err()
1887 else {
1888 return Task::ready(anyhow::Ok(None));
1889 };
1890
1891 cx.spawn(async move |_cx| {
1892 request_task.await.map(|prediction_result| {
1893 prediction_result.map(|prediction_result| {
1894 (
1895 prediction_result,
1896 PredictionRequestedBy::Buffer(buffer.entity_id()),
1897 )
1898 })
1899 })
1900 })
1901 },
1902 )
1903 }
1904
1905 pub fn refresh_prediction_from_diagnostics(
1906 &mut self,
1907 project: Entity<Project>,
1908 scope: DiagnosticSearchScope,
1909 cx: &mut Context<Self>,
1910 ) {
1911 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1912 return;
1913 }
1914
1915 if currently_following(&project, cx) {
1916 return;
1917 }
1918
1919 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1920 return;
1921 };
1922
1923 // Prefer predictions from buffer
1924 if project_state.current_prediction.is_some() {
1925 log::debug!(
1926 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1927 );
1928 return;
1929 }
1930
1931 self.queue_prediction_refresh(
1932 project.clone(),
1933 PredictEditsRequestTrigger::Diagnostics,
1934 project.entity_id(),
1935 cx,
1936 move |this, cx| {
1937 let Some((active_buffer, snapshot, cursor_point)) = this
1938 .read_with(cx, |this, cx| {
1939 let project_state = this.projects.get(&project.entity_id())?;
1940 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1941 let snapshot = buffer.read(cx).snapshot();
1942
1943 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1944 return None;
1945 }
1946
1947 let cursor_point = position
1948 .map(|pos| pos.to_point(&snapshot))
1949 .unwrap_or_default();
1950
1951 Some((buffer, snapshot, cursor_point))
1952 })
1953 .log_err()
1954 .flatten()
1955 else {
1956 return Task::ready(anyhow::Ok(None));
1957 };
1958
1959 cx.spawn(async move |cx| {
1960 let diagnostic_search_range = match scope {
1961 DiagnosticSearchScope::Local => {
1962 let diagnostic_search_start =
1963 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1964 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1965 Point::new(diagnostic_search_start, 0)
1966 ..Point::new(diagnostic_search_end, 0)
1967 }
1968 DiagnosticSearchScope::Global => Default::default(),
1969 };
1970
1971 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1972 active_buffer,
1973 &snapshot,
1974 diagnostic_search_range,
1975 cursor_point,
1976 &project,
1977 cx,
1978 )
1979 .await?
1980 else {
1981 return anyhow::Ok(None);
1982 };
1983
1984 let Some(prediction_result) = this
1985 .update(cx, |this, cx| {
1986 this.request_prediction(
1987 &project,
1988 &jump_buffer,
1989 jump_position,
1990 PredictEditsRequestTrigger::Diagnostics,
1991 cx,
1992 )
1993 })?
1994 .await?
1995 else {
1996 return anyhow::Ok(None);
1997 };
1998
1999 this.update(cx, |this, cx| {
2000 Some((
2001 if this
2002 .get_or_init_project(&project, cx)
2003 .current_prediction
2004 .is_none()
2005 {
2006 prediction_result
2007 } else {
2008 EditPredictionResult {
2009 id: prediction_result.id,
2010 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2011 }
2012 },
2013 PredictionRequestedBy::DiagnosticsUpdate,
2014 ))
2015 })
2016 })
2017 },
2018 );
2019 }
2020
2021 fn predictions_enabled_at(
2022 snapshot: &BufferSnapshot,
2023 position: Option<language::Anchor>,
2024 cx: &App,
2025 ) -> bool {
2026 let file = snapshot.file();
2027 let all_settings = all_language_settings(file, cx);
2028 if !all_settings.show_edit_predictions(snapshot.language(), cx)
2029 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2030 {
2031 return false;
2032 }
2033
2034 if let Some(last_position) = position {
2035 let settings = snapshot.settings_at(last_position, cx);
2036
2037 if !settings.edit_predictions_disabled_in.is_empty()
2038 && let Some(scope) = snapshot.language_scope_at(last_position)
2039 && let Some(scope_name) = scope.override_name()
2040 && settings
2041 .edit_predictions_disabled_in
2042 .iter()
2043 .any(|s| s == scope_name)
2044 {
2045 return false;
2046 }
2047 }
2048
2049 true
2050 }
2051
2052 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2053}
2054
2055fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2056 let Some(app_state) = AppState::try_global(cx).and_then(|app_state| app_state.upgrade()) else {
2057 return false;
2058 };
2059
2060 app_state
2061 .workspace_store
2062 .read(cx)
2063 .workspaces()
2064 .filter_map(|workspace| workspace.upgrade())
2065 .any(|workspace| {
2066 workspace.read(cx).project().entity_id() == project.entity_id()
2067 && workspace
2068 .read(cx)
2069 .leader_for_pane(workspace.read(cx).active_pane())
2070 .is_some()
2071 })
2072}
2073
2074fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2075 match provider {
2076 EditPredictionProvider::Zed
2077 | EditPredictionProvider::Sweep
2078 | EditPredictionProvider::Mercury
2079 | EditPredictionProvider::Ollama
2080 | EditPredictionProvider::OpenAiCompatibleApi
2081 | EditPredictionProvider::Experimental(_) => true,
2082 EditPredictionProvider::None
2083 | EditPredictionProvider::Copilot
2084 | EditPredictionProvider::Codestral => false,
2085 }
2086}
2087
2088impl EditPredictionStore {
2089 fn queue_prediction_refresh(
2090 &mut self,
2091 project: Entity<Project>,
2092 request_trigger: PredictEditsRequestTrigger,
2093 throttle_entity: EntityId,
2094 cx: &mut Context<Self>,
2095 do_refresh: impl FnOnce(
2096 WeakEntity<Self>,
2097 &mut AsyncApp,
2098 )
2099 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2100 + 'static,
2101 ) {
2102 fn select_throttle(
2103 project_state: &mut ProjectState,
2104 request_trigger: PredictEditsRequestTrigger,
2105 ) -> &mut Option<(EntityId, Instant)> {
2106 match request_trigger {
2107 PredictEditsRequestTrigger::Diagnostics => {
2108 &mut project_state.last_jump_prediction_refresh
2109 }
2110 _ => &mut project_state.last_edit_prediction_refresh,
2111 }
2112 }
2113
2114 let (needs_acceptance_tracking, max_pending_predictions) =
2115 match all_language_settings(None, cx).edit_predictions.provider {
2116 EditPredictionProvider::Zed
2117 | EditPredictionProvider::Sweep
2118 | EditPredictionProvider::Mercury
2119 | EditPredictionProvider::Experimental(_) => (true, 2),
2120 EditPredictionProvider::Ollama => (false, 1),
2121 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2122 EditPredictionProvider::None
2123 | EditPredictionProvider::Copilot
2124 | EditPredictionProvider::Codestral => {
2125 log::error!("queue_prediction_refresh called with non-store provider");
2126 return;
2127 }
2128 };
2129
2130 let drop_on_cancel = !needs_acceptance_tracking;
2131 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2132 let project_state = self.get_or_init_project(&project, cx);
2133 let pending_prediction_id = project_state.next_pending_prediction_id;
2134 project_state.next_pending_prediction_id += 1;
2135 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2136
2137 let task = cx.spawn(async move |this, cx| {
2138 let throttle_wait = this
2139 .update(cx, |this, cx| {
2140 let project_state = this.get_or_init_project(&project, cx);
2141 let throttle = *select_throttle(project_state, request_trigger);
2142
2143 throttle.and_then(|(last_entity, last_timestamp)| {
2144 if throttle_entity != last_entity {
2145 return None;
2146 }
2147 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2148 })
2149 })
2150 .ok()
2151 .flatten();
2152
2153 if let Some(timeout) = throttle_wait {
2154 cx.background_executor().timer(timeout).await;
2155 }
2156
2157 // If this task was cancelled before the throttle timeout expired,
2158 // do not perform a request. Also skip if another task already
2159 // proceeded since we were enqueued (duplicate).
2160 let mut is_cancelled = true;
2161 this.update(cx, |this, cx| {
2162 let project_state = this.get_or_init_project(&project, cx);
2163 let was_cancelled = project_state
2164 .cancelled_predictions
2165 .remove(&pending_prediction_id);
2166 if was_cancelled {
2167 return;
2168 }
2169
2170 // Another request has been already sent since this was enqueued
2171 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2172 return;
2173 }
2174
2175 let new_refresh = (throttle_entity, Instant::now());
2176 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2177 is_cancelled = false;
2178 })
2179 .ok();
2180 if is_cancelled {
2181 return None;
2182 }
2183
2184 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2185 let new_prediction_id = new_prediction_result
2186 .as_ref()
2187 .map(|(prediction, _)| prediction.id.clone());
2188
2189 // When a prediction completes, remove it from the pending list, and cancel
2190 // any pending predictions that were enqueued before it.
2191 this.update(cx, |this, cx| {
2192 let project_state = this.get_or_init_project(&project, cx);
2193
2194 let is_cancelled = project_state
2195 .cancelled_predictions
2196 .remove(&pending_prediction_id);
2197
2198 let new_current_prediction = if !is_cancelled
2199 && let Some((prediction_result, requested_by)) = new_prediction_result
2200 {
2201 match prediction_result.prediction {
2202 Ok(prediction) => {
2203 let new_prediction = CurrentEditPrediction {
2204 requested_by,
2205 prediction,
2206 was_shown: false,
2207 shown_with: None,
2208 };
2209
2210 if let Some(current_prediction) =
2211 project_state.current_prediction.as_ref()
2212 {
2213 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2214 {
2215 this.reject_current_prediction(
2216 EditPredictionRejectReason::Replaced,
2217 &project,
2218 cx,
2219 );
2220
2221 Some(new_prediction)
2222 } else {
2223 this.reject_prediction(
2224 new_prediction.prediction.id,
2225 EditPredictionRejectReason::CurrentPreferred,
2226 false,
2227 new_prediction.prediction.model_version,
2228 cx,
2229 );
2230 None
2231 }
2232 } else {
2233 Some(new_prediction)
2234 }
2235 }
2236 Err(reject_reason) => {
2237 this.reject_prediction(
2238 prediction_result.id,
2239 reject_reason,
2240 false,
2241 None,
2242 cx,
2243 );
2244 None
2245 }
2246 }
2247 } else {
2248 None
2249 };
2250
2251 let project_state = this.get_or_init_project(&project, cx);
2252
2253 if let Some(new_prediction) = new_current_prediction {
2254 project_state.current_prediction = Some(new_prediction);
2255 }
2256
2257 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2258 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2259 if pending_prediction.id == pending_prediction_id {
2260 pending_predictions.remove(ix);
2261 for pending_prediction in pending_predictions.drain(0..ix) {
2262 project_state.cancel_pending_prediction(pending_prediction, cx)
2263 }
2264 break;
2265 }
2266 }
2267 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2268 cx.notify();
2269 })
2270 .ok();
2271
2272 new_prediction_id
2273 });
2274
2275 if project_state.pending_predictions.len() < max_pending_predictions {
2276 project_state.pending_predictions.push(PendingPrediction {
2277 id: pending_prediction_id,
2278 task,
2279 drop_on_cancel,
2280 });
2281 } else {
2282 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2283 project_state.pending_predictions.push(PendingPrediction {
2284 id: pending_prediction_id,
2285 task,
2286 drop_on_cancel,
2287 });
2288 project_state.cancel_pending_prediction(pending_prediction, cx);
2289 }
2290 }
2291
2292 pub fn request_prediction(
2293 &mut self,
2294 project: &Entity<Project>,
2295 active_buffer: &Entity<Buffer>,
2296 position: language::Anchor,
2297 trigger: PredictEditsRequestTrigger,
2298 cx: &mut Context<Self>,
2299 ) -> Task<Result<Option<EditPredictionResult>>> {
2300 self.request_prediction_internal(
2301 project.clone(),
2302 active_buffer.clone(),
2303 position,
2304 trigger,
2305 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2306 cx,
2307 )
2308 }
2309
2310 fn request_prediction_internal(
2311 &mut self,
2312 project: Entity<Project>,
2313 active_buffer: Entity<Buffer>,
2314 position: language::Anchor,
2315 trigger: PredictEditsRequestTrigger,
2316 allow_jump: bool,
2317 cx: &mut Context<Self>,
2318 ) -> Task<Result<Option<EditPredictionResult>>> {
2319 self.get_or_init_project(&project, cx);
2320 let project_state = self.projects.get(&project.entity_id()).unwrap();
2321 let stored_events = project_state.events(cx);
2322 let has_events = !stored_events.is_empty();
2323 let events: Vec<Arc<zeta_prompt::Event>> =
2324 stored_events.iter().map(|e| e.event.clone()).collect();
2325 let debug_tx = project_state.debug_tx.clone();
2326
2327 let snapshot = active_buffer.read(cx).snapshot();
2328 let cursor_point = position.to_point(&snapshot);
2329 let current_offset = position.to_offset(&snapshot);
2330
2331 let mut user_actions: Vec<UserActionRecord> =
2332 project_state.user_actions.iter().cloned().collect();
2333
2334 if let Some(last_action) = user_actions.last() {
2335 if last_action.buffer_id == active_buffer.entity_id()
2336 && current_offset != last_action.offset
2337 {
2338 let timestamp_epoch_ms = SystemTime::now()
2339 .duration_since(UNIX_EPOCH)
2340 .map(|d| d.as_millis() as u64)
2341 .unwrap_or(0);
2342 user_actions.push(UserActionRecord {
2343 action_type: UserActionType::CursorMovement,
2344 buffer_id: active_buffer.entity_id(),
2345 line_number: cursor_point.row,
2346 offset: current_offset,
2347 timestamp_epoch_ms,
2348 });
2349 }
2350 }
2351 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2352 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2353 let diagnostic_search_range =
2354 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2355
2356 let related_files = self.context_for_project(&project, cx);
2357
2358 let is_open_source = snapshot
2359 .file()
2360 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2361 && events.iter().all(|event| event.in_open_source_repo())
2362 && related_files.iter().all(|file| file.in_open_source_repo);
2363
2364 let can_collect_data = !cfg!(test)
2365 && is_open_source
2366 && self.is_data_collection_enabled(cx)
2367 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2368
2369 let recent_paths = project_state.recent_paths.clone();
2370
2371 let inputs = EditPredictionModelInput {
2372 project: project.clone(),
2373 buffer: active_buffer,
2374 snapshot,
2375 position,
2376 events,
2377 related_files,
2378 recent_paths,
2379 trigger,
2380 diagnostic_search_range: diagnostic_search_range,
2381 debug_tx,
2382 user_actions,
2383 can_collect_data,
2384 is_open_source,
2385 };
2386
2387 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2388
2389 let task = match self.edit_prediction_model {
2390 EditPredictionModel::Zeta => {
2391 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2392 }
2393 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2394 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2395 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2396 };
2397
2398 cx.spawn(async move |this, cx| {
2399 let prediction = task.await?;
2400
2401 // Only fall back to diagnostics-based prediction if we got a
2402 // the model had nothing to suggest for the buffer
2403 if prediction.is_none()
2404 && allow_jump
2405 && has_events
2406 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2407 {
2408 this.update(cx, |this, cx| {
2409 this.refresh_prediction_from_diagnostics(
2410 project,
2411 DiagnosticSearchScope::Local,
2412 cx,
2413 );
2414 })?;
2415 return anyhow::Ok(None);
2416 }
2417
2418 Ok(prediction)
2419 })
2420 }
2421
2422 pub(crate) async fn next_diagnostic_location(
2423 active_buffer: Entity<Buffer>,
2424 active_buffer_snapshot: &BufferSnapshot,
2425 active_buffer_diagnostic_search_range: Range<Point>,
2426 active_buffer_cursor_point: Point,
2427 project: &Entity<Project>,
2428 cx: &mut AsyncApp,
2429 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2430 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2431 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2432 .flat_map(|(_, _, _, selections)| {
2433 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2434 })
2435 .collect();
2436
2437 let mut jump_location = active_buffer_snapshot
2438 .diagnostic_groups(None)
2439 .into_iter()
2440 .filter_map(|(_, group)| {
2441 let range = &group.entries[group.primary_ix]
2442 .range
2443 .to_point(&active_buffer_snapshot);
2444 if range.overlaps(&active_buffer_diagnostic_search_range) {
2445 return None;
2446 }
2447 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2448 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2449 });
2450 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2451 <= DIAGNOSTIC_LINES_RANGE;
2452 if near_collaborator && !near_local {
2453 return None;
2454 }
2455 Some(range.start)
2456 })
2457 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2458 .map(|position| {
2459 (
2460 active_buffer.clone(),
2461 active_buffer_snapshot.anchor_before(position),
2462 )
2463 });
2464
2465 if jump_location.is_none() {
2466 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2467 let file = buffer.file()?;
2468
2469 Some(ProjectPath {
2470 worktree_id: file.worktree_id(cx),
2471 path: file.path().clone(),
2472 })
2473 });
2474
2475 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2476 project
2477 .diagnostic_summaries(false, cx)
2478 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2479 .map(|(path, _, _)| {
2480 let shared_prefix = path
2481 .path
2482 .components()
2483 .zip(
2484 active_buffer_path
2485 .as_ref()
2486 .map(|p| p.path.components())
2487 .unwrap_or_default(),
2488 )
2489 .take_while(|(a, b)| a == b)
2490 .count();
2491 (path, shared_prefix)
2492 })
2493 .collect()
2494 });
2495
2496 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2497
2498 for (path, _) in candidates {
2499 let candidate_buffer = project
2500 .update(cx, |project, cx| project.open_buffer(path, cx))
2501 .await?;
2502
2503 let (has_collaborators, diagnostic_position) =
2504 candidate_buffer.read_with(cx, |buffer, _cx| {
2505 let snapshot = buffer.snapshot();
2506 let has_collaborators = snapshot
2507 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2508 .next()
2509 .is_some();
2510 let position = buffer
2511 .buffer_diagnostics(None)
2512 .into_iter()
2513 .min_by_key(|entry| entry.diagnostic.severity)
2514 .map(|entry| entry.range.start);
2515 (has_collaborators, position)
2516 });
2517
2518 if has_collaborators {
2519 continue;
2520 }
2521
2522 if let Some(position) = diagnostic_position {
2523 jump_location = Some((candidate_buffer, position));
2524 break;
2525 }
2526 }
2527 }
2528
2529 anyhow::Ok(jump_location)
2530 }
2531
2532 async fn send_raw_llm_request(
2533 request: RawCompletionRequest,
2534 client: Arc<Client>,
2535 custom_url: Option<Arc<Url>>,
2536 llm_token: LlmApiToken,
2537 organization_id: Option<OrganizationId>,
2538 app_version: Version,
2539 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2540 let url = if let Some(custom_url) = custom_url {
2541 custom_url.as_ref().clone()
2542 } else {
2543 client
2544 .http_client()
2545 .build_zed_llm_url("/predict_edits/raw", &[])?
2546 };
2547
2548 Self::send_api_request(
2549 |builder| {
2550 let req = builder
2551 .uri(url.as_ref())
2552 .body(serde_json::to_string(&request)?.into());
2553 Ok(req?)
2554 },
2555 client,
2556 llm_token,
2557 organization_id,
2558 app_version,
2559 true,
2560 )
2561 .await
2562 }
2563
2564 pub(crate) async fn send_v3_request(
2565 input: ZetaPromptInput,
2566 client: Arc<Client>,
2567 llm_token: LlmApiToken,
2568 organization_id: Option<OrganizationId>,
2569 app_version: Version,
2570 trigger: PredictEditsRequestTrigger,
2571 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2572 let url = client
2573 .http_client()
2574 .build_zed_llm_url("/predict_edits/v3", &[])?;
2575
2576 let request = PredictEditsV3Request { input, trigger };
2577
2578 let json_bytes = serde_json::to_vec(&request)?;
2579 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2580
2581 Self::send_api_request(
2582 |builder| {
2583 let req = builder
2584 .uri(url.as_ref())
2585 .header("Content-Encoding", "zstd")
2586 .body(compressed.clone().into());
2587 Ok(req?)
2588 },
2589 client,
2590 llm_token,
2591 organization_id,
2592 app_version,
2593 true,
2594 )
2595 .await
2596 }
2597
2598 async fn send_api_request<Res>(
2599 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2600 client: Arc<Client>,
2601 llm_token: LlmApiToken,
2602 organization_id: Option<OrganizationId>,
2603 app_version: Version,
2604 require_auth: bool,
2605 ) -> Result<(Res, Option<EditPredictionUsage>)>
2606 where
2607 Res: DeserializeOwned,
2608 {
2609 let http_client = client.http_client();
2610
2611 let mut token = if require_auth {
2612 Some(llm_token.acquire(&client, organization_id.clone()).await?)
2613 } else {
2614 llm_token
2615 .acquire(&client, organization_id.clone())
2616 .await
2617 .ok()
2618 };
2619 let mut did_retry = false;
2620
2621 loop {
2622 let request_builder = http_client::Request::builder().method(Method::POST);
2623
2624 let mut request_builder = request_builder
2625 .header("Content-Type", "application/json")
2626 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2627
2628 // Only add Authorization header if we have a token
2629 if let Some(ref token_value) = token {
2630 request_builder =
2631 request_builder.header("Authorization", format!("Bearer {}", token_value));
2632 }
2633
2634 let request = build(request_builder)?;
2635
2636 let mut response = http_client.send(request).await?;
2637
2638 if let Some(minimum_required_version) = response
2639 .headers()
2640 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2641 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2642 {
2643 anyhow::ensure!(
2644 app_version >= minimum_required_version,
2645 ZedUpdateRequiredError {
2646 minimum_version: minimum_required_version
2647 }
2648 );
2649 }
2650
2651 if response.status().is_success() {
2652 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2653
2654 let mut body = Vec::new();
2655 response.body_mut().read_to_end(&mut body).await?;
2656 return Ok((serde_json::from_slice(&body)?, usage));
2657 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2658 did_retry = true;
2659 token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2660 } else {
2661 let mut body = String::new();
2662 response.body_mut().read_to_string(&mut body).await?;
2663 anyhow::bail!(
2664 "Request failed with status: {:?}\nBody: {}",
2665 response.status(),
2666 body
2667 );
2668 }
2669 }
2670 }
2671
2672 pub fn refresh_context(
2673 &mut self,
2674 project: &Entity<Project>,
2675 buffer: &Entity<language::Buffer>,
2676 cursor_position: language::Anchor,
2677 cx: &mut Context<Self>,
2678 ) {
2679 self.get_or_init_project(project, cx)
2680 .context
2681 .update(cx, |store, cx| {
2682 store.refresh(buffer.clone(), cursor_position, cx);
2683 });
2684 }
2685
2686 #[cfg(feature = "cli-support")]
2687 pub fn set_context_for_buffer(
2688 &mut self,
2689 project: &Entity<Project>,
2690 related_files: Vec<RelatedFile>,
2691 cx: &mut Context<Self>,
2692 ) {
2693 self.get_or_init_project(project, cx)
2694 .context
2695 .update(cx, |store, cx| {
2696 store.set_related_files(related_files, cx);
2697 });
2698 }
2699
2700 #[cfg(feature = "cli-support")]
2701 pub fn set_recent_paths_for_project(
2702 &mut self,
2703 project: &Entity<Project>,
2704 paths: impl IntoIterator<Item = project::ProjectPath>,
2705 cx: &mut Context<Self>,
2706 ) {
2707 let project_state = self.get_or_init_project(project, cx);
2708 project_state.recent_paths = paths.into_iter().collect();
2709 }
2710
2711 fn is_file_open_source(
2712 &self,
2713 project: &Entity<Project>,
2714 file: &Arc<dyn File>,
2715 cx: &App,
2716 ) -> bool {
2717 if !file.is_local() || file.is_private() {
2718 return false;
2719 }
2720 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2721 return false;
2722 };
2723 project_state
2724 .license_detection_watchers
2725 .get(&file.worktree_id(cx))
2726 .as_ref()
2727 .is_some_and(|watcher| watcher.is_project_open_source())
2728 }
2729
2730 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2731 self.data_collection_choice.is_enabled(cx)
2732 }
2733
2734 fn load_data_collection_choice() -> DataCollectionChoice {
2735 let choice = KEY_VALUE_STORE
2736 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2737 .log_err()
2738 .flatten();
2739
2740 match choice.as_deref() {
2741 Some("true") => DataCollectionChoice::Enabled,
2742 Some("false") => DataCollectionChoice::Disabled,
2743 Some(_) => {
2744 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2745 DataCollectionChoice::NotAnswered
2746 }
2747 None => DataCollectionChoice::NotAnswered,
2748 }
2749 }
2750
2751 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2752 self.data_collection_choice = self.data_collection_choice.toggle();
2753 let new_choice = self.data_collection_choice;
2754 let is_enabled = new_choice.is_enabled(cx);
2755 db::write_and_log(cx, move || {
2756 KEY_VALUE_STORE.write_kvp(
2757 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2758 is_enabled.to_string(),
2759 )
2760 });
2761 }
2762
2763 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2764 self.shown_predictions.iter()
2765 }
2766
2767 pub fn shown_completions_len(&self) -> usize {
2768 self.shown_predictions.len()
2769 }
2770
2771 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2772 self.rated_predictions.contains(id)
2773 }
2774
2775 pub fn rate_prediction(
2776 &mut self,
2777 prediction: &EditPrediction,
2778 rating: EditPredictionRating,
2779 feedback: String,
2780 cx: &mut Context<Self>,
2781 ) {
2782 let organization = self.user_store.read(cx).current_organization();
2783
2784 self.rated_predictions.insert(prediction.id.clone());
2785
2786 cx.background_spawn({
2787 let client = self.client.clone();
2788 let prediction_id = prediction.id.to_string();
2789 let inputs = serde_json::to_value(&prediction.inputs);
2790 let output = prediction
2791 .edit_preview
2792 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2793 async move {
2794 client
2795 .cloud_client()
2796 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2797 organization_id: organization.map(|organization| organization.id.clone()),
2798 request_id: prediction_id,
2799 rating: match rating {
2800 EditPredictionRating::Positive => "positive".to_string(),
2801 EditPredictionRating::Negative => "negative".to_string(),
2802 },
2803 inputs: inputs?,
2804 output,
2805 feedback,
2806 })
2807 .await?;
2808
2809 anyhow::Ok(())
2810 }
2811 })
2812 .detach_and_log_err(cx);
2813
2814 cx.notify();
2815 }
2816}
2817
2818fn collaborator_edit_overlaps_locality_region(
2819 project_state: &ProjectState,
2820 project: &Entity<Project>,
2821 buffer: &Entity<Buffer>,
2822 snapshot: &BufferSnapshot,
2823 edit_range: &Range<Anchor>,
2824 cx: &App,
2825) -> bool {
2826 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2827 return false;
2828 };
2829
2830 if active_buffer.entity_id() != buffer.entity_id() {
2831 return false;
2832 }
2833
2834 let locality_point_range = expand_context_syntactically_then_linewise(
2835 snapshot,
2836 (position..position).to_point(snapshot),
2837 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2838 );
2839 let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2840
2841 edit_range.overlaps(&locality_anchor_range, snapshot)
2842}
2843
2844fn merge_trailing_events_if_needed(
2845 events: &mut VecDeque<StoredEvent>,
2846 end_snapshot: &TextBufferSnapshot,
2847 latest_snapshot: &TextBufferSnapshot,
2848 latest_edit_range: &Range<Anchor>,
2849) {
2850 if let Some(last_event) = events.back() {
2851 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2852 return;
2853 }
2854 if !latest_snapshot
2855 .version
2856 .observed_all(&last_event.new_snapshot_version)
2857 {
2858 return;
2859 }
2860 }
2861
2862 let mut next_old_event = None;
2863 let mut mergeable_count = 0;
2864 for old_event in events.iter().rev() {
2865 if let Some(next_old_event) = next_old_event
2866 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2867 {
2868 break;
2869 }
2870 mergeable_count += 1;
2871 next_old_event = Some(old_event);
2872 }
2873
2874 if mergeable_count <= 1 {
2875 return;
2876 }
2877
2878 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2879 let oldest_event = events_to_merge.peek().unwrap();
2880 let oldest_snapshot = oldest_event.old_snapshot.clone();
2881 let newest_snapshot = end_snapshot;
2882 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2883
2884 for event in events.range(events.len() - mergeable_count + 1..) {
2885 merged_edit_range =
2886 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2887 }
2888
2889 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2890 &oldest_snapshot,
2891 newest_snapshot,
2892 &merged_edit_range,
2893 ) {
2894 let merged_event = match oldest_event.event.as_ref() {
2895 zeta_prompt::Event::BufferChange {
2896 old_path,
2897 path,
2898 in_open_source_repo,
2899 ..
2900 } => StoredEvent {
2901 event: Arc::new(zeta_prompt::Event::BufferChange {
2902 old_path: old_path.clone(),
2903 path: path.clone(),
2904 diff,
2905 in_open_source_repo: *in_open_source_repo,
2906 predicted: events_to_merge.all(|e| {
2907 matches!(
2908 e.event.as_ref(),
2909 zeta_prompt::Event::BufferChange {
2910 predicted: true,
2911 ..
2912 }
2913 )
2914 }),
2915 }),
2916 old_snapshot: oldest_snapshot.clone(),
2917 new_snapshot_version: newest_snapshot.version.clone(),
2918 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2919 ..newest_snapshot.anchor_before(edit_range.end),
2920 },
2921 };
2922 events.truncate(events.len() - mergeable_count);
2923 events.push_back(merged_event);
2924 }
2925}
2926
2927fn merge_anchor_ranges(
2928 left: &Range<Anchor>,
2929 right: &Range<Anchor>,
2930 snapshot: &TextBufferSnapshot,
2931) -> Range<Anchor> {
2932 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2933 left.start
2934 } else {
2935 right.start
2936 };
2937 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2938 left.end
2939 } else {
2940 right.end
2941 };
2942 start..end
2943}
2944
2945#[derive(Error, Debug)]
2946#[error(
2947 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2948)]
2949pub struct ZedUpdateRequiredError {
2950 minimum_version: Version,
2951}
2952
2953#[derive(Debug, Clone, Copy)]
2954pub enum DataCollectionChoice {
2955 NotAnswered,
2956 Enabled,
2957 Disabled,
2958}
2959
2960impl DataCollectionChoice {
2961 pub fn is_enabled(self, cx: &App) -> bool {
2962 if cx.is_staff() {
2963 return true;
2964 }
2965 match self {
2966 Self::Enabled => true,
2967 Self::NotAnswered | Self::Disabled => false,
2968 }
2969 }
2970
2971 #[must_use]
2972 pub fn toggle(&self) -> DataCollectionChoice {
2973 match self {
2974 Self::Enabled => Self::Disabled,
2975 Self::Disabled => Self::Enabled,
2976 Self::NotAnswered => Self::Enabled,
2977 }
2978 }
2979}
2980
2981impl From<bool> for DataCollectionChoice {
2982 fn from(value: bool) -> Self {
2983 match value {
2984 true => DataCollectionChoice::Enabled,
2985 false => DataCollectionChoice::Disabled,
2986 }
2987 }
2988}
2989
2990struct ZedPredictUpsell;
2991
2992impl Dismissable for ZedPredictUpsell {
2993 const KEY: &'static str = "dismissed-edit-predict-upsell";
2994
2995 fn dismissed() -> bool {
2996 // To make this backwards compatible with older versions of Zed, we
2997 // check if the user has seen the previous Edit Prediction Onboarding
2998 // before, by checking the data collection choice which was written to
2999 // the database once the user clicked on "Accept and Enable"
3000 if KEY_VALUE_STORE
3001 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3002 .log_err()
3003 .is_some_and(|s| s.is_some())
3004 {
3005 return true;
3006 }
3007
3008 KEY_VALUE_STORE
3009 .read_kvp(Self::KEY)
3010 .log_err()
3011 .is_some_and(|s| s.is_some())
3012 }
3013}
3014
3015pub fn should_show_upsell_modal() -> bool {
3016 !ZedPredictUpsell::dismissed()
3017}
3018
3019pub fn init(cx: &mut App) {
3020 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3021 workspace.register_action(
3022 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3023 ZedPredictModal::toggle(
3024 workspace,
3025 workspace.user_store().clone(),
3026 workspace.client().clone(),
3027 window,
3028 cx,
3029 )
3030 },
3031 );
3032
3033 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3034 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3035 settings
3036 .project
3037 .all_languages
3038 .edit_predictions
3039 .get_or_insert_default()
3040 .provider = Some(EditPredictionProvider::None)
3041 });
3042 });
3043 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3044 EditPredictionStore::try_global(cx).and_then(|store| {
3045 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3046 })
3047 }
3048
3049 workspace.register_action(|workspace, _: &SignIn, window, cx| {
3050 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3051 copilot_ui::initiate_sign_in(copilot, window, cx);
3052 }
3053 });
3054 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3055 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3056 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3057 }
3058 });
3059 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3060 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3061 copilot_ui::initiate_sign_out(copilot, window, cx);
3062 }
3063 });
3064 })
3065 .detach();
3066}