1use anyhow::Result;
2use client::{Client, EditPredictionUsage, UserStore};
3use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KeyValueStore};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use heapless::Vec as ArrayVec;
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::{AppState, Workspace};
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56
57pub mod cursor_excerpt;
58pub mod example_spec;
59pub mod fim;
60mod license_detection;
61pub mod mercury;
62pub mod ollama;
63mod onboarding_modal;
64pub mod open_ai_response;
65mod prediction;
66pub mod sweep_ai;
67
68pub mod udiff;
69
70mod capture_example;
71pub mod open_ai_compatible;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
79use crate::example_spec::ExampleSpec;
80use crate::license_detection::LicenseDetectionWatcher;
81use crate::mercury::Mercury;
82use crate::onboarding_modal::ZedPredictModal;
83pub use crate::prediction::EditPrediction;
84pub use crate::prediction::EditPredictionId;
85use crate::prediction::EditPredictionResult;
86pub use crate::sweep_ai::SweepAi;
87pub use capture_example::capture_example;
88pub use language_model::ApiKeyState;
89pub use telemetry_events::EditPredictionRating;
90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
91
92actions!(
93 edit_prediction,
94 [
95 /// Resets the edit prediction onboarding state.
96 ResetOnboarding,
97 /// Clears the edit prediction history.
98 ClearHistory,
99 ]
100);
101
102/// Maximum number of events to track.
103const EVENT_COUNT_MAX: usize = 10;
104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
105const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
106const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
107const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
108const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
109const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
110const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
111const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
112const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
113
114pub struct EditPredictionJumpsFeatureFlag;
115
116impl FeatureFlag for EditPredictionJumpsFeatureFlag {
117 const NAME: &'static str = "edit_prediction_jumps";
118}
119
120#[derive(Clone)]
121struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
122
123impl Global for EditPredictionStoreGlobal {}
124
125/// Configuration for using the raw Zeta2 endpoint.
126/// When set, the client uses the raw endpoint and constructs the prompt itself.
127/// The version is also used as the Baseten environment name (lowercased).
128#[derive(Clone)]
129pub struct Zeta2RawConfig {
130 pub model_id: Option<String>,
131 pub environment: Option<String>,
132 pub format: ZetaFormat,
133}
134
135pub struct EditPredictionStore {
136 client: Arc<Client>,
137 user_store: Entity<UserStore>,
138 llm_token: LlmApiToken,
139 _fetch_experiments_task: Task<()>,
140 projects: HashMap<EntityId, ProjectState>,
141 update_required: bool,
142 edit_prediction_model: EditPredictionModel,
143 zeta2_raw_config: Option<Zeta2RawConfig>,
144 preferred_experiment: Option<String>,
145 available_experiments: Vec<String>,
146 pub sweep_ai: SweepAi,
147 pub mercury: Mercury,
148 data_collection_choice: DataCollectionChoice,
149 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
150 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
151 shown_predictions: VecDeque<EditPrediction>,
152 rated_predictions: HashSet<EditPredictionId>,
153 #[cfg(test)]
154 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
155}
156
157pub(crate) struct EditPredictionRejectionPayload {
158 rejection: EditPredictionRejection,
159 organization_id: Option<OrganizationId>,
160}
161
162#[derive(Copy, Clone, PartialEq, Eq)]
163pub enum EditPredictionModel {
164 Zeta,
165 Fim { format: EditPredictionPromptFormat },
166 Sweep,
167 Mercury,
168}
169
170#[derive(Clone)]
171pub struct EditPredictionModelInput {
172 project: Entity<Project>,
173 buffer: Entity<Buffer>,
174 snapshot: BufferSnapshot,
175 position: Anchor,
176 events: Vec<Arc<zeta_prompt::Event>>,
177 related_files: Vec<RelatedFile>,
178 recent_paths: VecDeque<ProjectPath>,
179 trigger: PredictEditsRequestTrigger,
180 diagnostic_search_range: Range<Point>,
181 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
182 can_collect_data: bool,
183 is_open_source: bool,
184 pub user_actions: Vec<UserActionRecord>,
185}
186
187#[derive(Debug)]
188pub enum DebugEvent {
189 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
190 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
191 EditPredictionStarted(EditPredictionStartedDebugEvent),
192 EditPredictionFinished(EditPredictionFinishedDebugEvent),
193}
194
195#[derive(Debug)]
196pub struct ContextRetrievalStartedDebugEvent {
197 pub project_entity_id: EntityId,
198 pub timestamp: Instant,
199 pub search_prompt: String,
200}
201
202#[derive(Debug)]
203pub struct ContextRetrievalFinishedDebugEvent {
204 pub project_entity_id: EntityId,
205 pub timestamp: Instant,
206 pub metadata: Vec<(&'static str, SharedString)>,
207}
208
209#[derive(Debug)]
210pub struct EditPredictionStartedDebugEvent {
211 pub buffer: WeakEntity<Buffer>,
212 pub position: Anchor,
213 pub prompt: Option<String>,
214}
215
216#[derive(Debug)]
217pub struct EditPredictionFinishedDebugEvent {
218 pub buffer: WeakEntity<Buffer>,
219 pub position: Anchor,
220 pub model_output: Option<String>,
221}
222
223const USER_ACTION_HISTORY_SIZE: usize = 16;
224
225#[derive(Clone, Debug)]
226pub struct UserActionRecord {
227 pub action_type: UserActionType,
228 pub buffer_id: EntityId,
229 pub line_number: u32,
230 pub offset: usize,
231 pub timestamp_epoch_ms: u64,
232}
233
234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
235pub enum UserActionType {
236 InsertChar,
237 InsertSelection,
238 DeleteChar,
239 DeleteSelection,
240 CursorMovement,
241}
242
243/// An event with associated metadata for reconstructing buffer state.
244#[derive(Clone)]
245pub struct StoredEvent {
246 pub event: Arc<zeta_prompt::Event>,
247 pub old_snapshot: TextBufferSnapshot,
248 pub new_snapshot_version: clock::Global,
249 pub total_edit_range: Range<Anchor>,
250}
251
252impl StoredEvent {
253 fn can_merge(
254 &self,
255 next_old_event: &StoredEvent,
256 latest_snapshot: &TextBufferSnapshot,
257 latest_edit_range: &Range<Anchor>,
258 ) -> bool {
259 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
260 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
261 return false;
262 }
263 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
264 return false;
265 }
266 if self.new_snapshot_version != next_old_event.old_snapshot.version {
267 return false;
268 }
269 if !latest_snapshot
270 .version
271 .observed_all(&next_old_event.new_snapshot_version)
272 {
273 return false;
274 }
275
276 let a_is_predicted = matches!(
277 self.event.as_ref(),
278 zeta_prompt::Event::BufferChange {
279 predicted: true,
280 ..
281 }
282 );
283 let b_is_predicted = matches!(
284 next_old_event.event.as_ref(),
285 zeta_prompt::Event::BufferChange {
286 predicted: true,
287 ..
288 }
289 );
290
291 // If events come from the same source (both predicted or both manual) then
292 // we would have coalesced them already.
293 if a_is_predicted == b_is_predicted {
294 return false;
295 }
296
297 let left_range = self.total_edit_range.to_point(latest_snapshot);
298 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
299 let latest_range = latest_edit_range.to_point(latest_snapshot);
300
301 // Events near to the latest edit are not merged if their sources differ.
302 if lines_between_ranges(&left_range, &latest_range)
303 .min(lines_between_ranges(&right_range, &latest_range))
304 <= CHANGE_GROUPING_LINE_SPAN
305 {
306 return false;
307 }
308
309 // Events that are distant from each other are not merged.
310 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
311 return false;
312 }
313
314 true
315 }
316}
317
318fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
319 if left.start > right.end {
320 return left.start.row - right.end.row;
321 }
322 if right.start > left.end {
323 return right.start.row - left.end.row;
324 }
325 0
326}
327
328struct ProjectState {
329 events: VecDeque<StoredEvent>,
330 last_event: Option<LastEvent>,
331 recent_paths: VecDeque<ProjectPath>,
332 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
333 current_prediction: Option<CurrentEditPrediction>,
334 next_pending_prediction_id: usize,
335 pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
336 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
337 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
338 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
339 cancelled_predictions: HashSet<usize>,
340 context: Entity<RelatedExcerptStore>,
341 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
342 user_actions: VecDeque<UserActionRecord>,
343 _subscriptions: [gpui::Subscription; 2],
344 copilot: Option<Entity<Copilot>>,
345}
346
347impl ProjectState {
348 fn record_user_action(&mut self, action: UserActionRecord) {
349 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
350 self.user_actions.pop_front();
351 }
352 self.user_actions.push_back(action);
353 }
354
355 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
356 self.events
357 .iter()
358 .cloned()
359 .chain(self.last_event.as_ref().iter().flat_map(|event| {
360 let (one, two) = event.split_by_pause();
361 let one = one.finalize(&self.license_detection_watchers, cx);
362 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
363 one.into_iter().chain(two)
364 }))
365 .collect()
366 }
367
368 fn cancel_pending_prediction(
369 &mut self,
370 pending_prediction: PendingPrediction,
371 cx: &mut Context<EditPredictionStore>,
372 ) {
373 self.cancelled_predictions.insert(pending_prediction.id);
374
375 if pending_prediction.drop_on_cancel {
376 drop(pending_prediction.task);
377 } else {
378 cx.spawn(async move |this, cx| {
379 let Some(prediction_id) = pending_prediction.task.await else {
380 return;
381 };
382
383 this.update(cx, |this, cx| {
384 this.reject_prediction(
385 prediction_id,
386 EditPredictionRejectReason::Canceled,
387 false,
388 None,
389 None,
390 cx,
391 );
392 })
393 .ok();
394 })
395 .detach()
396 }
397 }
398
399 fn active_buffer(
400 &self,
401 project: &Entity<Project>,
402 cx: &App,
403 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
404 let project = project.read(cx);
405 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
406 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
407 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
408 Some((active_buffer, registered_buffer.last_position))
409 }
410}
411
412#[derive(Debug, Clone)]
413struct CurrentEditPrediction {
414 pub requested_by: PredictionRequestedBy,
415 pub prediction: EditPrediction,
416 pub was_shown: bool,
417 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
418 pub e2e_latency: std::time::Duration,
419}
420
421impl CurrentEditPrediction {
422 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
423 let Some(new_edits) = self
424 .prediction
425 .interpolate(&self.prediction.buffer.read(cx))
426 else {
427 return false;
428 };
429
430 if self.prediction.buffer != old_prediction.prediction.buffer {
431 return true;
432 }
433
434 let Some(old_edits) = old_prediction
435 .prediction
436 .interpolate(&old_prediction.prediction.buffer.read(cx))
437 else {
438 return true;
439 };
440
441 let requested_by_buffer_id = self.requested_by.buffer_id();
442
443 // This reduces the occurrence of UI thrash from replacing edits
444 //
445 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
446 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
447 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
448 && old_edits.len() == 1
449 && new_edits.len() == 1
450 {
451 let (old_range, old_text) = &old_edits[0];
452 let (new_range, new_text) = &new_edits[0];
453 new_range == old_range && new_text.starts_with(old_text.as_ref())
454 } else {
455 true
456 }
457 }
458}
459
460#[derive(Debug, Clone)]
461enum PredictionRequestedBy {
462 DiagnosticsUpdate,
463 Buffer(EntityId),
464}
465
466impl PredictionRequestedBy {
467 pub fn buffer_id(&self) -> Option<EntityId> {
468 match self {
469 PredictionRequestedBy::DiagnosticsUpdate => None,
470 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
471 }
472 }
473}
474
475const DIAGNOSTIC_LINES_RANGE: u32 = 20;
476
477#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
478pub enum DiagnosticSearchScope {
479 Local,
480 Global,
481}
482
483#[derive(Debug)]
484struct PendingPrediction {
485 id: usize,
486 task: Task<Option<EditPredictionId>>,
487 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
488 /// If false, the task is awaited to completion so rejection can be reported.
489 drop_on_cancel: bool,
490}
491
492/// A prediction from the perspective of a buffer.
493#[derive(Debug)]
494enum BufferEditPrediction<'a> {
495 Local { prediction: &'a EditPrediction },
496 Jump { prediction: &'a EditPrediction },
497}
498
499#[cfg(test)]
500impl std::ops::Deref for BufferEditPrediction<'_> {
501 type Target = EditPrediction;
502
503 fn deref(&self) -> &Self::Target {
504 match self {
505 BufferEditPrediction::Local { prediction } => prediction,
506 BufferEditPrediction::Jump { prediction } => prediction,
507 }
508 }
509}
510
511#[derive(Clone)]
512
513struct PendingSettledPrediction {
514 request_id: EditPredictionId,
515 editable_anchor_range: Range<Anchor>,
516 example: Option<ExampleSpec>,
517 enqueued_at: Instant,
518 last_edit_at: Instant,
519 e2e_latency: std::time::Duration,
520}
521
522struct RegisteredBuffer {
523 file: Option<Arc<dyn File>>,
524 snapshot: TextBufferSnapshot,
525 pending_predictions: Vec<PendingSettledPrediction>,
526 last_position: Option<Anchor>,
527 _subscriptions: [gpui::Subscription; 2],
528}
529
530#[derive(Clone)]
531struct LastEvent {
532 old_snapshot: TextBufferSnapshot,
533 new_snapshot: TextBufferSnapshot,
534 old_file: Option<Arc<dyn File>>,
535 new_file: Option<Arc<dyn File>>,
536 latest_edit_range: Range<Anchor>,
537 total_edit_range: Range<Anchor>,
538 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
539 predicted: bool,
540 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
541 last_edit_time: Option<Instant>,
542}
543
544impl LastEvent {
545 pub fn finalize(
546 &self,
547 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
548 cx: &App,
549 ) -> Option<StoredEvent> {
550 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
551 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
552
553 let in_open_source_repo =
554 [self.new_file.as_ref(), self.old_file.as_ref()]
555 .iter()
556 .all(|file| {
557 file.is_some_and(|file| {
558 license_detection_watchers
559 .get(&file.worktree_id(cx))
560 .is_some_and(|watcher| watcher.is_project_open_source())
561 })
562 });
563
564 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
565 &self.old_snapshot,
566 &self.new_snapshot,
567 &self.total_edit_range,
568 )?;
569
570 if path == old_path && diff.is_empty() {
571 None
572 } else {
573 Some(StoredEvent {
574 event: Arc::new(zeta_prompt::Event::BufferChange {
575 old_path,
576 path,
577 diff,
578 in_open_source_repo,
579 predicted: self.predicted,
580 }),
581 old_snapshot: self.old_snapshot.clone(),
582 new_snapshot_version: self.new_snapshot.version.clone(),
583 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
584 ..self.new_snapshot.anchor_before(edit_range.end),
585 })
586 }
587 }
588
589 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
590 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
591 return (self.clone(), None);
592 };
593
594 let total_edit_range_before_pause = self
595 .total_edit_range_at_last_pause_boundary
596 .clone()
597 .unwrap_or_else(|| self.total_edit_range.clone());
598
599 let Some(total_edit_range_after_pause) =
600 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
601 else {
602 return (self.clone(), None);
603 };
604
605 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
606 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
607
608 let before = LastEvent {
609 old_snapshot: self.old_snapshot.clone(),
610 new_snapshot: boundary_snapshot.clone(),
611 old_file: self.old_file.clone(),
612 new_file: self.new_file.clone(),
613 latest_edit_range: latest_edit_range_before_pause,
614 total_edit_range: total_edit_range_before_pause,
615 total_edit_range_at_last_pause_boundary: None,
616 predicted: self.predicted,
617 snapshot_after_last_editing_pause: None,
618 last_edit_time: self.last_edit_time,
619 };
620
621 let after = LastEvent {
622 old_snapshot: boundary_snapshot.clone(),
623 new_snapshot: self.new_snapshot.clone(),
624 old_file: self.old_file.clone(),
625 new_file: self.new_file.clone(),
626 latest_edit_range: latest_edit_range_after_pause,
627 total_edit_range: total_edit_range_after_pause,
628 total_edit_range_at_last_pause_boundary: None,
629 predicted: self.predicted,
630 snapshot_after_last_editing_pause: None,
631 last_edit_time: self.last_edit_time,
632 };
633
634 (before, Some(after))
635 }
636}
637
638fn compute_total_edit_range_between_snapshots(
639 old_snapshot: &TextBufferSnapshot,
640 new_snapshot: &TextBufferSnapshot,
641) -> Option<Range<Anchor>> {
642 let edits: Vec<Edit<usize>> = new_snapshot
643 .edits_since::<usize>(&old_snapshot.version)
644 .collect();
645
646 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
647 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
648 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
649
650 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
651}
652
653fn compute_old_range_for_new_range(
654 old_snapshot: &TextBufferSnapshot,
655 new_snapshot: &TextBufferSnapshot,
656 total_edit_range: &Range<Anchor>,
657) -> Option<Range<Point>> {
658 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
659 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
660
661 let edits: Vec<Edit<usize>> = new_snapshot
662 .edits_since::<usize>(&old_snapshot.version)
663 .collect();
664 let mut old_start_offset = None;
665 let mut old_end_offset = None;
666 let mut delta: isize = 0;
667
668 for edit in &edits {
669 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
670 old_start_offset = Some(if new_start_offset < edit.new.start {
671 new_start_offset.checked_add_signed(-delta)?
672 } else {
673 edit.old.start
674 });
675 }
676
677 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
678 old_end_offset = Some(if new_end_offset < edit.new.start {
679 new_end_offset.checked_add_signed(-delta)?
680 } else {
681 edit.old.end
682 });
683 }
684
685 delta += edit.new.len() as isize - edit.old.len() as isize;
686 }
687
688 let old_start_offset =
689 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
690 let old_end_offset =
691 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
692
693 Some(
694 old_snapshot.offset_to_point(old_start_offset)
695 ..old_snapshot.offset_to_point(old_end_offset),
696 )
697}
698
699fn compute_diff_between_snapshots_in_range(
700 old_snapshot: &TextBufferSnapshot,
701 new_snapshot: &TextBufferSnapshot,
702 total_edit_range: &Range<Anchor>,
703) -> Option<(String, Range<Point>)> {
704 let new_start_point = total_edit_range.start.to_point(new_snapshot);
705 let new_end_point = total_edit_range.end.to_point(new_snapshot);
706 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
707 let old_start_point = old_range.start;
708 let old_end_point = old_range.end;
709
710 const CONTEXT_LINES: u32 = 3;
711
712 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
713 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
714 let old_context_end_row =
715 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
716 let new_context_end_row =
717 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
718
719 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
720 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
721 let old_end_line_offset = old_snapshot
722 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
723 let new_end_line_offset = new_snapshot
724 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
725 let old_edit_range = old_start_line_offset..old_end_line_offset;
726 let new_edit_range = new_start_line_offset..new_end_line_offset;
727
728 if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
729 || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
730 {
731 return None;
732 }
733
734 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
735 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
736
737 let diff = language::unified_diff_with_offsets(
738 &old_region_text,
739 &new_region_text,
740 old_context_start_row,
741 new_context_start_row,
742 );
743
744 Some((diff, new_start_point..new_end_point))
745}
746
747fn buffer_path_with_id_fallback(
748 file: Option<&Arc<dyn File>>,
749 snapshot: &TextBufferSnapshot,
750 cx: &App,
751) -> Arc<Path> {
752 if let Some(file) = file {
753 file.full_path(cx).into()
754 } else {
755 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
756 }
757}
758
759impl EditPredictionStore {
760 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
761 cx.try_global::<EditPredictionStoreGlobal>()
762 .map(|global| global.0.clone())
763 }
764
765 pub fn global(
766 client: &Arc<Client>,
767 user_store: &Entity<UserStore>,
768 cx: &mut App,
769 ) -> Entity<Self> {
770 cx.try_global::<EditPredictionStoreGlobal>()
771 .map(|global| global.0.clone())
772 .unwrap_or_else(|| {
773 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
774 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
775 ep_store
776 })
777 }
778
779 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
780 let data_collection_choice = Self::load_data_collection_choice(cx);
781
782 let llm_token = LlmApiToken::global(cx);
783
784 let (reject_tx, reject_rx) = mpsc::unbounded();
785 cx.background_spawn({
786 let client = client.clone();
787 let llm_token = llm_token.clone();
788 let app_version = AppVersion::global(cx);
789 let background_executor = cx.background_executor().clone();
790 async move {
791 Self::handle_rejected_predictions(
792 reject_rx,
793 client,
794 llm_token,
795 app_version,
796 background_executor,
797 )
798 .await
799 }
800 })
801 .detach();
802
803 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
804 cx.spawn(async move |this, cx| {
805 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
806 })
807 .detach();
808
809 let mut current_user = user_store.read(cx).watch_current_user();
810 let fetch_experiments_task = cx.spawn(async move |this, cx| {
811 while current_user.borrow().is_none() {
812 current_user.next().await;
813 }
814 this.update(cx, |this, cx| {
815 this.refresh_available_experiments(cx);
816 })
817 .log_err();
818 });
819
820 let this = Self {
821 projects: HashMap::default(),
822 client,
823 user_store,
824 llm_token,
825 _fetch_experiments_task: fetch_experiments_task,
826 update_required: false,
827 edit_prediction_model: EditPredictionModel::Zeta,
828 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
829 preferred_experiment: None,
830 available_experiments: Vec::new(),
831 sweep_ai: SweepAi::new(cx),
832 mercury: Mercury::new(cx),
833
834 data_collection_choice,
835 reject_predictions_tx: reject_tx,
836 settled_predictions_tx,
837 rated_predictions: Default::default(),
838 shown_predictions: Default::default(),
839 #[cfg(test)]
840 settled_event_callback: None,
841 };
842
843 this
844 }
845
846 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
847 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
848 let format = ZetaFormat::parse(&version_str).ok()?;
849 let model_id = env::var("ZED_ZETA_MODEL").ok();
850 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
851 Some(Zeta2RawConfig {
852 model_id,
853 environment,
854 format,
855 })
856 }
857
858 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
859 self.edit_prediction_model = model;
860 }
861
862 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
863 self.zeta2_raw_config = Some(config);
864 }
865
866 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
867 self.zeta2_raw_config.as_ref()
868 }
869
870 pub fn preferred_experiment(&self) -> Option<&str> {
871 self.preferred_experiment.as_deref()
872 }
873
874 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
875 self.preferred_experiment = experiment;
876 }
877
878 pub fn available_experiments(&self) -> &[String] {
879 &self.available_experiments
880 }
881
882 pub fn active_experiment(&self) -> Option<&str> {
883 self.preferred_experiment.as_deref().or_else(|| {
884 self.shown_predictions
885 .iter()
886 .find_map(|p| p.model_version.as_ref())
887 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
888 })
889 }
890
891 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
892 let client = self.client.clone();
893 let llm_token = self.llm_token.clone();
894 let app_version = AppVersion::global(cx);
895 let organization_id = self
896 .user_store
897 .read(cx)
898 .current_organization()
899 .map(|organization| organization.id.clone());
900
901 cx.spawn(async move |this, cx| {
902 let experiments = cx
903 .background_spawn(async move {
904 let http_client = client.http_client();
905 let token = llm_token.acquire(&client, organization_id).await?;
906 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
907 let request = http_client::Request::builder()
908 .method(Method::GET)
909 .uri(url.as_ref())
910 .header("Authorization", format!("Bearer {}", token))
911 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
912 .body(Default::default())?;
913 let mut response = http_client.send(request).await?;
914 if response.status().is_success() {
915 let mut body = Vec::new();
916 response.body_mut().read_to_end(&mut body).await?;
917 let experiments: Vec<String> = serde_json::from_slice(&body)?;
918 Ok(experiments)
919 } else {
920 let mut body = String::new();
921 response.body_mut().read_to_string(&mut body).await?;
922 anyhow::bail!(
923 "Failed to fetch experiments: {:?}\nBody: {}",
924 response.status(),
925 body
926 );
927 }
928 })
929 .await?;
930 this.update(cx, |this, cx| {
931 this.available_experiments = experiments;
932 cx.notify();
933 })?;
934 anyhow::Ok(())
935 })
936 .detach_and_log_err(cx);
937 }
938
939 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
940 use ui::IconName;
941 match self.edit_prediction_model {
942 EditPredictionModel::Sweep => {
943 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
944 .with_disabled(IconName::SweepAiDisabled)
945 .with_up(IconName::SweepAiUp)
946 .with_down(IconName::SweepAiDown)
947 .with_error(IconName::SweepAiError)
948 }
949 EditPredictionModel::Mercury => {
950 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
951 }
952 EditPredictionModel::Zeta => {
953 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
954 .with_disabled(IconName::ZedPredictDisabled)
955 .with_up(IconName::ZedPredictUp)
956 .with_down(IconName::ZedPredictDown)
957 .with_error(IconName::ZedPredictError)
958 }
959 EditPredictionModel::Fim { .. } => {
960 let settings = &all_language_settings(None, cx).edit_predictions;
961 match settings.provider {
962 EditPredictionProvider::Ollama => {
963 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
964 }
965 _ => {
966 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
967 }
968 }
969 }
970 }
971 }
972
973 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
974 self.sweep_ai.api_token.read(cx).has_key()
975 }
976
977 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
978 self.mercury.api_token.read(cx).has_key()
979 }
980
981 pub fn mercury_has_payment_required_error(&self) -> bool {
982 self.mercury.has_payment_required_error()
983 }
984
985 pub fn clear_history(&mut self) {
986 for project_state in self.projects.values_mut() {
987 project_state.events.clear();
988 project_state.last_event.take();
989 }
990 }
991
992 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
993 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
994 project_state.events.clear();
995 project_state.last_event.take();
996 }
997 }
998
999 pub fn edit_history_for_project(
1000 &self,
1001 project: &Entity<Project>,
1002 cx: &App,
1003 ) -> Vec<StoredEvent> {
1004 self.projects
1005 .get(&project.entity_id())
1006 .map(|project_state| project_state.events(cx))
1007 .unwrap_or_default()
1008 }
1009
1010 pub fn context_for_project<'a>(
1011 &'a self,
1012 project: &Entity<Project>,
1013 cx: &'a mut App,
1014 ) -> Vec<RelatedFile> {
1015 self.projects
1016 .get(&project.entity_id())
1017 .map(|project_state| {
1018 project_state.context.update(cx, |context, cx| {
1019 context
1020 .related_files_with_buffers(cx)
1021 .map(|(mut related_file, buffer)| {
1022 related_file.in_open_source_repo = buffer
1023 .read(cx)
1024 .file()
1025 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1026 related_file
1027 })
1028 .collect()
1029 })
1030 })
1031 .unwrap_or_default()
1032 }
1033
1034 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1035 self.projects
1036 .get(&project.entity_id())
1037 .and_then(|project| project.copilot.clone())
1038 }
1039
1040 pub fn start_copilot_for_project(
1041 &mut self,
1042 project: &Entity<Project>,
1043 cx: &mut Context<Self>,
1044 ) -> Option<Entity<Copilot>> {
1045 if DisableAiSettings::get(None, cx).disable_ai {
1046 return None;
1047 }
1048 let state = self.get_or_init_project(project, cx);
1049
1050 if state.copilot.is_some() {
1051 return state.copilot.clone();
1052 }
1053 let _project = project.clone();
1054 let project = project.read(cx);
1055
1056 let node = project.node_runtime().cloned();
1057 if let Some(node) = node {
1058 let next_id = project.languages().next_language_server_id();
1059 let fs = project.fs().clone();
1060
1061 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1062 state.copilot = Some(copilot.clone());
1063 Some(copilot)
1064 } else {
1065 None
1066 }
1067 }
1068
1069 pub fn context_for_project_with_buffers<'a>(
1070 &'a self,
1071 project: &Entity<Project>,
1072 cx: &'a mut App,
1073 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1074 self.projects
1075 .get(&project.entity_id())
1076 .map(|project| {
1077 project.context.update(cx, |context, cx| {
1078 context.related_files_with_buffers(cx).collect()
1079 })
1080 })
1081 .unwrap_or_default()
1082 }
1083
1084 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1085 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1086 self.user_store.read(cx).edit_prediction_usage()
1087 } else {
1088 None
1089 }
1090 }
1091
1092 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1093 self.get_or_init_project(project, cx);
1094 }
1095
1096 pub fn register_buffer(
1097 &mut self,
1098 buffer: &Entity<Buffer>,
1099 project: &Entity<Project>,
1100 cx: &mut Context<Self>,
1101 ) {
1102 let project_state = self.get_or_init_project(project, cx);
1103 Self::register_buffer_impl(project_state, buffer, project, cx);
1104 }
1105
1106 fn get_or_init_project(
1107 &mut self,
1108 project: &Entity<Project>,
1109 cx: &mut Context<Self>,
1110 ) -> &mut ProjectState {
1111 let entity_id = project.entity_id();
1112 self.projects
1113 .entry(entity_id)
1114 .or_insert_with(|| ProjectState {
1115 context: {
1116 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1117 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1118 this.handle_excerpt_store_event(entity_id, event);
1119 })
1120 .detach();
1121 related_excerpt_store
1122 },
1123 events: VecDeque::new(),
1124 last_event: None,
1125 recent_paths: VecDeque::new(),
1126 debug_tx: None,
1127 registered_buffers: HashMap::default(),
1128 current_prediction: None,
1129 cancelled_predictions: HashSet::default(),
1130 pending_predictions: ArrayVec::new(),
1131 next_pending_prediction_id: 0,
1132 last_edit_prediction_refresh: None,
1133 last_jump_prediction_refresh: None,
1134 license_detection_watchers: HashMap::default(),
1135 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1136 _subscriptions: [
1137 cx.subscribe(&project, Self::handle_project_event),
1138 cx.observe_release(&project, move |this, _, cx| {
1139 this.projects.remove(&entity_id);
1140 cx.notify();
1141 }),
1142 ],
1143 copilot: None,
1144 })
1145 }
1146
1147 pub fn remove_project(&mut self, project: &Entity<Project>) {
1148 self.projects.remove(&project.entity_id());
1149 }
1150
1151 fn handle_excerpt_store_event(
1152 &mut self,
1153 project_entity_id: EntityId,
1154 event: &RelatedExcerptStoreEvent,
1155 ) {
1156 if let Some(project_state) = self.projects.get(&project_entity_id) {
1157 if let Some(debug_tx) = project_state.debug_tx.clone() {
1158 match event {
1159 RelatedExcerptStoreEvent::StartedRefresh => {
1160 debug_tx
1161 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1162 ContextRetrievalStartedDebugEvent {
1163 project_entity_id: project_entity_id,
1164 timestamp: Instant::now(),
1165 search_prompt: String::new(),
1166 },
1167 ))
1168 .ok();
1169 }
1170 RelatedExcerptStoreEvent::FinishedRefresh {
1171 cache_hit_count,
1172 cache_miss_count,
1173 mean_definition_latency,
1174 max_definition_latency,
1175 } => {
1176 debug_tx
1177 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1178 ContextRetrievalFinishedDebugEvent {
1179 project_entity_id: project_entity_id,
1180 timestamp: Instant::now(),
1181 metadata: vec![
1182 (
1183 "Cache Hits",
1184 format!(
1185 "{}/{}",
1186 cache_hit_count,
1187 cache_hit_count + cache_miss_count
1188 )
1189 .into(),
1190 ),
1191 (
1192 "Max LSP Time",
1193 format!("{} ms", max_definition_latency.as_millis())
1194 .into(),
1195 ),
1196 (
1197 "Mean LSP Time",
1198 format!("{} ms", mean_definition_latency.as_millis())
1199 .into(),
1200 ),
1201 ],
1202 },
1203 ))
1204 .ok();
1205 }
1206 }
1207 }
1208 }
1209 }
1210
1211 pub fn debug_info(
1212 &mut self,
1213 project: &Entity<Project>,
1214 cx: &mut Context<Self>,
1215 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1216 let project_state = self.get_or_init_project(project, cx);
1217 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1218 project_state.debug_tx = Some(debug_watch_tx);
1219 debug_watch_rx
1220 }
1221
1222 fn handle_project_event(
1223 &mut self,
1224 project: Entity<Project>,
1225 event: &project::Event,
1226 cx: &mut Context<Self>,
1227 ) {
1228 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1229 return;
1230 }
1231 // TODO [zeta2] init with recent paths
1232 match event {
1233 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1234 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1235 return;
1236 };
1237 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1238 if let Some(path) = path {
1239 if let Some(ix) = project_state
1240 .recent_paths
1241 .iter()
1242 .position(|probe| probe == &path)
1243 {
1244 project_state.recent_paths.remove(ix);
1245 }
1246 project_state.recent_paths.push_front(path);
1247 }
1248 }
1249 project::Event::DiagnosticsUpdated { .. } => {
1250 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1251 self.refresh_prediction_from_diagnostics(
1252 project,
1253 DiagnosticSearchScope::Global,
1254 cx,
1255 );
1256 }
1257 }
1258 _ => (),
1259 }
1260 }
1261
1262 fn register_buffer_impl<'a>(
1263 project_state: &'a mut ProjectState,
1264 buffer: &Entity<Buffer>,
1265 project: &Entity<Project>,
1266 cx: &mut Context<Self>,
1267 ) -> &'a mut RegisteredBuffer {
1268 let buffer_id = buffer.entity_id();
1269
1270 if let Some(file) = buffer.read(cx).file() {
1271 let worktree_id = file.worktree_id(cx);
1272 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1273 project_state
1274 .license_detection_watchers
1275 .entry(worktree_id)
1276 .or_insert_with(|| {
1277 let project_entity_id = project.entity_id();
1278 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1279 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1280 else {
1281 return;
1282 };
1283 project_state
1284 .license_detection_watchers
1285 .remove(&worktree_id);
1286 })
1287 .detach();
1288 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1289 });
1290 }
1291 }
1292
1293 match project_state.registered_buffers.entry(buffer_id) {
1294 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1295 hash_map::Entry::Vacant(entry) => {
1296 let buf = buffer.read(cx);
1297 let snapshot = buf.text_snapshot();
1298 let file = buf.file().cloned();
1299 let project_entity_id = project.entity_id();
1300 entry.insert(RegisteredBuffer {
1301 snapshot,
1302 file,
1303 last_position: None,
1304 pending_predictions: Vec::new(),
1305 _subscriptions: [
1306 cx.subscribe(buffer, {
1307 let project = project.downgrade();
1308 move |this, buffer, event, cx| {
1309 if let language::BufferEvent::Edited { is_local } = event
1310 && let Some(project) = project.upgrade()
1311 {
1312 this.report_changes_for_buffer(
1313 &buffer, &project, false, *is_local, cx,
1314 );
1315 }
1316 }
1317 }),
1318 cx.observe_release(buffer, move |this, _buffer, _cx| {
1319 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1320 else {
1321 return;
1322 };
1323 project_state.registered_buffers.remove(&buffer_id);
1324 }),
1325 ],
1326 })
1327 }
1328 }
1329 }
1330
1331 fn report_changes_for_buffer(
1332 &mut self,
1333 buffer: &Entity<Buffer>,
1334 project: &Entity<Project>,
1335 is_predicted: bool,
1336 is_local: bool,
1337 cx: &mut Context<Self>,
1338 ) {
1339 let project_state = self.get_or_init_project(project, cx);
1340 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1341
1342 let buf = buffer.read(cx);
1343 let new_file = buf.file().cloned();
1344 let new_snapshot = buf.text_snapshot();
1345 if new_snapshot.version == registered_buffer.snapshot.version {
1346 return;
1347 }
1348 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1349 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1350 let mut num_edits = 0usize;
1351 let mut total_deleted = 0usize;
1352 let mut total_inserted = 0usize;
1353 let mut edit_range: Option<Range<Anchor>> = None;
1354 let mut last_offset: Option<usize> = None;
1355 let now = cx.background_executor().now();
1356
1357 for (edit, anchor_range) in
1358 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1359 {
1360 num_edits += 1;
1361 total_deleted += edit.old.len();
1362 total_inserted += edit.new.len();
1363 edit_range = Some(match edit_range {
1364 None => anchor_range,
1365 Some(acc) => acc.start..anchor_range.end,
1366 });
1367 last_offset = Some(edit.new.end);
1368 }
1369
1370 let Some(edit_range) = edit_range else {
1371 return;
1372 };
1373
1374 for pending_prediction in &mut registered_buffer.pending_predictions {
1375 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1376 pending_prediction.last_edit_at = now;
1377 }
1378 }
1379
1380 let include_in_history = is_local
1381 || collaborator_edit_overlaps_locality_region(
1382 project_state,
1383 project,
1384 buffer,
1385 &buf.snapshot(),
1386 &edit_range,
1387 cx,
1388 );
1389
1390 if is_local {
1391 let action_type = match (total_deleted, total_inserted, num_edits) {
1392 (0, ins, n) if ins == n => UserActionType::InsertChar,
1393 (0, _, _) => UserActionType::InsertSelection,
1394 (del, 0, n) if del == n => UserActionType::DeleteChar,
1395 (_, 0, _) => UserActionType::DeleteSelection,
1396 (_, ins, n) if ins == n => UserActionType::InsertChar,
1397 (_, _, _) => UserActionType::InsertSelection,
1398 };
1399
1400 if let Some(offset) = last_offset {
1401 let point = new_snapshot.offset_to_point(offset);
1402 let timestamp_epoch_ms = SystemTime::now()
1403 .duration_since(UNIX_EPOCH)
1404 .map(|d| d.as_millis() as u64)
1405 .unwrap_or(0);
1406 project_state.record_user_action(UserActionRecord {
1407 action_type,
1408 buffer_id: buffer.entity_id(),
1409 line_number: point.row,
1410 offset,
1411 timestamp_epoch_ms,
1412 });
1413 }
1414 }
1415
1416 if !include_in_history {
1417 return;
1418 }
1419
1420 let is_recordable_history_edit =
1421 compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1422 .is_some();
1423
1424 let events = &mut project_state.events;
1425
1426 if !is_recordable_history_edit {
1427 if let Some(event) = project_state.last_event.take() {
1428 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1429 if events.len() + 1 >= EVENT_COUNT_MAX {
1430 events.pop_front();
1431 }
1432 events.push_back(event);
1433 }
1434 }
1435 return;
1436 }
1437
1438 if let Some(last_event) = project_state.last_event.as_mut() {
1439 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1440 == last_event.new_snapshot.remote_id()
1441 && old_snapshot.version == last_event.new_snapshot.version;
1442
1443 let prediction_source_changed = is_predicted != last_event.predicted;
1444
1445 let should_coalesce = is_next_snapshot_of_same_buffer
1446 && !prediction_source_changed
1447 && lines_between_ranges(
1448 &edit_range.to_point(&new_snapshot),
1449 &last_event.latest_edit_range.to_point(&new_snapshot),
1450 ) <= CHANGE_GROUPING_LINE_SPAN;
1451
1452 if should_coalesce {
1453 let pause_elapsed = last_event
1454 .last_edit_time
1455 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1456 .unwrap_or(false);
1457 if pause_elapsed {
1458 last_event.snapshot_after_last_editing_pause =
1459 Some(last_event.new_snapshot.clone());
1460 last_event.total_edit_range_at_last_pause_boundary =
1461 Some(last_event.total_edit_range.clone());
1462 }
1463
1464 last_event.latest_edit_range = edit_range.clone();
1465 last_event.total_edit_range =
1466 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1467 last_event.new_snapshot = new_snapshot;
1468 last_event.last_edit_time = Some(now);
1469 return;
1470 }
1471 }
1472
1473 if let Some(event) = project_state.last_event.take() {
1474 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1475 if events.len() + 1 >= EVENT_COUNT_MAX {
1476 events.pop_front();
1477 }
1478 events.push_back(event);
1479 }
1480 }
1481
1482 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1483
1484 project_state.last_event = Some(LastEvent {
1485 old_file,
1486 new_file,
1487 old_snapshot,
1488 new_snapshot,
1489 latest_edit_range: edit_range.clone(),
1490 total_edit_range: edit_range,
1491 total_edit_range_at_last_pause_boundary: None,
1492 predicted: is_predicted,
1493 snapshot_after_last_editing_pause: None,
1494 last_edit_time: Some(now),
1495 });
1496 }
1497
1498 fn prediction_at(
1499 &mut self,
1500 buffer: &Entity<Buffer>,
1501 position: Option<language::Anchor>,
1502 project: &Entity<Project>,
1503 cx: &App,
1504 ) -> Option<BufferEditPrediction<'_>> {
1505 let project_state = self.projects.get_mut(&project.entity_id())?;
1506 if let Some(position) = position
1507 && let Some(buffer) = project_state
1508 .registered_buffers
1509 .get_mut(&buffer.entity_id())
1510 {
1511 buffer.last_position = Some(position);
1512 }
1513
1514 let CurrentEditPrediction {
1515 requested_by,
1516 prediction,
1517 ..
1518 } = project_state.current_prediction.as_ref()?;
1519
1520 if prediction.targets_buffer(buffer.read(cx)) {
1521 Some(BufferEditPrediction::Local { prediction })
1522 } else {
1523 let show_jump = match requested_by {
1524 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1525 requested_by_buffer_id == &buffer.entity_id()
1526 }
1527 PredictionRequestedBy::DiagnosticsUpdate => true,
1528 };
1529
1530 if show_jump {
1531 Some(BufferEditPrediction::Jump { prediction })
1532 } else {
1533 None
1534 }
1535 }
1536 }
1537
1538 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1539 let Some(current_prediction) = self
1540 .projects
1541 .get_mut(&project.entity_id())
1542 .and_then(|project_state| project_state.current_prediction.take())
1543 else {
1544 return;
1545 };
1546
1547 self.report_changes_for_buffer(
1548 ¤t_prediction.prediction.buffer,
1549 project,
1550 true,
1551 true,
1552 cx,
1553 );
1554
1555 // can't hold &mut project_state ref across report_changes_for_buffer_call
1556 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1557 return;
1558 };
1559
1560 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1561 project_state.cancel_pending_prediction(pending_prediction, cx);
1562 }
1563
1564 match self.edit_prediction_model {
1565 EditPredictionModel::Sweep => {
1566 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1567 }
1568 EditPredictionModel::Mercury => {
1569 mercury::edit_prediction_accepted(
1570 current_prediction.prediction.id,
1571 self.client.http_client(),
1572 cx,
1573 );
1574 }
1575 EditPredictionModel::Zeta => {
1576 let is_cloud = !matches!(
1577 all_language_settings(None, cx).edit_predictions.provider,
1578 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1579 );
1580 if is_cloud {
1581 zeta::edit_prediction_accepted(self, current_prediction, cx)
1582 }
1583 }
1584 EditPredictionModel::Fim { .. } => {}
1585 }
1586 }
1587
1588 async fn handle_rejected_predictions(
1589 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1590 client: Arc<Client>,
1591 llm_token: LlmApiToken,
1592 app_version: Version,
1593 background_executor: BackgroundExecutor,
1594 ) {
1595 let mut rx = std::pin::pin!(rx.peekable());
1596 let mut batched = Vec::new();
1597
1598 while let Some(EditPredictionRejectionPayload {
1599 rejection,
1600 organization_id,
1601 }) = rx.next().await
1602 {
1603 batched.push(rejection);
1604
1605 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1606 select_biased! {
1607 next = rx.as_mut().peek().fuse() => {
1608 if next.is_some() {
1609 continue;
1610 }
1611 }
1612 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1613 }
1614 }
1615
1616 let url = client
1617 .http_client()
1618 .build_zed_llm_url("/predict_edits/reject", &[])
1619 .unwrap();
1620
1621 let flush_count = batched
1622 .len()
1623 // in case items have accumulated after failure
1624 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1625 let start = batched.len() - flush_count;
1626
1627 let body = RejectEditPredictionsBodyRef {
1628 rejections: &batched[start..],
1629 };
1630
1631 let result = Self::send_api_request::<()>(
1632 |builder| {
1633 let req = builder
1634 .uri(url.as_ref())
1635 .body(serde_json::to_string(&body)?.into());
1636 anyhow::Ok(req?)
1637 },
1638 client.clone(),
1639 llm_token.clone(),
1640 organization_id,
1641 app_version.clone(),
1642 true,
1643 )
1644 .await;
1645
1646 if result.log_err().is_some() {
1647 batched.drain(start..);
1648 }
1649 }
1650 }
1651
1652 async fn run_settled_predictions_worker(
1653 this: WeakEntity<Self>,
1654 mut rx: UnboundedReceiver<Instant>,
1655 cx: &mut AsyncApp,
1656 ) {
1657 let mut next_wake_time: Option<Instant> = None;
1658 loop {
1659 let now = cx.background_executor().now();
1660 if let Some(wake_time) = next_wake_time.take() {
1661 cx.background_executor()
1662 .timer(wake_time.duration_since(now))
1663 .await;
1664 } else {
1665 let Some(new_enqueue_time) = rx.next().await else {
1666 break;
1667 };
1668 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1669 while rx.next().now_or_never().flatten().is_some() {}
1670 continue;
1671 }
1672
1673 let Some(this) = this.upgrade() else {
1674 break;
1675 };
1676
1677 let now = cx.background_executor().now();
1678
1679 let mut oldest_edited_at = None;
1680
1681 this.update(cx, |this, _| {
1682 for (_, project_state) in this.projects.iter_mut() {
1683 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1684 registered_buffer
1685 .pending_predictions
1686 .retain_mut(|pending_prediction| {
1687 let age =
1688 now.saturating_duration_since(pending_prediction.enqueued_at);
1689 if age >= EDIT_PREDICTION_SETTLED_TTL {
1690 return false;
1691 }
1692
1693 let quiet_for =
1694 now.saturating_duration_since(pending_prediction.last_edit_at);
1695 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1696 let settled_editable_region = registered_buffer
1697 .snapshot
1698 .text_for_range(
1699 pending_prediction.editable_anchor_range.clone(),
1700 )
1701 .collect::<String>();
1702
1703 #[cfg(test)]
1704 if let Some(callback) = &this.settled_event_callback {
1705 callback(
1706 pending_prediction.request_id.clone(),
1707 settled_editable_region.clone(),
1708 );
1709 }
1710
1711 telemetry::event!(
1712 EDIT_PREDICTION_SETTLED_EVENT,
1713 request_id = pending_prediction.request_id.0.clone(),
1714 settled_editable_region,
1715 example = pending_prediction.example.take(),
1716 e2e_latency = pending_prediction.e2e_latency.as_millis(),
1717 );
1718
1719 return false;
1720 }
1721
1722 if oldest_edited_at
1723 .is_none_or(|t| pending_prediction.last_edit_at < t)
1724 {
1725 oldest_edited_at = Some(pending_prediction.last_edit_at);
1726 }
1727
1728 true
1729 });
1730 }
1731 }
1732 });
1733
1734 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1735 }
1736 }
1737
1738 pub(crate) fn enqueue_settled_prediction(
1739 &mut self,
1740 request_id: EditPredictionId,
1741 project: &Entity<Project>,
1742 edited_buffer: &Entity<Buffer>,
1743 edited_buffer_snapshot: &BufferSnapshot,
1744 editable_offset_range: Range<usize>,
1745 example: Option<ExampleSpec>,
1746 e2e_latency: std::time::Duration,
1747 cx: &mut Context<Self>,
1748 ) {
1749 let this = &mut *self;
1750 let project_state = this.get_or_init_project(project, cx);
1751 if let Some(buffer) = project_state
1752 .registered_buffers
1753 .get_mut(&edited_buffer.entity_id())
1754 {
1755 let now = cx.background_executor().now();
1756 buffer.pending_predictions.push(PendingSettledPrediction {
1757 request_id: request_id,
1758 editable_anchor_range: edited_buffer_snapshot
1759 .anchor_range_around(editable_offset_range),
1760 example,
1761 e2e_latency,
1762 enqueued_at: now,
1763 last_edit_at: now,
1764 });
1765 this.settled_predictions_tx.unbounded_send(now).ok();
1766 }
1767 }
1768
1769 fn reject_current_prediction(
1770 &mut self,
1771 reason: EditPredictionRejectReason,
1772 project: &Entity<Project>,
1773 cx: &App,
1774 ) {
1775 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1776 project_state.pending_predictions.clear();
1777 if let Some(prediction) = project_state.current_prediction.take() {
1778 let model_version = prediction.prediction.model_version.clone();
1779 self.reject_prediction(
1780 prediction.prediction.id,
1781 reason,
1782 prediction.was_shown,
1783 model_version,
1784 Some(prediction.e2e_latency),
1785 cx,
1786 );
1787 }
1788 };
1789 }
1790
1791 fn did_show_current_prediction(
1792 &mut self,
1793 project: &Entity<Project>,
1794 display_type: edit_prediction_types::SuggestionDisplayType,
1795 cx: &mut Context<Self>,
1796 ) {
1797 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1798 return;
1799 };
1800
1801 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1802 return;
1803 };
1804
1805 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1806 let previous_shown_with = current_prediction.shown_with;
1807
1808 if previous_shown_with.is_none() || !is_jump {
1809 current_prediction.shown_with = Some(display_type);
1810 }
1811
1812 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1813
1814 if is_first_non_jump_show {
1815 current_prediction.was_shown = true;
1816 }
1817
1818 let display_type_changed = previous_shown_with != Some(display_type);
1819
1820 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1821 sweep_ai::edit_prediction_shown(
1822 &self.sweep_ai,
1823 self.client.clone(),
1824 ¤t_prediction.prediction,
1825 display_type,
1826 cx,
1827 );
1828 }
1829
1830 if is_first_non_jump_show {
1831 self.shown_predictions
1832 .push_front(current_prediction.prediction.clone());
1833 if self.shown_predictions.len() > 50 {
1834 let completion = self.shown_predictions.pop_back().unwrap();
1835 self.rated_predictions.remove(&completion.id);
1836 }
1837 }
1838 }
1839
1840 fn reject_prediction(
1841 &mut self,
1842 prediction_id: EditPredictionId,
1843 reason: EditPredictionRejectReason,
1844 was_shown: bool,
1845 model_version: Option<String>,
1846 e2e_latency: Option<std::time::Duration>,
1847 cx: &App,
1848 ) {
1849 match self.edit_prediction_model {
1850 EditPredictionModel::Zeta => {
1851 let is_cloud = !matches!(
1852 all_language_settings(None, cx).edit_predictions.provider,
1853 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1854 );
1855
1856 if is_cloud {
1857 let organization_id = self
1858 .user_store
1859 .read(cx)
1860 .current_organization()
1861 .map(|organization| organization.id.clone());
1862
1863 self.reject_predictions_tx
1864 .unbounded_send(EditPredictionRejectionPayload {
1865 rejection: EditPredictionRejection {
1866 request_id: prediction_id.to_string(),
1867 reason,
1868 was_shown,
1869 model_version,
1870 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1871 },
1872 organization_id,
1873 })
1874 .log_err();
1875 }
1876 }
1877 EditPredictionModel::Mercury => {
1878 mercury::edit_prediction_rejected(
1879 prediction_id,
1880 was_shown,
1881 reason,
1882 self.client.http_client(),
1883 cx,
1884 );
1885 }
1886 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1887 }
1888 }
1889
1890 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1891 self.projects
1892 .get(&project.entity_id())
1893 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1894 }
1895
1896 pub fn refresh_prediction_from_buffer(
1897 &mut self,
1898 project: Entity<Project>,
1899 buffer: Entity<Buffer>,
1900 position: language::Anchor,
1901 cx: &mut Context<Self>,
1902 ) {
1903 self.queue_prediction_refresh(
1904 project.clone(),
1905 PredictEditsRequestTrigger::Other,
1906 buffer.entity_id(),
1907 cx,
1908 move |this, cx| {
1909 let Some(request_task) = this
1910 .update(cx, |this, cx| {
1911 this.request_prediction(
1912 &project,
1913 &buffer,
1914 position,
1915 PredictEditsRequestTrigger::Other,
1916 cx,
1917 )
1918 })
1919 .log_err()
1920 else {
1921 return Task::ready(anyhow::Ok(None));
1922 };
1923
1924 cx.spawn(async move |_cx| {
1925 request_task.await.map(|prediction_result| {
1926 prediction_result.map(|prediction_result| {
1927 (
1928 prediction_result,
1929 PredictionRequestedBy::Buffer(buffer.entity_id()),
1930 )
1931 })
1932 })
1933 })
1934 },
1935 )
1936 }
1937
1938 pub fn refresh_prediction_from_diagnostics(
1939 &mut self,
1940 project: Entity<Project>,
1941 scope: DiagnosticSearchScope,
1942 cx: &mut Context<Self>,
1943 ) {
1944 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1945 return;
1946 }
1947
1948 if currently_following(&project, cx) {
1949 return;
1950 }
1951
1952 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1953 return;
1954 };
1955
1956 // Prefer predictions from buffer
1957 if project_state.current_prediction.is_some() {
1958 log::debug!(
1959 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1960 );
1961 return;
1962 }
1963
1964 self.queue_prediction_refresh(
1965 project.clone(),
1966 PredictEditsRequestTrigger::Diagnostics,
1967 project.entity_id(),
1968 cx,
1969 move |this, cx| {
1970 let Some((active_buffer, snapshot, cursor_point)) = this
1971 .read_with(cx, |this, cx| {
1972 let project_state = this.projects.get(&project.entity_id())?;
1973 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1974 let snapshot = buffer.read(cx).snapshot();
1975
1976 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1977 return None;
1978 }
1979
1980 let cursor_point = position
1981 .map(|pos| pos.to_point(&snapshot))
1982 .unwrap_or_default();
1983
1984 Some((buffer, snapshot, cursor_point))
1985 })
1986 .log_err()
1987 .flatten()
1988 else {
1989 return Task::ready(anyhow::Ok(None));
1990 };
1991
1992 cx.spawn(async move |cx| {
1993 let diagnostic_search_range = match scope {
1994 DiagnosticSearchScope::Local => {
1995 let diagnostic_search_start =
1996 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1997 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1998 Point::new(diagnostic_search_start, 0)
1999 ..Point::new(diagnostic_search_end, 0)
2000 }
2001 DiagnosticSearchScope::Global => Default::default(),
2002 };
2003
2004 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
2005 active_buffer,
2006 &snapshot,
2007 diagnostic_search_range,
2008 cursor_point,
2009 &project,
2010 cx,
2011 )
2012 .await?
2013 else {
2014 return anyhow::Ok(None);
2015 };
2016
2017 let Some(prediction_result) = this
2018 .update(cx, |this, cx| {
2019 this.request_prediction(
2020 &project,
2021 &jump_buffer,
2022 jump_position,
2023 PredictEditsRequestTrigger::Diagnostics,
2024 cx,
2025 )
2026 })?
2027 .await?
2028 else {
2029 return anyhow::Ok(None);
2030 };
2031
2032 this.update(cx, |this, cx| {
2033 Some((
2034 if this
2035 .get_or_init_project(&project, cx)
2036 .current_prediction
2037 .is_none()
2038 {
2039 prediction_result
2040 } else {
2041 EditPredictionResult {
2042 id: prediction_result.id,
2043 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2044 e2e_latency: prediction_result.e2e_latency,
2045 }
2046 },
2047 PredictionRequestedBy::DiagnosticsUpdate,
2048 ))
2049 })
2050 })
2051 },
2052 );
2053 }
2054
2055 fn predictions_enabled_at(
2056 snapshot: &BufferSnapshot,
2057 position: Option<language::Anchor>,
2058 cx: &App,
2059 ) -> bool {
2060 let file = snapshot.file();
2061 let all_settings = all_language_settings(file, cx);
2062 if !all_settings.show_edit_predictions(snapshot.language(), cx)
2063 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2064 {
2065 return false;
2066 }
2067
2068 if let Some(last_position) = position {
2069 let settings = snapshot.settings_at(last_position, cx);
2070
2071 if !settings.edit_predictions_disabled_in.is_empty()
2072 && let Some(scope) = snapshot.language_scope_at(last_position)
2073 && let Some(scope_name) = scope.override_name()
2074 && settings
2075 .edit_predictions_disabled_in
2076 .iter()
2077 .any(|s| s == scope_name)
2078 {
2079 return false;
2080 }
2081 }
2082
2083 true
2084 }
2085
2086 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2087}
2088
2089fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2090 let Some(app_state) = AppState::try_global(cx).and_then(|app_state| app_state.upgrade()) else {
2091 return false;
2092 };
2093
2094 app_state
2095 .workspace_store
2096 .read(cx)
2097 .workspaces()
2098 .filter_map(|workspace| workspace.upgrade())
2099 .any(|workspace| {
2100 workspace.read(cx).project().entity_id() == project.entity_id()
2101 && workspace
2102 .read(cx)
2103 .leader_for_pane(workspace.read(cx).active_pane())
2104 .is_some()
2105 })
2106}
2107
2108fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2109 match provider {
2110 EditPredictionProvider::Zed
2111 | EditPredictionProvider::Sweep
2112 | EditPredictionProvider::Mercury
2113 | EditPredictionProvider::Ollama
2114 | EditPredictionProvider::OpenAiCompatibleApi
2115 | EditPredictionProvider::Experimental(_) => true,
2116 EditPredictionProvider::None
2117 | EditPredictionProvider::Copilot
2118 | EditPredictionProvider::Codestral => false,
2119 }
2120}
2121
2122impl EditPredictionStore {
2123 fn queue_prediction_refresh(
2124 &mut self,
2125 project: Entity<Project>,
2126 request_trigger: PredictEditsRequestTrigger,
2127 throttle_entity: EntityId,
2128 cx: &mut Context<Self>,
2129 do_refresh: impl FnOnce(
2130 WeakEntity<Self>,
2131 &mut AsyncApp,
2132 )
2133 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2134 + 'static,
2135 ) {
2136 fn select_throttle(
2137 project_state: &mut ProjectState,
2138 request_trigger: PredictEditsRequestTrigger,
2139 ) -> &mut Option<(EntityId, Instant)> {
2140 match request_trigger {
2141 PredictEditsRequestTrigger::Diagnostics => {
2142 &mut project_state.last_jump_prediction_refresh
2143 }
2144 _ => &mut project_state.last_edit_prediction_refresh,
2145 }
2146 }
2147
2148 let (needs_acceptance_tracking, max_pending_predictions) =
2149 match all_language_settings(None, cx).edit_predictions.provider {
2150 EditPredictionProvider::Zed
2151 | EditPredictionProvider::Sweep
2152 | EditPredictionProvider::Mercury
2153 | EditPredictionProvider::Experimental(_) => (true, 2),
2154 EditPredictionProvider::Ollama => (false, 1),
2155 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2156 EditPredictionProvider::None
2157 | EditPredictionProvider::Copilot
2158 | EditPredictionProvider::Codestral => {
2159 log::error!("queue_prediction_refresh called with non-store provider");
2160 return;
2161 }
2162 };
2163
2164 let drop_on_cancel = !needs_acceptance_tracking;
2165 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2166 let project_state = self.get_or_init_project(&project, cx);
2167 let pending_prediction_id = project_state.next_pending_prediction_id;
2168 project_state.next_pending_prediction_id += 1;
2169 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2170
2171 let task = cx.spawn(async move |this, cx| {
2172 let throttle_wait = this
2173 .update(cx, |this, cx| {
2174 let project_state = this.get_or_init_project(&project, cx);
2175 let throttle = *select_throttle(project_state, request_trigger);
2176
2177 let now = cx.background_executor().now();
2178 throttle.and_then(|(last_entity, last_timestamp)| {
2179 if throttle_entity != last_entity {
2180 return None;
2181 }
2182 (last_timestamp + throttle_timeout).checked_duration_since(now)
2183 })
2184 })
2185 .ok()
2186 .flatten();
2187
2188 if let Some(timeout) = throttle_wait {
2189 cx.background_executor().timer(timeout).await;
2190 }
2191
2192 // If this task was cancelled before the throttle timeout expired,
2193 // do not perform a request. Also skip if another task already
2194 // proceeded since we were enqueued (duplicate).
2195 let mut is_cancelled = true;
2196 this.update(cx, |this, cx| {
2197 let project_state = this.get_or_init_project(&project, cx);
2198 let was_cancelled = project_state
2199 .cancelled_predictions
2200 .remove(&pending_prediction_id);
2201 if was_cancelled {
2202 return;
2203 }
2204
2205 // Another request has been already sent since this was enqueued
2206 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2207 return;
2208 }
2209
2210 let new_refresh = (throttle_entity, cx.background_executor().now());
2211 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2212 is_cancelled = false;
2213 })
2214 .ok();
2215 if is_cancelled {
2216 return None;
2217 }
2218
2219 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2220 let new_prediction_id = new_prediction_result
2221 .as_ref()
2222 .map(|(prediction, _)| prediction.id.clone());
2223
2224 // When a prediction completes, remove it from the pending list, and cancel
2225 // any pending predictions that were enqueued before it.
2226 this.update(cx, |this, cx| {
2227 let project_state = this.get_or_init_project(&project, cx);
2228
2229 let is_cancelled = project_state
2230 .cancelled_predictions
2231 .remove(&pending_prediction_id);
2232
2233 let new_current_prediction = if !is_cancelled
2234 && let Some((prediction_result, requested_by)) = new_prediction_result
2235 {
2236 match prediction_result.prediction {
2237 Ok(prediction) => {
2238 let new_prediction = CurrentEditPrediction {
2239 requested_by,
2240 prediction,
2241 was_shown: false,
2242 shown_with: None,
2243 e2e_latency: prediction_result.e2e_latency,
2244 };
2245
2246 if let Some(current_prediction) =
2247 project_state.current_prediction.as_ref()
2248 {
2249 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2250 {
2251 this.reject_current_prediction(
2252 EditPredictionRejectReason::Replaced,
2253 &project,
2254 cx,
2255 );
2256
2257 Some(new_prediction)
2258 } else {
2259 this.reject_prediction(
2260 new_prediction.prediction.id,
2261 EditPredictionRejectReason::CurrentPreferred,
2262 false,
2263 new_prediction.prediction.model_version,
2264 Some(new_prediction.e2e_latency),
2265 cx,
2266 );
2267 None
2268 }
2269 } else {
2270 Some(new_prediction)
2271 }
2272 }
2273 Err(reject_reason) => {
2274 this.reject_prediction(
2275 prediction_result.id,
2276 reject_reason,
2277 false,
2278 None,
2279 Some(prediction_result.e2e_latency),
2280 cx,
2281 );
2282 None
2283 }
2284 }
2285 } else {
2286 None
2287 };
2288
2289 let project_state = this.get_or_init_project(&project, cx);
2290
2291 if let Some(new_prediction) = new_current_prediction {
2292 project_state.current_prediction = Some(new_prediction);
2293 }
2294
2295 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2296 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2297 if pending_prediction.id == pending_prediction_id {
2298 pending_predictions.remove(ix);
2299 for pending_prediction in pending_predictions.drain(0..ix) {
2300 project_state.cancel_pending_prediction(pending_prediction, cx)
2301 }
2302 break;
2303 }
2304 }
2305 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2306 cx.notify();
2307 })
2308 .ok();
2309
2310 new_prediction_id
2311 });
2312
2313 if project_state.pending_predictions.len() < max_pending_predictions {
2314 project_state
2315 .pending_predictions
2316 .push(PendingPrediction {
2317 id: pending_prediction_id,
2318 task,
2319 drop_on_cancel,
2320 })
2321 .unwrap();
2322 } else {
2323 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2324 project_state
2325 .pending_predictions
2326 .push(PendingPrediction {
2327 id: pending_prediction_id,
2328 task,
2329 drop_on_cancel,
2330 })
2331 .unwrap();
2332 project_state.cancel_pending_prediction(pending_prediction, cx);
2333 }
2334 }
2335
2336 pub fn request_prediction(
2337 &mut self,
2338 project: &Entity<Project>,
2339 active_buffer: &Entity<Buffer>,
2340 position: language::Anchor,
2341 trigger: PredictEditsRequestTrigger,
2342 cx: &mut Context<Self>,
2343 ) -> Task<Result<Option<EditPredictionResult>>> {
2344 self.request_prediction_internal(
2345 project.clone(),
2346 active_buffer.clone(),
2347 position,
2348 trigger,
2349 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2350 cx,
2351 )
2352 }
2353
2354 fn request_prediction_internal(
2355 &mut self,
2356 project: Entity<Project>,
2357 active_buffer: Entity<Buffer>,
2358 position: language::Anchor,
2359 trigger: PredictEditsRequestTrigger,
2360 allow_jump: bool,
2361 cx: &mut Context<Self>,
2362 ) -> Task<Result<Option<EditPredictionResult>>> {
2363 self.get_or_init_project(&project, cx);
2364 let project_state = self.projects.get(&project.entity_id()).unwrap();
2365 let stored_events = project_state.events(cx);
2366 let has_events = !stored_events.is_empty();
2367 let events: Vec<Arc<zeta_prompt::Event>> =
2368 stored_events.iter().map(|e| e.event.clone()).collect();
2369 let debug_tx = project_state.debug_tx.clone();
2370
2371 let snapshot = active_buffer.read(cx).snapshot();
2372 let cursor_point = position.to_point(&snapshot);
2373 let current_offset = position.to_offset(&snapshot);
2374
2375 let mut user_actions: Vec<UserActionRecord> =
2376 project_state.user_actions.iter().cloned().collect();
2377
2378 if let Some(last_action) = user_actions.last() {
2379 if last_action.buffer_id == active_buffer.entity_id()
2380 && current_offset != last_action.offset
2381 {
2382 let timestamp_epoch_ms = SystemTime::now()
2383 .duration_since(UNIX_EPOCH)
2384 .map(|d| d.as_millis() as u64)
2385 .unwrap_or(0);
2386 user_actions.push(UserActionRecord {
2387 action_type: UserActionType::CursorMovement,
2388 buffer_id: active_buffer.entity_id(),
2389 line_number: cursor_point.row,
2390 offset: current_offset,
2391 timestamp_epoch_ms,
2392 });
2393 }
2394 }
2395 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2396 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2397 let diagnostic_search_range =
2398 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2399
2400 let related_files = self.context_for_project(&project, cx);
2401
2402 let is_open_source = snapshot
2403 .file()
2404 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2405 && events.iter().all(|event| event.in_open_source_repo())
2406 && related_files.iter().all(|file| file.in_open_source_repo);
2407
2408 let can_collect_data = !cfg!(test)
2409 && is_open_source
2410 && self.is_data_collection_enabled(cx)
2411 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2412
2413 let recent_paths = project_state.recent_paths.clone();
2414
2415 let inputs = EditPredictionModelInput {
2416 project: project.clone(),
2417 buffer: active_buffer,
2418 snapshot,
2419 position,
2420 events,
2421 related_files,
2422 recent_paths,
2423 trigger,
2424 diagnostic_search_range: diagnostic_search_range,
2425 debug_tx,
2426 user_actions,
2427 can_collect_data,
2428 is_open_source,
2429 };
2430
2431 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2432
2433 let task = match self.edit_prediction_model {
2434 EditPredictionModel::Zeta => {
2435 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2436 }
2437 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2438 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2439 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2440 };
2441
2442 cx.spawn(async move |this, cx| {
2443 let prediction = task.await?;
2444
2445 // Only fall back to diagnostics-based prediction if we got a
2446 // the model had nothing to suggest for the buffer
2447 if prediction.is_none()
2448 && allow_jump
2449 && has_events
2450 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2451 {
2452 this.update(cx, |this, cx| {
2453 this.refresh_prediction_from_diagnostics(
2454 project,
2455 DiagnosticSearchScope::Local,
2456 cx,
2457 );
2458 })?;
2459 return anyhow::Ok(None);
2460 }
2461
2462 Ok(prediction)
2463 })
2464 }
2465
2466 pub(crate) async fn next_diagnostic_location(
2467 active_buffer: Entity<Buffer>,
2468 active_buffer_snapshot: &BufferSnapshot,
2469 active_buffer_diagnostic_search_range: Range<Point>,
2470 active_buffer_cursor_point: Point,
2471 project: &Entity<Project>,
2472 cx: &mut AsyncApp,
2473 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2474 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2475 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2476 .flat_map(|(_, _, _, selections)| {
2477 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2478 })
2479 .collect();
2480
2481 let mut jump_location = active_buffer_snapshot
2482 .diagnostic_groups(None)
2483 .into_iter()
2484 .filter_map(|(_, group)| {
2485 let range = &group.entries[group.primary_ix]
2486 .range
2487 .to_point(&active_buffer_snapshot);
2488 if range.overlaps(&active_buffer_diagnostic_search_range) {
2489 return None;
2490 }
2491 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2492 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2493 });
2494 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2495 <= DIAGNOSTIC_LINES_RANGE;
2496 if near_collaborator && !near_local {
2497 return None;
2498 }
2499 Some(range.start)
2500 })
2501 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2502 .map(|position| {
2503 (
2504 active_buffer.clone(),
2505 active_buffer_snapshot.anchor_before(position),
2506 )
2507 });
2508
2509 if jump_location.is_none() {
2510 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2511 let file = buffer.file()?;
2512
2513 Some(ProjectPath {
2514 worktree_id: file.worktree_id(cx),
2515 path: file.path().clone(),
2516 })
2517 });
2518
2519 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2520 project
2521 .diagnostic_summaries(false, cx)
2522 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2523 .map(|(path, _, _)| {
2524 let shared_prefix = path
2525 .path
2526 .components()
2527 .zip(
2528 active_buffer_path
2529 .as_ref()
2530 .map(|p| p.path.components())
2531 .unwrap_or_default(),
2532 )
2533 .take_while(|(a, b)| a == b)
2534 .count();
2535 (path, shared_prefix)
2536 })
2537 .collect()
2538 });
2539
2540 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2541
2542 for (path, _) in candidates {
2543 let candidate_buffer = project
2544 .update(cx, |project, cx| project.open_buffer(path, cx))
2545 .await?;
2546
2547 let (has_collaborators, diagnostic_position) =
2548 candidate_buffer.read_with(cx, |buffer, _cx| {
2549 let snapshot = buffer.snapshot();
2550 let has_collaborators = snapshot
2551 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2552 .next()
2553 .is_some();
2554 let position = buffer
2555 .buffer_diagnostics(None)
2556 .into_iter()
2557 .min_by_key(|entry| entry.diagnostic.severity)
2558 .map(|entry| entry.range.start);
2559 (has_collaborators, position)
2560 });
2561
2562 if has_collaborators {
2563 continue;
2564 }
2565
2566 if let Some(position) = diagnostic_position {
2567 jump_location = Some((candidate_buffer, position));
2568 break;
2569 }
2570 }
2571 }
2572
2573 anyhow::Ok(jump_location)
2574 }
2575
2576 async fn send_raw_llm_request(
2577 request: RawCompletionRequest,
2578 client: Arc<Client>,
2579 custom_url: Option<Arc<Url>>,
2580 llm_token: LlmApiToken,
2581 organization_id: Option<OrganizationId>,
2582 app_version: Version,
2583 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2584 let url = if let Some(custom_url) = custom_url {
2585 custom_url.as_ref().clone()
2586 } else {
2587 client
2588 .http_client()
2589 .build_zed_llm_url("/predict_edits/raw", &[])?
2590 };
2591
2592 Self::send_api_request(
2593 |builder| {
2594 let req = builder
2595 .uri(url.as_ref())
2596 .body(serde_json::to_string(&request)?.into());
2597 Ok(req?)
2598 },
2599 client,
2600 llm_token,
2601 organization_id,
2602 app_version,
2603 true,
2604 )
2605 .await
2606 }
2607
2608 pub(crate) async fn send_v3_request(
2609 input: ZetaPromptInput,
2610 client: Arc<Client>,
2611 llm_token: LlmApiToken,
2612 organization_id: Option<OrganizationId>,
2613 app_version: Version,
2614 trigger: PredictEditsRequestTrigger,
2615 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2616 let url = client
2617 .http_client()
2618 .build_zed_llm_url("/predict_edits/v3", &[])?;
2619
2620 let request = PredictEditsV3Request { input, trigger };
2621
2622 let json_bytes = serde_json::to_vec(&request)?;
2623 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2624
2625 Self::send_api_request(
2626 |builder| {
2627 let req = builder
2628 .uri(url.as_ref())
2629 .header("Content-Encoding", "zstd")
2630 .body(compressed.clone().into());
2631 Ok(req?)
2632 },
2633 client,
2634 llm_token,
2635 organization_id,
2636 app_version,
2637 true,
2638 )
2639 .await
2640 }
2641
2642 async fn send_api_request<Res>(
2643 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2644 client: Arc<Client>,
2645 llm_token: LlmApiToken,
2646 organization_id: Option<OrganizationId>,
2647 app_version: Version,
2648 require_auth: bool,
2649 ) -> Result<(Res, Option<EditPredictionUsage>)>
2650 where
2651 Res: DeserializeOwned,
2652 {
2653 let http_client = client.http_client();
2654
2655 let mut token = if require_auth {
2656 Some(llm_token.acquire(&client, organization_id.clone()).await?)
2657 } else {
2658 llm_token
2659 .acquire(&client, organization_id.clone())
2660 .await
2661 .ok()
2662 };
2663 let mut did_retry = false;
2664
2665 loop {
2666 let request_builder = http_client::Request::builder().method(Method::POST);
2667
2668 let mut request_builder = request_builder
2669 .header("Content-Type", "application/json")
2670 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2671
2672 // Only add Authorization header if we have a token
2673 if let Some(ref token_value) = token {
2674 request_builder =
2675 request_builder.header("Authorization", format!("Bearer {}", token_value));
2676 }
2677
2678 let request = build(request_builder)?;
2679
2680 let mut response = http_client.send(request).await?;
2681
2682 if let Some(minimum_required_version) = response
2683 .headers()
2684 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2685 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2686 {
2687 anyhow::ensure!(
2688 app_version >= minimum_required_version,
2689 ZedUpdateRequiredError {
2690 minimum_version: minimum_required_version
2691 }
2692 );
2693 }
2694
2695 if response.status().is_success() {
2696 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2697
2698 let mut body = Vec::new();
2699 response.body_mut().read_to_end(&mut body).await?;
2700 return Ok((serde_json::from_slice(&body)?, usage));
2701 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2702 did_retry = true;
2703 token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2704 } else {
2705 let mut body = String::new();
2706 response.body_mut().read_to_string(&mut body).await?;
2707 anyhow::bail!(
2708 "Request failed with status: {:?}\nBody: {}",
2709 response.status(),
2710 body
2711 );
2712 }
2713 }
2714 }
2715
2716 pub fn refresh_context(
2717 &mut self,
2718 project: &Entity<Project>,
2719 buffer: &Entity<language::Buffer>,
2720 cursor_position: language::Anchor,
2721 cx: &mut Context<Self>,
2722 ) {
2723 self.get_or_init_project(project, cx)
2724 .context
2725 .update(cx, |store, cx| {
2726 store.refresh(buffer.clone(), cursor_position, cx);
2727 });
2728 }
2729
2730 #[cfg(feature = "cli-support")]
2731 pub fn set_context_for_buffer(
2732 &mut self,
2733 project: &Entity<Project>,
2734 related_files: Vec<RelatedFile>,
2735 cx: &mut Context<Self>,
2736 ) {
2737 self.get_or_init_project(project, cx)
2738 .context
2739 .update(cx, |store, cx| {
2740 store.set_related_files(related_files, cx);
2741 });
2742 }
2743
2744 #[cfg(feature = "cli-support")]
2745 pub fn set_recent_paths_for_project(
2746 &mut self,
2747 project: &Entity<Project>,
2748 paths: impl IntoIterator<Item = project::ProjectPath>,
2749 cx: &mut Context<Self>,
2750 ) {
2751 let project_state = self.get_or_init_project(project, cx);
2752 project_state.recent_paths = paths.into_iter().collect();
2753 }
2754
2755 fn is_file_open_source(
2756 &self,
2757 project: &Entity<Project>,
2758 file: &Arc<dyn File>,
2759 cx: &App,
2760 ) -> bool {
2761 if !file.is_local() || file.is_private() {
2762 return false;
2763 }
2764 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2765 return false;
2766 };
2767 project_state
2768 .license_detection_watchers
2769 .get(&file.worktree_id(cx))
2770 .as_ref()
2771 .is_some_and(|watcher| watcher.is_project_open_source())
2772 }
2773
2774 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2775 self.data_collection_choice.is_enabled(cx)
2776 }
2777
2778 fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2779 let choice = KeyValueStore::global(cx)
2780 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2781 .log_err()
2782 .flatten();
2783
2784 match choice.as_deref() {
2785 Some("true") => DataCollectionChoice::Enabled,
2786 Some("false") => DataCollectionChoice::Disabled,
2787 Some(_) => {
2788 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2789 DataCollectionChoice::NotAnswered
2790 }
2791 None => DataCollectionChoice::NotAnswered,
2792 }
2793 }
2794
2795 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2796 self.data_collection_choice = self.data_collection_choice.toggle();
2797 let new_choice = self.data_collection_choice;
2798 let is_enabled = new_choice.is_enabled(cx);
2799 let kvp = KeyValueStore::global(cx);
2800 db::write_and_log(cx, move || async move {
2801 kvp.write_kvp(
2802 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2803 is_enabled.to_string(),
2804 )
2805 .await
2806 });
2807 }
2808
2809 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2810 self.shown_predictions.iter()
2811 }
2812
2813 pub fn shown_completions_len(&self) -> usize {
2814 self.shown_predictions.len()
2815 }
2816
2817 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2818 self.rated_predictions.contains(id)
2819 }
2820
2821 pub fn rate_prediction(
2822 &mut self,
2823 prediction: &EditPrediction,
2824 rating: EditPredictionRating,
2825 feedback: String,
2826 cx: &mut Context<Self>,
2827 ) {
2828 let organization = self.user_store.read(cx).current_organization();
2829
2830 self.rated_predictions.insert(prediction.id.clone());
2831
2832 cx.background_spawn({
2833 let client = self.client.clone();
2834 let prediction_id = prediction.id.to_string();
2835 let inputs = serde_json::to_value(&prediction.inputs);
2836 let output = prediction
2837 .edit_preview
2838 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2839 async move {
2840 client
2841 .cloud_client()
2842 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2843 organization_id: organization.map(|organization| organization.id.clone()),
2844 request_id: prediction_id,
2845 rating: match rating {
2846 EditPredictionRating::Positive => "positive".to_string(),
2847 EditPredictionRating::Negative => "negative".to_string(),
2848 },
2849 inputs: inputs?,
2850 output,
2851 feedback,
2852 })
2853 .await?;
2854
2855 anyhow::Ok(())
2856 }
2857 })
2858 .detach_and_log_err(cx);
2859
2860 cx.notify();
2861 }
2862}
2863
2864fn collaborator_edit_overlaps_locality_region(
2865 project_state: &ProjectState,
2866 project: &Entity<Project>,
2867 buffer: &Entity<Buffer>,
2868 snapshot: &BufferSnapshot,
2869 edit_range: &Range<Anchor>,
2870 cx: &App,
2871) -> bool {
2872 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2873 return false;
2874 };
2875
2876 if active_buffer.entity_id() != buffer.entity_id() {
2877 return false;
2878 }
2879
2880 let locality_point_range = expand_context_syntactically_then_linewise(
2881 snapshot,
2882 (position..position).to_point(snapshot),
2883 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2884 );
2885 let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2886
2887 edit_range.overlaps(&locality_anchor_range, snapshot)
2888}
2889
2890fn merge_trailing_events_if_needed(
2891 events: &mut VecDeque<StoredEvent>,
2892 end_snapshot: &TextBufferSnapshot,
2893 latest_snapshot: &TextBufferSnapshot,
2894 latest_edit_range: &Range<Anchor>,
2895) {
2896 if let Some(last_event) = events.back() {
2897 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2898 return;
2899 }
2900 if !latest_snapshot
2901 .version
2902 .observed_all(&last_event.new_snapshot_version)
2903 {
2904 return;
2905 }
2906 }
2907
2908 let mut next_old_event = None;
2909 let mut mergeable_count = 0;
2910 for old_event in events.iter().rev() {
2911 if let Some(next_old_event) = next_old_event
2912 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2913 {
2914 break;
2915 }
2916 mergeable_count += 1;
2917 next_old_event = Some(old_event);
2918 }
2919
2920 if mergeable_count <= 1 {
2921 return;
2922 }
2923
2924 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2925 let oldest_event = events_to_merge.peek().unwrap();
2926 let oldest_snapshot = oldest_event.old_snapshot.clone();
2927 let newest_snapshot = end_snapshot;
2928 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2929
2930 for event in events.range(events.len() - mergeable_count + 1..) {
2931 merged_edit_range =
2932 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2933 }
2934
2935 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2936 &oldest_snapshot,
2937 newest_snapshot,
2938 &merged_edit_range,
2939 ) {
2940 let merged_event = match oldest_event.event.as_ref() {
2941 zeta_prompt::Event::BufferChange {
2942 old_path,
2943 path,
2944 in_open_source_repo,
2945 ..
2946 } => StoredEvent {
2947 event: Arc::new(zeta_prompt::Event::BufferChange {
2948 old_path: old_path.clone(),
2949 path: path.clone(),
2950 diff,
2951 in_open_source_repo: *in_open_source_repo,
2952 predicted: events_to_merge.all(|e| {
2953 matches!(
2954 e.event.as_ref(),
2955 zeta_prompt::Event::BufferChange {
2956 predicted: true,
2957 ..
2958 }
2959 )
2960 }),
2961 }),
2962 old_snapshot: oldest_snapshot.clone(),
2963 new_snapshot_version: newest_snapshot.version.clone(),
2964 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2965 ..newest_snapshot.anchor_before(edit_range.end),
2966 },
2967 };
2968 events.truncate(events.len() - mergeable_count);
2969 events.push_back(merged_event);
2970 }
2971}
2972
2973fn merge_anchor_ranges(
2974 left: &Range<Anchor>,
2975 right: &Range<Anchor>,
2976 snapshot: &TextBufferSnapshot,
2977) -> Range<Anchor> {
2978 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2979 left.start
2980 } else {
2981 right.start
2982 };
2983 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2984 left.end
2985 } else {
2986 right.end
2987 };
2988 start..end
2989}
2990
2991#[derive(Error, Debug)]
2992#[error(
2993 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2994)]
2995pub struct ZedUpdateRequiredError {
2996 minimum_version: Version,
2997}
2998
2999#[derive(Debug, Clone, Copy)]
3000pub enum DataCollectionChoice {
3001 NotAnswered,
3002 Enabled,
3003 Disabled,
3004}
3005
3006impl DataCollectionChoice {
3007 pub fn is_enabled(self, cx: &App) -> bool {
3008 if cx.is_staff() {
3009 return true;
3010 }
3011 match self {
3012 Self::Enabled => true,
3013 Self::NotAnswered | Self::Disabled => false,
3014 }
3015 }
3016
3017 #[must_use]
3018 pub fn toggle(&self) -> DataCollectionChoice {
3019 match self {
3020 Self::Enabled => Self::Disabled,
3021 Self::Disabled => Self::Enabled,
3022 Self::NotAnswered => Self::Enabled,
3023 }
3024 }
3025}
3026
3027impl From<bool> for DataCollectionChoice {
3028 fn from(value: bool) -> Self {
3029 match value {
3030 true => DataCollectionChoice::Enabled,
3031 false => DataCollectionChoice::Disabled,
3032 }
3033 }
3034}
3035
3036struct ZedPredictUpsell;
3037
3038impl Dismissable for ZedPredictUpsell {
3039 const KEY: &'static str = "dismissed-edit-predict-upsell";
3040
3041 fn dismissed(cx: &App) -> bool {
3042 // To make this backwards compatible with older versions of Zed, we
3043 // check if the user has seen the previous Edit Prediction Onboarding
3044 // before, by checking the data collection choice which was written to
3045 // the database once the user clicked on "Accept and Enable"
3046 let kvp = KeyValueStore::global(cx);
3047 if kvp
3048 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3049 .log_err()
3050 .is_some_and(|s| s.is_some())
3051 {
3052 return true;
3053 }
3054
3055 kvp.read_kvp(Self::KEY)
3056 .log_err()
3057 .is_some_and(|s| s.is_some())
3058 }
3059}
3060
3061pub fn should_show_upsell_modal(cx: &App) -> bool {
3062 !ZedPredictUpsell::dismissed(cx)
3063}
3064
3065pub fn init(cx: &mut App) {
3066 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3067 workspace.register_action(
3068 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3069 ZedPredictModal::toggle(
3070 workspace,
3071 workspace.user_store().clone(),
3072 workspace.client().clone(),
3073 window,
3074 cx,
3075 )
3076 },
3077 );
3078
3079 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3080 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3081 settings
3082 .project
3083 .all_languages
3084 .edit_predictions
3085 .get_or_insert_default()
3086 .provider = Some(EditPredictionProvider::None)
3087 });
3088 });
3089 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3090 EditPredictionStore::try_global(cx).and_then(|store| {
3091 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3092 })
3093 }
3094
3095 workspace.register_action(|workspace, _: &SignIn, window, cx| {
3096 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3097 copilot_ui::initiate_sign_in(copilot, window, cx);
3098 }
3099 });
3100 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3101 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3102 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3103 }
3104 });
3105 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3106 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3107 copilot_ui::initiate_sign_out(copilot, window, cx);
3108 }
3109 });
3110 })
3111 .detach();
3112}