1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KeyValueStore};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::{AnchorRangeExt, Edit};
44use workspace::{AppState, Workspace};
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56
57pub mod cursor_excerpt;
58pub mod example_spec;
59pub mod fim;
60mod license_detection;
61pub mod mercury;
62pub mod ollama;
63mod onboarding_modal;
64pub mod open_ai_response;
65mod prediction;
66pub mod sweep_ai;
67
68pub mod udiff;
69
70mod capture_example;
71pub mod open_ai_compatible;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
79use crate::example_spec::ExampleSpec;
80use crate::license_detection::LicenseDetectionWatcher;
81use crate::mercury::Mercury;
82use crate::onboarding_modal::ZedPredictModal;
83pub use crate::prediction::EditPrediction;
84pub use crate::prediction::EditPredictionId;
85use crate::prediction::EditPredictionResult;
86pub use crate::sweep_ai::SweepAi;
87pub use capture_example::capture_example;
88pub use language_model::ApiKeyState;
89pub use telemetry_events::EditPredictionRating;
90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
91
92actions!(
93 edit_prediction,
94 [
95 /// Resets the edit prediction onboarding state.
96 ResetOnboarding,
97 /// Clears the edit prediction history.
98 ClearHistory,
99 ]
100);
101
102/// Maximum number of events to track.
103const EVENT_COUNT_MAX: usize = 10;
104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
105const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
106const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
107const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
108const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
109const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
110const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
111const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
112
113pub struct EditPredictionJumpsFeatureFlag;
114
115impl FeatureFlag for EditPredictionJumpsFeatureFlag {
116 const NAME: &'static str = "edit_prediction_jumps";
117}
118
119#[derive(Clone)]
120struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
121
122impl Global for EditPredictionStoreGlobal {}
123
124/// Configuration for using the raw Zeta2 endpoint.
125/// When set, the client uses the raw endpoint and constructs the prompt itself.
126/// The version is also used as the Baseten environment name (lowercased).
127#[derive(Clone)]
128pub struct Zeta2RawConfig {
129 pub model_id: Option<String>,
130 pub environment: Option<String>,
131 pub format: ZetaFormat,
132}
133
134pub struct EditPredictionStore {
135 client: Arc<Client>,
136 user_store: Entity<UserStore>,
137 llm_token: LlmApiToken,
138 _fetch_experiments_task: Task<()>,
139 projects: HashMap<EntityId, ProjectState>,
140 update_required: bool,
141 edit_prediction_model: EditPredictionModel,
142 zeta2_raw_config: Option<Zeta2RawConfig>,
143 preferred_experiment: Option<String>,
144 available_experiments: Vec<String>,
145 pub sweep_ai: SweepAi,
146 pub mercury: Mercury,
147 data_collection_choice: DataCollectionChoice,
148 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
149 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
150 shown_predictions: VecDeque<EditPrediction>,
151 rated_predictions: HashSet<EditPredictionId>,
152 #[cfg(test)]
153 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
154}
155
156pub(crate) struct EditPredictionRejectionPayload {
157 rejection: EditPredictionRejection,
158 organization_id: Option<OrganizationId>,
159}
160
161#[derive(Copy, Clone, PartialEq, Eq)]
162pub enum EditPredictionModel {
163 Zeta,
164 Fim { format: EditPredictionPromptFormat },
165 Sweep,
166 Mercury,
167}
168
169#[derive(Clone)]
170pub struct EditPredictionModelInput {
171 project: Entity<Project>,
172 buffer: Entity<Buffer>,
173 snapshot: BufferSnapshot,
174 position: Anchor,
175 events: Vec<Arc<zeta_prompt::Event>>,
176 related_files: Vec<RelatedFile>,
177 recent_paths: VecDeque<ProjectPath>,
178 trigger: PredictEditsRequestTrigger,
179 diagnostic_search_range: Range<Point>,
180 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
181 can_collect_data: bool,
182 is_open_source: bool,
183 pub user_actions: Vec<UserActionRecord>,
184}
185
186#[derive(Debug)]
187pub enum DebugEvent {
188 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
189 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
190 EditPredictionStarted(EditPredictionStartedDebugEvent),
191 EditPredictionFinished(EditPredictionFinishedDebugEvent),
192}
193
194#[derive(Debug)]
195pub struct ContextRetrievalStartedDebugEvent {
196 pub project_entity_id: EntityId,
197 pub timestamp: Instant,
198 pub search_prompt: String,
199}
200
201#[derive(Debug)]
202pub struct ContextRetrievalFinishedDebugEvent {
203 pub project_entity_id: EntityId,
204 pub timestamp: Instant,
205 pub metadata: Vec<(&'static str, SharedString)>,
206}
207
208#[derive(Debug)]
209pub struct EditPredictionStartedDebugEvent {
210 pub buffer: WeakEntity<Buffer>,
211 pub position: Anchor,
212 pub prompt: Option<String>,
213}
214
215#[derive(Debug)]
216pub struct EditPredictionFinishedDebugEvent {
217 pub buffer: WeakEntity<Buffer>,
218 pub position: Anchor,
219 pub model_output: Option<String>,
220}
221
222const USER_ACTION_HISTORY_SIZE: usize = 16;
223
224#[derive(Clone, Debug)]
225pub struct UserActionRecord {
226 pub action_type: UserActionType,
227 pub buffer_id: EntityId,
228 pub line_number: u32,
229 pub offset: usize,
230 pub timestamp_epoch_ms: u64,
231}
232
233#[derive(Clone, Copy, Debug, PartialEq, Eq)]
234pub enum UserActionType {
235 InsertChar,
236 InsertSelection,
237 DeleteChar,
238 DeleteSelection,
239 CursorMovement,
240}
241
242/// An event with associated metadata for reconstructing buffer state.
243#[derive(Clone)]
244pub struct StoredEvent {
245 pub event: Arc<zeta_prompt::Event>,
246 pub old_snapshot: TextBufferSnapshot,
247 pub new_snapshot_version: clock::Global,
248 pub total_edit_range: Range<Anchor>,
249}
250
251impl StoredEvent {
252 fn can_merge(
253 &self,
254 next_old_event: &StoredEvent,
255 latest_snapshot: &TextBufferSnapshot,
256 latest_edit_range: &Range<Anchor>,
257 ) -> bool {
258 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
259 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
260 return false;
261 }
262 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
263 return false;
264 }
265 if self.new_snapshot_version != next_old_event.old_snapshot.version {
266 return false;
267 }
268 if !latest_snapshot
269 .version
270 .observed_all(&next_old_event.new_snapshot_version)
271 {
272 return false;
273 }
274
275 let a_is_predicted = matches!(
276 self.event.as_ref(),
277 zeta_prompt::Event::BufferChange {
278 predicted: true,
279 ..
280 }
281 );
282 let b_is_predicted = matches!(
283 next_old_event.event.as_ref(),
284 zeta_prompt::Event::BufferChange {
285 predicted: true,
286 ..
287 }
288 );
289
290 // If events come from the same source (both predicted or both manual) then
291 // we would have coalesced them already.
292 if a_is_predicted == b_is_predicted {
293 return false;
294 }
295
296 let left_range = self.total_edit_range.to_point(latest_snapshot);
297 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
298 let latest_range = latest_edit_range.to_point(latest_snapshot);
299
300 // Events near to the latest edit are not merged if their sources differ.
301 if lines_between_ranges(&left_range, &latest_range)
302 .min(lines_between_ranges(&right_range, &latest_range))
303 <= CHANGE_GROUPING_LINE_SPAN
304 {
305 return false;
306 }
307
308 // Events that are distant from each other are not merged.
309 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
310 return false;
311 }
312
313 true
314 }
315}
316
317fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
318 if left.start > right.end {
319 return left.start.row - right.end.row;
320 }
321 if right.start > left.end {
322 return right.start.row - left.end.row;
323 }
324 0
325}
326
327struct ProjectState {
328 events: VecDeque<StoredEvent>,
329 last_event: Option<LastEvent>,
330 recent_paths: VecDeque<ProjectPath>,
331 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
332 current_prediction: Option<CurrentEditPrediction>,
333 next_pending_prediction_id: usize,
334 pending_predictions: ArrayVec<PendingPrediction, 2>,
335 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
336 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
337 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
338 cancelled_predictions: HashSet<usize>,
339 context: Entity<RelatedExcerptStore>,
340 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
341 user_actions: VecDeque<UserActionRecord>,
342 _subscriptions: [gpui::Subscription; 2],
343 copilot: Option<Entity<Copilot>>,
344}
345
346impl ProjectState {
347 fn record_user_action(&mut self, action: UserActionRecord) {
348 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
349 self.user_actions.pop_front();
350 }
351 self.user_actions.push_back(action);
352 }
353
354 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
355 self.events
356 .iter()
357 .cloned()
358 .chain(self.last_event.as_ref().iter().flat_map(|event| {
359 let (one, two) = event.split_by_pause();
360 let one = one.finalize(&self.license_detection_watchers, cx);
361 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
362 one.into_iter().chain(two)
363 }))
364 .collect()
365 }
366
367 fn cancel_pending_prediction(
368 &mut self,
369 pending_prediction: PendingPrediction,
370 cx: &mut Context<EditPredictionStore>,
371 ) {
372 self.cancelled_predictions.insert(pending_prediction.id);
373
374 if pending_prediction.drop_on_cancel {
375 drop(pending_prediction.task);
376 } else {
377 cx.spawn(async move |this, cx| {
378 let Some(prediction_id) = pending_prediction.task.await else {
379 return;
380 };
381
382 this.update(cx, |this, cx| {
383 this.reject_prediction(
384 prediction_id,
385 EditPredictionRejectReason::Canceled,
386 false,
387 None,
388 None,
389 cx,
390 );
391 })
392 .ok();
393 })
394 .detach()
395 }
396 }
397
398 fn active_buffer(
399 &self,
400 project: &Entity<Project>,
401 cx: &App,
402 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
403 let project = project.read(cx);
404 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
405 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
406 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
407 Some((active_buffer, registered_buffer.last_position))
408 }
409}
410
411#[derive(Debug, Clone)]
412struct CurrentEditPrediction {
413 pub requested_by: PredictionRequestedBy,
414 pub prediction: EditPrediction,
415 pub was_shown: bool,
416 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
417 pub e2e_latency: std::time::Duration,
418}
419
420impl CurrentEditPrediction {
421 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
422 let Some(new_edits) = self
423 .prediction
424 .interpolate(&self.prediction.buffer.read(cx))
425 else {
426 return false;
427 };
428
429 if self.prediction.buffer != old_prediction.prediction.buffer {
430 return true;
431 }
432
433 let Some(old_edits) = old_prediction
434 .prediction
435 .interpolate(&old_prediction.prediction.buffer.read(cx))
436 else {
437 return true;
438 };
439
440 let requested_by_buffer_id = self.requested_by.buffer_id();
441
442 // This reduces the occurrence of UI thrash from replacing edits
443 //
444 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
445 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
446 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
447 && old_edits.len() == 1
448 && new_edits.len() == 1
449 {
450 let (old_range, old_text) = &old_edits[0];
451 let (new_range, new_text) = &new_edits[0];
452 new_range == old_range && new_text.starts_with(old_text.as_ref())
453 } else {
454 true
455 }
456 }
457}
458
459#[derive(Debug, Clone)]
460enum PredictionRequestedBy {
461 DiagnosticsUpdate,
462 Buffer(EntityId),
463}
464
465impl PredictionRequestedBy {
466 pub fn buffer_id(&self) -> Option<EntityId> {
467 match self {
468 PredictionRequestedBy::DiagnosticsUpdate => None,
469 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
470 }
471 }
472}
473
474const DIAGNOSTIC_LINES_RANGE: u32 = 20;
475
476#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
477pub enum DiagnosticSearchScope {
478 Local,
479 Global,
480}
481
482#[derive(Debug)]
483struct PendingPrediction {
484 id: usize,
485 task: Task<Option<EditPredictionId>>,
486 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
487 /// If false, the task is awaited to completion so rejection can be reported.
488 drop_on_cancel: bool,
489}
490
491/// A prediction from the perspective of a buffer.
492#[derive(Debug)]
493enum BufferEditPrediction<'a> {
494 Local { prediction: &'a EditPrediction },
495 Jump { prediction: &'a EditPrediction },
496}
497
498#[cfg(test)]
499impl std::ops::Deref for BufferEditPrediction<'_> {
500 type Target = EditPrediction;
501
502 fn deref(&self) -> &Self::Target {
503 match self {
504 BufferEditPrediction::Local { prediction } => prediction,
505 BufferEditPrediction::Jump { prediction } => prediction,
506 }
507 }
508}
509
510#[derive(Clone)]
511
512struct PendingSettledPrediction {
513 request_id: EditPredictionId,
514 editable_anchor_range: Range<Anchor>,
515 example: Option<ExampleSpec>,
516 enqueued_at: Instant,
517 last_edit_at: Instant,
518 e2e_latency: std::time::Duration,
519}
520
521struct RegisteredBuffer {
522 file: Option<Arc<dyn File>>,
523 snapshot: TextBufferSnapshot,
524 pending_predictions: Vec<PendingSettledPrediction>,
525 last_position: Option<Anchor>,
526 _subscriptions: [gpui::Subscription; 2],
527}
528
529#[derive(Clone)]
530struct LastEvent {
531 old_snapshot: TextBufferSnapshot,
532 new_snapshot: TextBufferSnapshot,
533 old_file: Option<Arc<dyn File>>,
534 new_file: Option<Arc<dyn File>>,
535 latest_edit_range: Range<Anchor>,
536 total_edit_range: Range<Anchor>,
537 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
538 predicted: bool,
539 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
540 last_edit_time: Option<Instant>,
541}
542
543impl LastEvent {
544 pub fn finalize(
545 &self,
546 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
547 cx: &App,
548 ) -> Option<StoredEvent> {
549 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
550 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
551
552 let in_open_source_repo =
553 [self.new_file.as_ref(), self.old_file.as_ref()]
554 .iter()
555 .all(|file| {
556 file.is_some_and(|file| {
557 license_detection_watchers
558 .get(&file.worktree_id(cx))
559 .is_some_and(|watcher| watcher.is_project_open_source())
560 })
561 });
562
563 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
564 &self.old_snapshot,
565 &self.new_snapshot,
566 &self.total_edit_range,
567 )?;
568
569 if path == old_path && diff.is_empty() {
570 None
571 } else {
572 Some(StoredEvent {
573 event: Arc::new(zeta_prompt::Event::BufferChange {
574 old_path,
575 path,
576 diff,
577 in_open_source_repo,
578 predicted: self.predicted,
579 }),
580 old_snapshot: self.old_snapshot.clone(),
581 new_snapshot_version: self.new_snapshot.version.clone(),
582 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
583 ..self.new_snapshot.anchor_before(edit_range.end),
584 })
585 }
586 }
587
588 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
589 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
590 return (self.clone(), None);
591 };
592
593 let total_edit_range_before_pause = self
594 .total_edit_range_at_last_pause_boundary
595 .clone()
596 .unwrap_or_else(|| self.total_edit_range.clone());
597
598 let Some(total_edit_range_after_pause) =
599 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
600 else {
601 return (self.clone(), None);
602 };
603
604 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
605 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
606
607 let before = LastEvent {
608 old_snapshot: self.old_snapshot.clone(),
609 new_snapshot: boundary_snapshot.clone(),
610 old_file: self.old_file.clone(),
611 new_file: self.new_file.clone(),
612 latest_edit_range: latest_edit_range_before_pause,
613 total_edit_range: total_edit_range_before_pause,
614 total_edit_range_at_last_pause_boundary: None,
615 predicted: self.predicted,
616 snapshot_after_last_editing_pause: None,
617 last_edit_time: self.last_edit_time,
618 };
619
620 let after = LastEvent {
621 old_snapshot: boundary_snapshot.clone(),
622 new_snapshot: self.new_snapshot.clone(),
623 old_file: self.old_file.clone(),
624 new_file: self.new_file.clone(),
625 latest_edit_range: latest_edit_range_after_pause,
626 total_edit_range: total_edit_range_after_pause,
627 total_edit_range_at_last_pause_boundary: None,
628 predicted: self.predicted,
629 snapshot_after_last_editing_pause: None,
630 last_edit_time: self.last_edit_time,
631 };
632
633 (before, Some(after))
634 }
635}
636
637fn compute_total_edit_range_between_snapshots(
638 old_snapshot: &TextBufferSnapshot,
639 new_snapshot: &TextBufferSnapshot,
640) -> Option<Range<Anchor>> {
641 let edits: Vec<Edit<usize>> = new_snapshot
642 .edits_since::<usize>(&old_snapshot.version)
643 .collect();
644
645 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
646 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
647 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
648
649 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
650}
651
652fn compute_old_range_for_new_range(
653 old_snapshot: &TextBufferSnapshot,
654 new_snapshot: &TextBufferSnapshot,
655 total_edit_range: &Range<Anchor>,
656) -> Option<Range<Point>> {
657 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
658 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
659
660 let edits: Vec<Edit<usize>> = new_snapshot
661 .edits_since::<usize>(&old_snapshot.version)
662 .collect();
663 let mut old_start_offset = None;
664 let mut old_end_offset = None;
665 let mut delta: isize = 0;
666
667 for edit in &edits {
668 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
669 old_start_offset = Some(if new_start_offset < edit.new.start {
670 new_start_offset.checked_add_signed(-delta)?
671 } else {
672 edit.old.start
673 });
674 }
675
676 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
677 old_end_offset = Some(if new_end_offset < edit.new.start {
678 new_end_offset.checked_add_signed(-delta)?
679 } else {
680 edit.old.end
681 });
682 }
683
684 delta += edit.new.len() as isize - edit.old.len() as isize;
685 }
686
687 let old_start_offset =
688 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
689 let old_end_offset =
690 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
691
692 Some(
693 old_snapshot.offset_to_point(old_start_offset)
694 ..old_snapshot.offset_to_point(old_end_offset),
695 )
696}
697
698fn compute_diff_between_snapshots_in_range(
699 old_snapshot: &TextBufferSnapshot,
700 new_snapshot: &TextBufferSnapshot,
701 total_edit_range: &Range<Anchor>,
702) -> Option<(String, Range<Point>)> {
703 let new_start_point = total_edit_range.start.to_point(new_snapshot);
704 let new_end_point = total_edit_range.end.to_point(new_snapshot);
705 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
706 let old_start_point = old_range.start;
707 let old_end_point = old_range.end;
708
709 const CONTEXT_LINES: u32 = 3;
710
711 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
712 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
713 let old_context_end_row =
714 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
715 let new_context_end_row =
716 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
717
718 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
719 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
720 let old_end_line_offset = old_snapshot
721 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
722 let new_end_line_offset = new_snapshot
723 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
724 let old_edit_range = old_start_line_offset..old_end_line_offset;
725 let new_edit_range = new_start_line_offset..new_end_line_offset;
726
727 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
728 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
729
730 let diff = language::unified_diff_with_offsets(
731 &old_region_text,
732 &new_region_text,
733 old_context_start_row,
734 new_context_start_row,
735 );
736
737 Some((diff, new_start_point..new_end_point))
738}
739
740fn buffer_path_with_id_fallback(
741 file: Option<&Arc<dyn File>>,
742 snapshot: &TextBufferSnapshot,
743 cx: &App,
744) -> Arc<Path> {
745 if let Some(file) = file {
746 file.full_path(cx).into()
747 } else {
748 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
749 }
750}
751
752impl EditPredictionStore {
753 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
754 cx.try_global::<EditPredictionStoreGlobal>()
755 .map(|global| global.0.clone())
756 }
757
758 pub fn global(
759 client: &Arc<Client>,
760 user_store: &Entity<UserStore>,
761 cx: &mut App,
762 ) -> Entity<Self> {
763 cx.try_global::<EditPredictionStoreGlobal>()
764 .map(|global| global.0.clone())
765 .unwrap_or_else(|| {
766 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
767 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
768 ep_store
769 })
770 }
771
772 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
773 let data_collection_choice = Self::load_data_collection_choice(cx);
774
775 let llm_token = LlmApiToken::global(cx);
776
777 let (reject_tx, reject_rx) = mpsc::unbounded();
778 cx.background_spawn({
779 let client = client.clone();
780 let llm_token = llm_token.clone();
781 let app_version = AppVersion::global(cx);
782 let background_executor = cx.background_executor().clone();
783 async move {
784 Self::handle_rejected_predictions(
785 reject_rx,
786 client,
787 llm_token,
788 app_version,
789 background_executor,
790 )
791 .await
792 }
793 })
794 .detach();
795
796 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
797 cx.spawn(async move |this, cx| {
798 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
799 })
800 .detach();
801
802 let mut current_user = user_store.read(cx).watch_current_user();
803 let fetch_experiments_task = cx.spawn(async move |this, cx| {
804 while current_user.borrow().is_none() {
805 current_user.next().await;
806 }
807 this.update(cx, |this, cx| {
808 this.refresh_available_experiments(cx);
809 })
810 .log_err();
811 });
812
813 let this = Self {
814 projects: HashMap::default(),
815 client,
816 user_store,
817 llm_token,
818 _fetch_experiments_task: fetch_experiments_task,
819 update_required: false,
820 edit_prediction_model: EditPredictionModel::Zeta,
821 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
822 preferred_experiment: None,
823 available_experiments: Vec::new(),
824 sweep_ai: SweepAi::new(cx),
825 mercury: Mercury::new(cx),
826
827 data_collection_choice,
828 reject_predictions_tx: reject_tx,
829 settled_predictions_tx,
830 rated_predictions: Default::default(),
831 shown_predictions: Default::default(),
832 #[cfg(test)]
833 settled_event_callback: None,
834 };
835
836 this
837 }
838
839 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
840 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
841 let format = ZetaFormat::parse(&version_str).ok()?;
842 let model_id = env::var("ZED_ZETA_MODEL").ok();
843 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
844 Some(Zeta2RawConfig {
845 model_id,
846 environment,
847 format,
848 })
849 }
850
851 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
852 self.edit_prediction_model = model;
853 }
854
855 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
856 self.zeta2_raw_config = Some(config);
857 }
858
859 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
860 self.zeta2_raw_config.as_ref()
861 }
862
863 pub fn preferred_experiment(&self) -> Option<&str> {
864 self.preferred_experiment.as_deref()
865 }
866
867 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
868 self.preferred_experiment = experiment;
869 }
870
871 pub fn available_experiments(&self) -> &[String] {
872 &self.available_experiments
873 }
874
875 pub fn active_experiment(&self) -> Option<&str> {
876 self.preferred_experiment.as_deref().or_else(|| {
877 self.shown_predictions
878 .iter()
879 .find_map(|p| p.model_version.as_ref())
880 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
881 })
882 }
883
884 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
885 let client = self.client.clone();
886 let llm_token = self.llm_token.clone();
887 let app_version = AppVersion::global(cx);
888 let organization_id = self
889 .user_store
890 .read(cx)
891 .current_organization()
892 .map(|organization| organization.id.clone());
893
894 cx.spawn(async move |this, cx| {
895 let experiments = cx
896 .background_spawn(async move {
897 let http_client = client.http_client();
898 let token = llm_token.acquire(&client, organization_id).await?;
899 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
900 let request = http_client::Request::builder()
901 .method(Method::GET)
902 .uri(url.as_ref())
903 .header("Authorization", format!("Bearer {}", token))
904 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
905 .body(Default::default())?;
906 let mut response = http_client.send(request).await?;
907 if response.status().is_success() {
908 let mut body = Vec::new();
909 response.body_mut().read_to_end(&mut body).await?;
910 let experiments: Vec<String> = serde_json::from_slice(&body)?;
911 Ok(experiments)
912 } else {
913 let mut body = String::new();
914 response.body_mut().read_to_string(&mut body).await?;
915 anyhow::bail!(
916 "Failed to fetch experiments: {:?}\nBody: {}",
917 response.status(),
918 body
919 );
920 }
921 })
922 .await?;
923 this.update(cx, |this, cx| {
924 this.available_experiments = experiments;
925 cx.notify();
926 })?;
927 anyhow::Ok(())
928 })
929 .detach_and_log_err(cx);
930 }
931
932 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
933 use ui::IconName;
934 match self.edit_prediction_model {
935 EditPredictionModel::Sweep => {
936 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
937 .with_disabled(IconName::SweepAiDisabled)
938 .with_up(IconName::SweepAiUp)
939 .with_down(IconName::SweepAiDown)
940 .with_error(IconName::SweepAiError)
941 }
942 EditPredictionModel::Mercury => {
943 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
944 }
945 EditPredictionModel::Zeta => {
946 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
947 .with_disabled(IconName::ZedPredictDisabled)
948 .with_up(IconName::ZedPredictUp)
949 .with_down(IconName::ZedPredictDown)
950 .with_error(IconName::ZedPredictError)
951 }
952 EditPredictionModel::Fim { .. } => {
953 let settings = &all_language_settings(None, cx).edit_predictions;
954 match settings.provider {
955 EditPredictionProvider::Ollama => {
956 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
957 }
958 _ => {
959 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
960 }
961 }
962 }
963 }
964 }
965
966 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
967 self.sweep_ai.api_token.read(cx).has_key()
968 }
969
970 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
971 self.mercury.api_token.read(cx).has_key()
972 }
973
974 pub fn mercury_has_payment_required_error(&self) -> bool {
975 self.mercury.has_payment_required_error()
976 }
977
978 pub fn clear_history(&mut self) {
979 for project_state in self.projects.values_mut() {
980 project_state.events.clear();
981 project_state.last_event.take();
982 }
983 }
984
985 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
986 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
987 project_state.events.clear();
988 project_state.last_event.take();
989 }
990 }
991
992 pub fn edit_history_for_project(
993 &self,
994 project: &Entity<Project>,
995 cx: &App,
996 ) -> Vec<StoredEvent> {
997 self.projects
998 .get(&project.entity_id())
999 .map(|project_state| project_state.events(cx))
1000 .unwrap_or_default()
1001 }
1002
1003 pub fn context_for_project<'a>(
1004 &'a self,
1005 project: &Entity<Project>,
1006 cx: &'a mut App,
1007 ) -> Vec<RelatedFile> {
1008 self.projects
1009 .get(&project.entity_id())
1010 .map(|project_state| {
1011 project_state.context.update(cx, |context, cx| {
1012 context
1013 .related_files_with_buffers(cx)
1014 .map(|(mut related_file, buffer)| {
1015 related_file.in_open_source_repo = buffer
1016 .read(cx)
1017 .file()
1018 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1019 related_file
1020 })
1021 .collect()
1022 })
1023 })
1024 .unwrap_or_default()
1025 }
1026
1027 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1028 self.projects
1029 .get(&project.entity_id())
1030 .and_then(|project| project.copilot.clone())
1031 }
1032
1033 pub fn start_copilot_for_project(
1034 &mut self,
1035 project: &Entity<Project>,
1036 cx: &mut Context<Self>,
1037 ) -> Option<Entity<Copilot>> {
1038 if DisableAiSettings::get(None, cx).disable_ai {
1039 return None;
1040 }
1041 let state = self.get_or_init_project(project, cx);
1042
1043 if state.copilot.is_some() {
1044 return state.copilot.clone();
1045 }
1046 let _project = project.clone();
1047 let project = project.read(cx);
1048
1049 let node = project.node_runtime().cloned();
1050 if let Some(node) = node {
1051 let next_id = project.languages().next_language_server_id();
1052 let fs = project.fs().clone();
1053
1054 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1055 state.copilot = Some(copilot.clone());
1056 Some(copilot)
1057 } else {
1058 None
1059 }
1060 }
1061
1062 pub fn context_for_project_with_buffers<'a>(
1063 &'a self,
1064 project: &Entity<Project>,
1065 cx: &'a mut App,
1066 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1067 self.projects
1068 .get(&project.entity_id())
1069 .map(|project| {
1070 project.context.update(cx, |context, cx| {
1071 context.related_files_with_buffers(cx).collect()
1072 })
1073 })
1074 .unwrap_or_default()
1075 }
1076
1077 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1078 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1079 self.user_store.read(cx).edit_prediction_usage()
1080 } else {
1081 None
1082 }
1083 }
1084
1085 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1086 self.get_or_init_project(project, cx);
1087 }
1088
1089 pub fn register_buffer(
1090 &mut self,
1091 buffer: &Entity<Buffer>,
1092 project: &Entity<Project>,
1093 cx: &mut Context<Self>,
1094 ) {
1095 let project_state = self.get_or_init_project(project, cx);
1096 Self::register_buffer_impl(project_state, buffer, project, cx);
1097 }
1098
1099 fn get_or_init_project(
1100 &mut self,
1101 project: &Entity<Project>,
1102 cx: &mut Context<Self>,
1103 ) -> &mut ProjectState {
1104 let entity_id = project.entity_id();
1105 self.projects
1106 .entry(entity_id)
1107 .or_insert_with(|| ProjectState {
1108 context: {
1109 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1110 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1111 this.handle_excerpt_store_event(entity_id, event);
1112 })
1113 .detach();
1114 related_excerpt_store
1115 },
1116 events: VecDeque::new(),
1117 last_event: None,
1118 recent_paths: VecDeque::new(),
1119 debug_tx: None,
1120 registered_buffers: HashMap::default(),
1121 current_prediction: None,
1122 cancelled_predictions: HashSet::default(),
1123 pending_predictions: ArrayVec::new(),
1124 next_pending_prediction_id: 0,
1125 last_edit_prediction_refresh: None,
1126 last_jump_prediction_refresh: None,
1127 license_detection_watchers: HashMap::default(),
1128 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1129 _subscriptions: [
1130 cx.subscribe(&project, Self::handle_project_event),
1131 cx.observe_release(&project, move |this, _, cx| {
1132 this.projects.remove(&entity_id);
1133 cx.notify();
1134 }),
1135 ],
1136 copilot: None,
1137 })
1138 }
1139
1140 pub fn remove_project(&mut self, project: &Entity<Project>) {
1141 self.projects.remove(&project.entity_id());
1142 }
1143
1144 fn handle_excerpt_store_event(
1145 &mut self,
1146 project_entity_id: EntityId,
1147 event: &RelatedExcerptStoreEvent,
1148 ) {
1149 if let Some(project_state) = self.projects.get(&project_entity_id) {
1150 if let Some(debug_tx) = project_state.debug_tx.clone() {
1151 match event {
1152 RelatedExcerptStoreEvent::StartedRefresh => {
1153 debug_tx
1154 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1155 ContextRetrievalStartedDebugEvent {
1156 project_entity_id: project_entity_id,
1157 timestamp: Instant::now(),
1158 search_prompt: String::new(),
1159 },
1160 ))
1161 .ok();
1162 }
1163 RelatedExcerptStoreEvent::FinishedRefresh {
1164 cache_hit_count,
1165 cache_miss_count,
1166 mean_definition_latency,
1167 max_definition_latency,
1168 } => {
1169 debug_tx
1170 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1171 ContextRetrievalFinishedDebugEvent {
1172 project_entity_id: project_entity_id,
1173 timestamp: Instant::now(),
1174 metadata: vec![
1175 (
1176 "Cache Hits",
1177 format!(
1178 "{}/{}",
1179 cache_hit_count,
1180 cache_hit_count + cache_miss_count
1181 )
1182 .into(),
1183 ),
1184 (
1185 "Max LSP Time",
1186 format!("{} ms", max_definition_latency.as_millis())
1187 .into(),
1188 ),
1189 (
1190 "Mean LSP Time",
1191 format!("{} ms", mean_definition_latency.as_millis())
1192 .into(),
1193 ),
1194 ],
1195 },
1196 ))
1197 .ok();
1198 }
1199 }
1200 }
1201 }
1202 }
1203
1204 pub fn debug_info(
1205 &mut self,
1206 project: &Entity<Project>,
1207 cx: &mut Context<Self>,
1208 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1209 let project_state = self.get_or_init_project(project, cx);
1210 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1211 project_state.debug_tx = Some(debug_watch_tx);
1212 debug_watch_rx
1213 }
1214
1215 fn handle_project_event(
1216 &mut self,
1217 project: Entity<Project>,
1218 event: &project::Event,
1219 cx: &mut Context<Self>,
1220 ) {
1221 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1222 return;
1223 }
1224 // TODO [zeta2] init with recent paths
1225 match event {
1226 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1227 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1228 return;
1229 };
1230 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1231 if let Some(path) = path {
1232 if let Some(ix) = project_state
1233 .recent_paths
1234 .iter()
1235 .position(|probe| probe == &path)
1236 {
1237 project_state.recent_paths.remove(ix);
1238 }
1239 project_state.recent_paths.push_front(path);
1240 }
1241 }
1242 project::Event::DiagnosticsUpdated { .. } => {
1243 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1244 self.refresh_prediction_from_diagnostics(
1245 project,
1246 DiagnosticSearchScope::Global,
1247 cx,
1248 );
1249 }
1250 }
1251 _ => (),
1252 }
1253 }
1254
1255 fn register_buffer_impl<'a>(
1256 project_state: &'a mut ProjectState,
1257 buffer: &Entity<Buffer>,
1258 project: &Entity<Project>,
1259 cx: &mut Context<Self>,
1260 ) -> &'a mut RegisteredBuffer {
1261 let buffer_id = buffer.entity_id();
1262
1263 if let Some(file) = buffer.read(cx).file() {
1264 let worktree_id = file.worktree_id(cx);
1265 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1266 project_state
1267 .license_detection_watchers
1268 .entry(worktree_id)
1269 .or_insert_with(|| {
1270 let project_entity_id = project.entity_id();
1271 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1272 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1273 else {
1274 return;
1275 };
1276 project_state
1277 .license_detection_watchers
1278 .remove(&worktree_id);
1279 })
1280 .detach();
1281 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1282 });
1283 }
1284 }
1285
1286 match project_state.registered_buffers.entry(buffer_id) {
1287 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1288 hash_map::Entry::Vacant(entry) => {
1289 let buf = buffer.read(cx);
1290 let snapshot = buf.text_snapshot();
1291 let file = buf.file().cloned();
1292 let project_entity_id = project.entity_id();
1293 entry.insert(RegisteredBuffer {
1294 snapshot,
1295 file,
1296 last_position: None,
1297 pending_predictions: Vec::new(),
1298 _subscriptions: [
1299 cx.subscribe(buffer, {
1300 let project = project.downgrade();
1301 move |this, buffer, event, cx| {
1302 if let language::BufferEvent::Edited { is_local } = event
1303 && let Some(project) = project.upgrade()
1304 {
1305 this.report_changes_for_buffer(
1306 &buffer, &project, false, *is_local, cx,
1307 );
1308 }
1309 }
1310 }),
1311 cx.observe_release(buffer, move |this, _buffer, _cx| {
1312 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1313 else {
1314 return;
1315 };
1316 project_state.registered_buffers.remove(&buffer_id);
1317 }),
1318 ],
1319 })
1320 }
1321 }
1322 }
1323
1324 fn report_changes_for_buffer(
1325 &mut self,
1326 buffer: &Entity<Buffer>,
1327 project: &Entity<Project>,
1328 is_predicted: bool,
1329 is_local: bool,
1330 cx: &mut Context<Self>,
1331 ) {
1332 let project_state = self.get_or_init_project(project, cx);
1333 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1334
1335 let buf = buffer.read(cx);
1336 let new_file = buf.file().cloned();
1337 let new_snapshot = buf.text_snapshot();
1338 if new_snapshot.version == registered_buffer.snapshot.version {
1339 return;
1340 }
1341 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1342 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1343 let mut num_edits = 0usize;
1344 let mut total_deleted = 0usize;
1345 let mut total_inserted = 0usize;
1346 let mut edit_range: Option<Range<Anchor>> = None;
1347 let mut last_offset: Option<usize> = None;
1348 let now = cx.background_executor().now();
1349
1350 for (edit, anchor_range) in
1351 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1352 {
1353 num_edits += 1;
1354 total_deleted += edit.old.len();
1355 total_inserted += edit.new.len();
1356 edit_range = Some(match edit_range {
1357 None => anchor_range,
1358 Some(acc) => acc.start..anchor_range.end,
1359 });
1360 last_offset = Some(edit.new.end);
1361 }
1362
1363 let Some(edit_range) = edit_range else {
1364 return;
1365 };
1366
1367 for pending_prediction in &mut registered_buffer.pending_predictions {
1368 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1369 pending_prediction.last_edit_at = now;
1370 }
1371 }
1372
1373 let include_in_history = is_local
1374 || collaborator_edit_overlaps_locality_region(
1375 project_state,
1376 project,
1377 buffer,
1378 &buf.snapshot(),
1379 &edit_range,
1380 cx,
1381 );
1382
1383 if is_local {
1384 let action_type = match (total_deleted, total_inserted, num_edits) {
1385 (0, ins, n) if ins == n => UserActionType::InsertChar,
1386 (0, _, _) => UserActionType::InsertSelection,
1387 (del, 0, n) if del == n => UserActionType::DeleteChar,
1388 (_, 0, _) => UserActionType::DeleteSelection,
1389 (_, ins, n) if ins == n => UserActionType::InsertChar,
1390 (_, _, _) => UserActionType::InsertSelection,
1391 };
1392
1393 if let Some(offset) = last_offset {
1394 let point = new_snapshot.offset_to_point(offset);
1395 let timestamp_epoch_ms = SystemTime::now()
1396 .duration_since(UNIX_EPOCH)
1397 .map(|d| d.as_millis() as u64)
1398 .unwrap_or(0);
1399 project_state.record_user_action(UserActionRecord {
1400 action_type,
1401 buffer_id: buffer.entity_id(),
1402 line_number: point.row,
1403 offset,
1404 timestamp_epoch_ms,
1405 });
1406 }
1407 }
1408
1409 if !include_in_history {
1410 return;
1411 }
1412
1413 let events = &mut project_state.events;
1414
1415 if let Some(last_event) = project_state.last_event.as_mut() {
1416 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1417 == last_event.new_snapshot.remote_id()
1418 && old_snapshot.version == last_event.new_snapshot.version;
1419
1420 let prediction_source_changed = is_predicted != last_event.predicted;
1421
1422 let should_coalesce = is_next_snapshot_of_same_buffer
1423 && !prediction_source_changed
1424 && lines_between_ranges(
1425 &edit_range.to_point(&new_snapshot),
1426 &last_event.latest_edit_range.to_point(&new_snapshot),
1427 ) <= CHANGE_GROUPING_LINE_SPAN;
1428
1429 if should_coalesce {
1430 let pause_elapsed = last_event
1431 .last_edit_time
1432 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1433 .unwrap_or(false);
1434 if pause_elapsed {
1435 last_event.snapshot_after_last_editing_pause =
1436 Some(last_event.new_snapshot.clone());
1437 last_event.total_edit_range_at_last_pause_boundary =
1438 Some(last_event.total_edit_range.clone());
1439 }
1440
1441 last_event.latest_edit_range = edit_range.clone();
1442 last_event.total_edit_range =
1443 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1444 last_event.new_snapshot = new_snapshot;
1445 last_event.last_edit_time = Some(now);
1446 return;
1447 }
1448 }
1449
1450 if let Some(event) = project_state.last_event.take() {
1451 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1452 if events.len() + 1 >= EVENT_COUNT_MAX {
1453 events.pop_front();
1454 }
1455 events.push_back(event);
1456 }
1457 }
1458
1459 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1460
1461 project_state.last_event = Some(LastEvent {
1462 old_file,
1463 new_file,
1464 old_snapshot,
1465 new_snapshot,
1466 latest_edit_range: edit_range.clone(),
1467 total_edit_range: edit_range,
1468 total_edit_range_at_last_pause_boundary: None,
1469 predicted: is_predicted,
1470 snapshot_after_last_editing_pause: None,
1471 last_edit_time: Some(now),
1472 });
1473 }
1474
1475 fn prediction_at(
1476 &mut self,
1477 buffer: &Entity<Buffer>,
1478 position: Option<language::Anchor>,
1479 project: &Entity<Project>,
1480 cx: &App,
1481 ) -> Option<BufferEditPrediction<'_>> {
1482 let project_state = self.projects.get_mut(&project.entity_id())?;
1483 if let Some(position) = position
1484 && let Some(buffer) = project_state
1485 .registered_buffers
1486 .get_mut(&buffer.entity_id())
1487 {
1488 buffer.last_position = Some(position);
1489 }
1490
1491 let CurrentEditPrediction {
1492 requested_by,
1493 prediction,
1494 ..
1495 } = project_state.current_prediction.as_ref()?;
1496
1497 if prediction.targets_buffer(buffer.read(cx)) {
1498 Some(BufferEditPrediction::Local { prediction })
1499 } else {
1500 let show_jump = match requested_by {
1501 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1502 requested_by_buffer_id == &buffer.entity_id()
1503 }
1504 PredictionRequestedBy::DiagnosticsUpdate => true,
1505 };
1506
1507 if show_jump {
1508 Some(BufferEditPrediction::Jump { prediction })
1509 } else {
1510 None
1511 }
1512 }
1513 }
1514
1515 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1516 let Some(current_prediction) = self
1517 .projects
1518 .get_mut(&project.entity_id())
1519 .and_then(|project_state| project_state.current_prediction.take())
1520 else {
1521 return;
1522 };
1523
1524 self.report_changes_for_buffer(
1525 ¤t_prediction.prediction.buffer,
1526 project,
1527 true,
1528 true,
1529 cx,
1530 );
1531
1532 // can't hold &mut project_state ref across report_changes_for_buffer_call
1533 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1534 return;
1535 };
1536
1537 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1538 project_state.cancel_pending_prediction(pending_prediction, cx);
1539 }
1540
1541 match self.edit_prediction_model {
1542 EditPredictionModel::Sweep => {
1543 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1544 }
1545 EditPredictionModel::Mercury => {
1546 mercury::edit_prediction_accepted(
1547 current_prediction.prediction.id,
1548 self.client.http_client(),
1549 cx,
1550 );
1551 }
1552 EditPredictionModel::Zeta => {
1553 let is_cloud = !matches!(
1554 all_language_settings(None, cx).edit_predictions.provider,
1555 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1556 );
1557 if is_cloud {
1558 zeta::edit_prediction_accepted(self, current_prediction, cx)
1559 }
1560 }
1561 EditPredictionModel::Fim { .. } => {}
1562 }
1563 }
1564
1565 async fn handle_rejected_predictions(
1566 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1567 client: Arc<Client>,
1568 llm_token: LlmApiToken,
1569 app_version: Version,
1570 background_executor: BackgroundExecutor,
1571 ) {
1572 let mut rx = std::pin::pin!(rx.peekable());
1573 let mut batched = Vec::new();
1574
1575 while let Some(EditPredictionRejectionPayload {
1576 rejection,
1577 organization_id,
1578 }) = rx.next().await
1579 {
1580 batched.push(rejection);
1581
1582 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1583 select_biased! {
1584 next = rx.as_mut().peek().fuse() => {
1585 if next.is_some() {
1586 continue;
1587 }
1588 }
1589 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1590 }
1591 }
1592
1593 let url = client
1594 .http_client()
1595 .build_zed_llm_url("/predict_edits/reject", &[])
1596 .unwrap();
1597
1598 let flush_count = batched
1599 .len()
1600 // in case items have accumulated after failure
1601 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1602 let start = batched.len() - flush_count;
1603
1604 let body = RejectEditPredictionsBodyRef {
1605 rejections: &batched[start..],
1606 };
1607
1608 let result = Self::send_api_request::<()>(
1609 |builder| {
1610 let req = builder
1611 .uri(url.as_ref())
1612 .body(serde_json::to_string(&body)?.into());
1613 anyhow::Ok(req?)
1614 },
1615 client.clone(),
1616 llm_token.clone(),
1617 organization_id,
1618 app_version.clone(),
1619 true,
1620 )
1621 .await;
1622
1623 if result.log_err().is_some() {
1624 batched.drain(start..);
1625 }
1626 }
1627 }
1628
1629 async fn run_settled_predictions_worker(
1630 this: WeakEntity<Self>,
1631 mut rx: UnboundedReceiver<Instant>,
1632 cx: &mut AsyncApp,
1633 ) {
1634 let mut next_wake_time: Option<Instant> = None;
1635 loop {
1636 let now = cx.background_executor().now();
1637 if let Some(wake_time) = next_wake_time.take() {
1638 cx.background_executor()
1639 .timer(wake_time.duration_since(now))
1640 .await;
1641 } else {
1642 let Some(new_enqueue_time) = rx.next().await else {
1643 break;
1644 };
1645 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1646 while rx.next().now_or_never().flatten().is_some() {}
1647 continue;
1648 }
1649
1650 let Some(this) = this.upgrade() else {
1651 break;
1652 };
1653
1654 let now = cx.background_executor().now();
1655
1656 let mut oldest_edited_at = None;
1657
1658 this.update(cx, |this, _| {
1659 for (_, project_state) in this.projects.iter_mut() {
1660 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1661 registered_buffer
1662 .pending_predictions
1663 .retain_mut(|pending_prediction| {
1664 let age =
1665 now.saturating_duration_since(pending_prediction.enqueued_at);
1666 if age >= EDIT_PREDICTION_SETTLED_TTL {
1667 return false;
1668 }
1669
1670 let quiet_for =
1671 now.saturating_duration_since(pending_prediction.last_edit_at);
1672 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1673 let settled_editable_region = registered_buffer
1674 .snapshot
1675 .text_for_range(
1676 pending_prediction.editable_anchor_range.clone(),
1677 )
1678 .collect::<String>();
1679
1680 #[cfg(test)]
1681 if let Some(callback) = &this.settled_event_callback {
1682 callback(
1683 pending_prediction.request_id.clone(),
1684 settled_editable_region.clone(),
1685 );
1686 }
1687
1688 telemetry::event!(
1689 EDIT_PREDICTION_SETTLED_EVENT,
1690 request_id = pending_prediction.request_id.0.clone(),
1691 settled_editable_region,
1692 example = pending_prediction.example.take(),
1693 e2e_latency = pending_prediction.e2e_latency.as_millis(),
1694 );
1695
1696 return false;
1697 }
1698
1699 if oldest_edited_at
1700 .is_none_or(|t| pending_prediction.last_edit_at < t)
1701 {
1702 oldest_edited_at = Some(pending_prediction.last_edit_at);
1703 }
1704
1705 true
1706 });
1707 }
1708 }
1709 });
1710
1711 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1712 }
1713 }
1714
1715 pub(crate) fn enqueue_settled_prediction(
1716 &mut self,
1717 request_id: EditPredictionId,
1718 project: &Entity<Project>,
1719 edited_buffer: &Entity<Buffer>,
1720 edited_buffer_snapshot: &BufferSnapshot,
1721 editable_offset_range: Range<usize>,
1722 example: Option<ExampleSpec>,
1723 e2e_latency: std::time::Duration,
1724 cx: &mut Context<Self>,
1725 ) {
1726 let this = &mut *self;
1727 let project_state = this.get_or_init_project(project, cx);
1728 if let Some(buffer) = project_state
1729 .registered_buffers
1730 .get_mut(&edited_buffer.entity_id())
1731 {
1732 let now = cx.background_executor().now();
1733 buffer.pending_predictions.push(PendingSettledPrediction {
1734 request_id: request_id,
1735 editable_anchor_range: edited_buffer_snapshot
1736 .anchor_range_around(editable_offset_range),
1737 example,
1738 e2e_latency,
1739 enqueued_at: now,
1740 last_edit_at: now,
1741 });
1742 this.settled_predictions_tx.unbounded_send(now).ok();
1743 }
1744 }
1745
1746 fn reject_current_prediction(
1747 &mut self,
1748 reason: EditPredictionRejectReason,
1749 project: &Entity<Project>,
1750 cx: &App,
1751 ) {
1752 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1753 project_state.pending_predictions.clear();
1754 if let Some(prediction) = project_state.current_prediction.take() {
1755 let model_version = prediction.prediction.model_version.clone();
1756 self.reject_prediction(
1757 prediction.prediction.id,
1758 reason,
1759 prediction.was_shown,
1760 model_version,
1761 Some(prediction.e2e_latency),
1762 cx,
1763 );
1764 }
1765 };
1766 }
1767
1768 fn did_show_current_prediction(
1769 &mut self,
1770 project: &Entity<Project>,
1771 display_type: edit_prediction_types::SuggestionDisplayType,
1772 cx: &mut Context<Self>,
1773 ) {
1774 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1775 return;
1776 };
1777
1778 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1779 return;
1780 };
1781
1782 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1783 let previous_shown_with = current_prediction.shown_with;
1784
1785 if previous_shown_with.is_none() || !is_jump {
1786 current_prediction.shown_with = Some(display_type);
1787 }
1788
1789 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1790
1791 if is_first_non_jump_show {
1792 current_prediction.was_shown = true;
1793 }
1794
1795 let display_type_changed = previous_shown_with != Some(display_type);
1796
1797 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1798 sweep_ai::edit_prediction_shown(
1799 &self.sweep_ai,
1800 self.client.clone(),
1801 ¤t_prediction.prediction,
1802 display_type,
1803 cx,
1804 );
1805 }
1806
1807 if is_first_non_jump_show {
1808 self.shown_predictions
1809 .push_front(current_prediction.prediction.clone());
1810 if self.shown_predictions.len() > 50 {
1811 let completion = self.shown_predictions.pop_back().unwrap();
1812 self.rated_predictions.remove(&completion.id);
1813 }
1814 }
1815 }
1816
1817 fn reject_prediction(
1818 &mut self,
1819 prediction_id: EditPredictionId,
1820 reason: EditPredictionRejectReason,
1821 was_shown: bool,
1822 model_version: Option<String>,
1823 e2e_latency: Option<std::time::Duration>,
1824 cx: &App,
1825 ) {
1826 match self.edit_prediction_model {
1827 EditPredictionModel::Zeta => {
1828 let is_cloud = !matches!(
1829 all_language_settings(None, cx).edit_predictions.provider,
1830 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1831 );
1832
1833 if is_cloud {
1834 let organization_id = self
1835 .user_store
1836 .read(cx)
1837 .current_organization()
1838 .map(|organization| organization.id.clone());
1839
1840 self.reject_predictions_tx
1841 .unbounded_send(EditPredictionRejectionPayload {
1842 rejection: EditPredictionRejection {
1843 request_id: prediction_id.to_string(),
1844 reason,
1845 was_shown,
1846 model_version,
1847 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1848 },
1849 organization_id,
1850 })
1851 .log_err();
1852 }
1853 }
1854 EditPredictionModel::Mercury => {
1855 mercury::edit_prediction_rejected(
1856 prediction_id,
1857 was_shown,
1858 reason,
1859 self.client.http_client(),
1860 cx,
1861 );
1862 }
1863 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1864 }
1865 }
1866
1867 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1868 self.projects
1869 .get(&project.entity_id())
1870 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1871 }
1872
1873 pub fn refresh_prediction_from_buffer(
1874 &mut self,
1875 project: Entity<Project>,
1876 buffer: Entity<Buffer>,
1877 position: language::Anchor,
1878 cx: &mut Context<Self>,
1879 ) {
1880 self.queue_prediction_refresh(
1881 project.clone(),
1882 PredictEditsRequestTrigger::Other,
1883 buffer.entity_id(),
1884 cx,
1885 move |this, cx| {
1886 let Some(request_task) = this
1887 .update(cx, |this, cx| {
1888 this.request_prediction(
1889 &project,
1890 &buffer,
1891 position,
1892 PredictEditsRequestTrigger::Other,
1893 cx,
1894 )
1895 })
1896 .log_err()
1897 else {
1898 return Task::ready(anyhow::Ok(None));
1899 };
1900
1901 cx.spawn(async move |_cx| {
1902 request_task.await.map(|prediction_result| {
1903 prediction_result.map(|prediction_result| {
1904 (
1905 prediction_result,
1906 PredictionRequestedBy::Buffer(buffer.entity_id()),
1907 )
1908 })
1909 })
1910 })
1911 },
1912 )
1913 }
1914
1915 pub fn refresh_prediction_from_diagnostics(
1916 &mut self,
1917 project: Entity<Project>,
1918 scope: DiagnosticSearchScope,
1919 cx: &mut Context<Self>,
1920 ) {
1921 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1922 return;
1923 }
1924
1925 if currently_following(&project, cx) {
1926 return;
1927 }
1928
1929 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1930 return;
1931 };
1932
1933 // Prefer predictions from buffer
1934 if project_state.current_prediction.is_some() {
1935 log::debug!(
1936 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1937 );
1938 return;
1939 }
1940
1941 self.queue_prediction_refresh(
1942 project.clone(),
1943 PredictEditsRequestTrigger::Diagnostics,
1944 project.entity_id(),
1945 cx,
1946 move |this, cx| {
1947 let Some((active_buffer, snapshot, cursor_point)) = this
1948 .read_with(cx, |this, cx| {
1949 let project_state = this.projects.get(&project.entity_id())?;
1950 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1951 let snapshot = buffer.read(cx).snapshot();
1952
1953 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1954 return None;
1955 }
1956
1957 let cursor_point = position
1958 .map(|pos| pos.to_point(&snapshot))
1959 .unwrap_or_default();
1960
1961 Some((buffer, snapshot, cursor_point))
1962 })
1963 .log_err()
1964 .flatten()
1965 else {
1966 return Task::ready(anyhow::Ok(None));
1967 };
1968
1969 cx.spawn(async move |cx| {
1970 let diagnostic_search_range = match scope {
1971 DiagnosticSearchScope::Local => {
1972 let diagnostic_search_start =
1973 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1974 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1975 Point::new(diagnostic_search_start, 0)
1976 ..Point::new(diagnostic_search_end, 0)
1977 }
1978 DiagnosticSearchScope::Global => Default::default(),
1979 };
1980
1981 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1982 active_buffer,
1983 &snapshot,
1984 diagnostic_search_range,
1985 cursor_point,
1986 &project,
1987 cx,
1988 )
1989 .await?
1990 else {
1991 return anyhow::Ok(None);
1992 };
1993
1994 let Some(prediction_result) = this
1995 .update(cx, |this, cx| {
1996 this.request_prediction(
1997 &project,
1998 &jump_buffer,
1999 jump_position,
2000 PredictEditsRequestTrigger::Diagnostics,
2001 cx,
2002 )
2003 })?
2004 .await?
2005 else {
2006 return anyhow::Ok(None);
2007 };
2008
2009 this.update(cx, |this, cx| {
2010 Some((
2011 if this
2012 .get_or_init_project(&project, cx)
2013 .current_prediction
2014 .is_none()
2015 {
2016 prediction_result
2017 } else {
2018 EditPredictionResult {
2019 id: prediction_result.id,
2020 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2021 e2e_latency: prediction_result.e2e_latency,
2022 }
2023 },
2024 PredictionRequestedBy::DiagnosticsUpdate,
2025 ))
2026 })
2027 })
2028 },
2029 );
2030 }
2031
2032 fn predictions_enabled_at(
2033 snapshot: &BufferSnapshot,
2034 position: Option<language::Anchor>,
2035 cx: &App,
2036 ) -> bool {
2037 let file = snapshot.file();
2038 let all_settings = all_language_settings(file, cx);
2039 if !all_settings.show_edit_predictions(snapshot.language(), cx)
2040 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2041 {
2042 return false;
2043 }
2044
2045 if let Some(last_position) = position {
2046 let settings = snapshot.settings_at(last_position, cx);
2047
2048 if !settings.edit_predictions_disabled_in.is_empty()
2049 && let Some(scope) = snapshot.language_scope_at(last_position)
2050 && let Some(scope_name) = scope.override_name()
2051 && settings
2052 .edit_predictions_disabled_in
2053 .iter()
2054 .any(|s| s == scope_name)
2055 {
2056 return false;
2057 }
2058 }
2059
2060 true
2061 }
2062
2063 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2064}
2065
2066fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2067 let Some(app_state) = AppState::try_global(cx).and_then(|app_state| app_state.upgrade()) else {
2068 return false;
2069 };
2070
2071 app_state
2072 .workspace_store
2073 .read(cx)
2074 .workspaces()
2075 .filter_map(|workspace| workspace.upgrade())
2076 .any(|workspace| {
2077 workspace.read(cx).project().entity_id() == project.entity_id()
2078 && workspace
2079 .read(cx)
2080 .leader_for_pane(workspace.read(cx).active_pane())
2081 .is_some()
2082 })
2083}
2084
2085fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2086 match provider {
2087 EditPredictionProvider::Zed
2088 | EditPredictionProvider::Sweep
2089 | EditPredictionProvider::Mercury
2090 | EditPredictionProvider::Ollama
2091 | EditPredictionProvider::OpenAiCompatibleApi
2092 | EditPredictionProvider::Experimental(_) => true,
2093 EditPredictionProvider::None
2094 | EditPredictionProvider::Copilot
2095 | EditPredictionProvider::Codestral => false,
2096 }
2097}
2098
2099impl EditPredictionStore {
2100 fn queue_prediction_refresh(
2101 &mut self,
2102 project: Entity<Project>,
2103 request_trigger: PredictEditsRequestTrigger,
2104 throttle_entity: EntityId,
2105 cx: &mut Context<Self>,
2106 do_refresh: impl FnOnce(
2107 WeakEntity<Self>,
2108 &mut AsyncApp,
2109 )
2110 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2111 + 'static,
2112 ) {
2113 fn select_throttle(
2114 project_state: &mut ProjectState,
2115 request_trigger: PredictEditsRequestTrigger,
2116 ) -> &mut Option<(EntityId, Instant)> {
2117 match request_trigger {
2118 PredictEditsRequestTrigger::Diagnostics => {
2119 &mut project_state.last_jump_prediction_refresh
2120 }
2121 _ => &mut project_state.last_edit_prediction_refresh,
2122 }
2123 }
2124
2125 let (needs_acceptance_tracking, max_pending_predictions) =
2126 match all_language_settings(None, cx).edit_predictions.provider {
2127 EditPredictionProvider::Zed
2128 | EditPredictionProvider::Sweep
2129 | EditPredictionProvider::Mercury
2130 | EditPredictionProvider::Experimental(_) => (true, 2),
2131 EditPredictionProvider::Ollama => (false, 1),
2132 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2133 EditPredictionProvider::None
2134 | EditPredictionProvider::Copilot
2135 | EditPredictionProvider::Codestral => {
2136 log::error!("queue_prediction_refresh called with non-store provider");
2137 return;
2138 }
2139 };
2140
2141 let drop_on_cancel = !needs_acceptance_tracking;
2142 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2143 let project_state = self.get_or_init_project(&project, cx);
2144 let pending_prediction_id = project_state.next_pending_prediction_id;
2145 project_state.next_pending_prediction_id += 1;
2146 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2147
2148 let task = cx.spawn(async move |this, cx| {
2149 let throttle_wait = this
2150 .update(cx, |this, cx| {
2151 let project_state = this.get_or_init_project(&project, cx);
2152 let throttle = *select_throttle(project_state, request_trigger);
2153
2154 let now = cx.background_executor().now();
2155 throttle.and_then(|(last_entity, last_timestamp)| {
2156 if throttle_entity != last_entity {
2157 return None;
2158 }
2159 (last_timestamp + throttle_timeout).checked_duration_since(now)
2160 })
2161 })
2162 .ok()
2163 .flatten();
2164
2165 if let Some(timeout) = throttle_wait {
2166 cx.background_executor().timer(timeout).await;
2167 }
2168
2169 // If this task was cancelled before the throttle timeout expired,
2170 // do not perform a request. Also skip if another task already
2171 // proceeded since we were enqueued (duplicate).
2172 let mut is_cancelled = true;
2173 this.update(cx, |this, cx| {
2174 let project_state = this.get_or_init_project(&project, cx);
2175 let was_cancelled = project_state
2176 .cancelled_predictions
2177 .remove(&pending_prediction_id);
2178 if was_cancelled {
2179 return;
2180 }
2181
2182 // Another request has been already sent since this was enqueued
2183 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2184 return;
2185 }
2186
2187 let new_refresh = (throttle_entity, cx.background_executor().now());
2188 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2189 is_cancelled = false;
2190 })
2191 .ok();
2192 if is_cancelled {
2193 return None;
2194 }
2195
2196 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2197 let new_prediction_id = new_prediction_result
2198 .as_ref()
2199 .map(|(prediction, _)| prediction.id.clone());
2200
2201 // When a prediction completes, remove it from the pending list, and cancel
2202 // any pending predictions that were enqueued before it.
2203 this.update(cx, |this, cx| {
2204 let project_state = this.get_or_init_project(&project, cx);
2205
2206 let is_cancelled = project_state
2207 .cancelled_predictions
2208 .remove(&pending_prediction_id);
2209
2210 let new_current_prediction = if !is_cancelled
2211 && let Some((prediction_result, requested_by)) = new_prediction_result
2212 {
2213 match prediction_result.prediction {
2214 Ok(prediction) => {
2215 let new_prediction = CurrentEditPrediction {
2216 requested_by,
2217 prediction,
2218 was_shown: false,
2219 shown_with: None,
2220 e2e_latency: prediction_result.e2e_latency,
2221 };
2222
2223 if let Some(current_prediction) =
2224 project_state.current_prediction.as_ref()
2225 {
2226 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2227 {
2228 this.reject_current_prediction(
2229 EditPredictionRejectReason::Replaced,
2230 &project,
2231 cx,
2232 );
2233
2234 Some(new_prediction)
2235 } else {
2236 this.reject_prediction(
2237 new_prediction.prediction.id,
2238 EditPredictionRejectReason::CurrentPreferred,
2239 false,
2240 new_prediction.prediction.model_version,
2241 Some(new_prediction.e2e_latency),
2242 cx,
2243 );
2244 None
2245 }
2246 } else {
2247 Some(new_prediction)
2248 }
2249 }
2250 Err(reject_reason) => {
2251 this.reject_prediction(
2252 prediction_result.id,
2253 reject_reason,
2254 false,
2255 None,
2256 Some(prediction_result.e2e_latency),
2257 cx,
2258 );
2259 None
2260 }
2261 }
2262 } else {
2263 None
2264 };
2265
2266 let project_state = this.get_or_init_project(&project, cx);
2267
2268 if let Some(new_prediction) = new_current_prediction {
2269 project_state.current_prediction = Some(new_prediction);
2270 }
2271
2272 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2273 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2274 if pending_prediction.id == pending_prediction_id {
2275 pending_predictions.remove(ix);
2276 for pending_prediction in pending_predictions.drain(0..ix) {
2277 project_state.cancel_pending_prediction(pending_prediction, cx)
2278 }
2279 break;
2280 }
2281 }
2282 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2283 cx.notify();
2284 })
2285 .ok();
2286
2287 new_prediction_id
2288 });
2289
2290 if project_state.pending_predictions.len() < max_pending_predictions {
2291 project_state.pending_predictions.push(PendingPrediction {
2292 id: pending_prediction_id,
2293 task,
2294 drop_on_cancel,
2295 });
2296 } else {
2297 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2298 project_state.pending_predictions.push(PendingPrediction {
2299 id: pending_prediction_id,
2300 task,
2301 drop_on_cancel,
2302 });
2303 project_state.cancel_pending_prediction(pending_prediction, cx);
2304 }
2305 }
2306
2307 pub fn request_prediction(
2308 &mut self,
2309 project: &Entity<Project>,
2310 active_buffer: &Entity<Buffer>,
2311 position: language::Anchor,
2312 trigger: PredictEditsRequestTrigger,
2313 cx: &mut Context<Self>,
2314 ) -> Task<Result<Option<EditPredictionResult>>> {
2315 self.request_prediction_internal(
2316 project.clone(),
2317 active_buffer.clone(),
2318 position,
2319 trigger,
2320 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2321 cx,
2322 )
2323 }
2324
2325 fn request_prediction_internal(
2326 &mut self,
2327 project: Entity<Project>,
2328 active_buffer: Entity<Buffer>,
2329 position: language::Anchor,
2330 trigger: PredictEditsRequestTrigger,
2331 allow_jump: bool,
2332 cx: &mut Context<Self>,
2333 ) -> Task<Result<Option<EditPredictionResult>>> {
2334 self.get_or_init_project(&project, cx);
2335 let project_state = self.projects.get(&project.entity_id()).unwrap();
2336 let stored_events = project_state.events(cx);
2337 let has_events = !stored_events.is_empty();
2338 let events: Vec<Arc<zeta_prompt::Event>> =
2339 stored_events.iter().map(|e| e.event.clone()).collect();
2340 let debug_tx = project_state.debug_tx.clone();
2341
2342 let snapshot = active_buffer.read(cx).snapshot();
2343 let cursor_point = position.to_point(&snapshot);
2344 let current_offset = position.to_offset(&snapshot);
2345
2346 let mut user_actions: Vec<UserActionRecord> =
2347 project_state.user_actions.iter().cloned().collect();
2348
2349 if let Some(last_action) = user_actions.last() {
2350 if last_action.buffer_id == active_buffer.entity_id()
2351 && current_offset != last_action.offset
2352 {
2353 let timestamp_epoch_ms = SystemTime::now()
2354 .duration_since(UNIX_EPOCH)
2355 .map(|d| d.as_millis() as u64)
2356 .unwrap_or(0);
2357 user_actions.push(UserActionRecord {
2358 action_type: UserActionType::CursorMovement,
2359 buffer_id: active_buffer.entity_id(),
2360 line_number: cursor_point.row,
2361 offset: current_offset,
2362 timestamp_epoch_ms,
2363 });
2364 }
2365 }
2366 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2367 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2368 let diagnostic_search_range =
2369 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2370
2371 let related_files = self.context_for_project(&project, cx);
2372
2373 let is_open_source = snapshot
2374 .file()
2375 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2376 && events.iter().all(|event| event.in_open_source_repo())
2377 && related_files.iter().all(|file| file.in_open_source_repo);
2378
2379 let can_collect_data = !cfg!(test)
2380 && is_open_source
2381 && self.is_data_collection_enabled(cx)
2382 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2383
2384 let recent_paths = project_state.recent_paths.clone();
2385
2386 let inputs = EditPredictionModelInput {
2387 project: project.clone(),
2388 buffer: active_buffer,
2389 snapshot,
2390 position,
2391 events,
2392 related_files,
2393 recent_paths,
2394 trigger,
2395 diagnostic_search_range: diagnostic_search_range,
2396 debug_tx,
2397 user_actions,
2398 can_collect_data,
2399 is_open_source,
2400 };
2401
2402 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2403
2404 let task = match self.edit_prediction_model {
2405 EditPredictionModel::Zeta => {
2406 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2407 }
2408 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2409 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2410 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2411 };
2412
2413 cx.spawn(async move |this, cx| {
2414 let prediction = task.await?;
2415
2416 // Only fall back to diagnostics-based prediction if we got a
2417 // the model had nothing to suggest for the buffer
2418 if prediction.is_none()
2419 && allow_jump
2420 && has_events
2421 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2422 {
2423 this.update(cx, |this, cx| {
2424 this.refresh_prediction_from_diagnostics(
2425 project,
2426 DiagnosticSearchScope::Local,
2427 cx,
2428 );
2429 })?;
2430 return anyhow::Ok(None);
2431 }
2432
2433 Ok(prediction)
2434 })
2435 }
2436
2437 pub(crate) async fn next_diagnostic_location(
2438 active_buffer: Entity<Buffer>,
2439 active_buffer_snapshot: &BufferSnapshot,
2440 active_buffer_diagnostic_search_range: Range<Point>,
2441 active_buffer_cursor_point: Point,
2442 project: &Entity<Project>,
2443 cx: &mut AsyncApp,
2444 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2445 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2446 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2447 .flat_map(|(_, _, _, selections)| {
2448 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2449 })
2450 .collect();
2451
2452 let mut jump_location = active_buffer_snapshot
2453 .diagnostic_groups(None)
2454 .into_iter()
2455 .filter_map(|(_, group)| {
2456 let range = &group.entries[group.primary_ix]
2457 .range
2458 .to_point(&active_buffer_snapshot);
2459 if range.overlaps(&active_buffer_diagnostic_search_range) {
2460 return None;
2461 }
2462 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2463 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2464 });
2465 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2466 <= DIAGNOSTIC_LINES_RANGE;
2467 if near_collaborator && !near_local {
2468 return None;
2469 }
2470 Some(range.start)
2471 })
2472 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2473 .map(|position| {
2474 (
2475 active_buffer.clone(),
2476 active_buffer_snapshot.anchor_before(position),
2477 )
2478 });
2479
2480 if jump_location.is_none() {
2481 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2482 let file = buffer.file()?;
2483
2484 Some(ProjectPath {
2485 worktree_id: file.worktree_id(cx),
2486 path: file.path().clone(),
2487 })
2488 });
2489
2490 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2491 project
2492 .diagnostic_summaries(false, cx)
2493 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2494 .map(|(path, _, _)| {
2495 let shared_prefix = path
2496 .path
2497 .components()
2498 .zip(
2499 active_buffer_path
2500 .as_ref()
2501 .map(|p| p.path.components())
2502 .unwrap_or_default(),
2503 )
2504 .take_while(|(a, b)| a == b)
2505 .count();
2506 (path, shared_prefix)
2507 })
2508 .collect()
2509 });
2510
2511 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2512
2513 for (path, _) in candidates {
2514 let candidate_buffer = project
2515 .update(cx, |project, cx| project.open_buffer(path, cx))
2516 .await?;
2517
2518 let (has_collaborators, diagnostic_position) =
2519 candidate_buffer.read_with(cx, |buffer, _cx| {
2520 let snapshot = buffer.snapshot();
2521 let has_collaborators = snapshot
2522 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2523 .next()
2524 .is_some();
2525 let position = buffer
2526 .buffer_diagnostics(None)
2527 .into_iter()
2528 .min_by_key(|entry| entry.diagnostic.severity)
2529 .map(|entry| entry.range.start);
2530 (has_collaborators, position)
2531 });
2532
2533 if has_collaborators {
2534 continue;
2535 }
2536
2537 if let Some(position) = diagnostic_position {
2538 jump_location = Some((candidate_buffer, position));
2539 break;
2540 }
2541 }
2542 }
2543
2544 anyhow::Ok(jump_location)
2545 }
2546
2547 async fn send_raw_llm_request(
2548 request: RawCompletionRequest,
2549 client: Arc<Client>,
2550 custom_url: Option<Arc<Url>>,
2551 llm_token: LlmApiToken,
2552 organization_id: Option<OrganizationId>,
2553 app_version: Version,
2554 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2555 let url = if let Some(custom_url) = custom_url {
2556 custom_url.as_ref().clone()
2557 } else {
2558 client
2559 .http_client()
2560 .build_zed_llm_url("/predict_edits/raw", &[])?
2561 };
2562
2563 Self::send_api_request(
2564 |builder| {
2565 let req = builder
2566 .uri(url.as_ref())
2567 .body(serde_json::to_string(&request)?.into());
2568 Ok(req?)
2569 },
2570 client,
2571 llm_token,
2572 organization_id,
2573 app_version,
2574 true,
2575 )
2576 .await
2577 }
2578
2579 pub(crate) async fn send_v3_request(
2580 input: ZetaPromptInput,
2581 client: Arc<Client>,
2582 llm_token: LlmApiToken,
2583 organization_id: Option<OrganizationId>,
2584 app_version: Version,
2585 trigger: PredictEditsRequestTrigger,
2586 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2587 let url = client
2588 .http_client()
2589 .build_zed_llm_url("/predict_edits/v3", &[])?;
2590
2591 let request = PredictEditsV3Request { input, trigger };
2592
2593 let json_bytes = serde_json::to_vec(&request)?;
2594 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2595
2596 Self::send_api_request(
2597 |builder| {
2598 let req = builder
2599 .uri(url.as_ref())
2600 .header("Content-Encoding", "zstd")
2601 .body(compressed.clone().into());
2602 Ok(req?)
2603 },
2604 client,
2605 llm_token,
2606 organization_id,
2607 app_version,
2608 true,
2609 )
2610 .await
2611 }
2612
2613 async fn send_api_request<Res>(
2614 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2615 client: Arc<Client>,
2616 llm_token: LlmApiToken,
2617 organization_id: Option<OrganizationId>,
2618 app_version: Version,
2619 require_auth: bool,
2620 ) -> Result<(Res, Option<EditPredictionUsage>)>
2621 where
2622 Res: DeserializeOwned,
2623 {
2624 let http_client = client.http_client();
2625
2626 let mut token = if require_auth {
2627 Some(llm_token.acquire(&client, organization_id.clone()).await?)
2628 } else {
2629 llm_token
2630 .acquire(&client, organization_id.clone())
2631 .await
2632 .ok()
2633 };
2634 let mut did_retry = false;
2635
2636 loop {
2637 let request_builder = http_client::Request::builder().method(Method::POST);
2638
2639 let mut request_builder = request_builder
2640 .header("Content-Type", "application/json")
2641 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2642
2643 // Only add Authorization header if we have a token
2644 if let Some(ref token_value) = token {
2645 request_builder =
2646 request_builder.header("Authorization", format!("Bearer {}", token_value));
2647 }
2648
2649 let request = build(request_builder)?;
2650
2651 let mut response = http_client.send(request).await?;
2652
2653 if let Some(minimum_required_version) = response
2654 .headers()
2655 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2656 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2657 {
2658 anyhow::ensure!(
2659 app_version >= minimum_required_version,
2660 ZedUpdateRequiredError {
2661 minimum_version: minimum_required_version
2662 }
2663 );
2664 }
2665
2666 if response.status().is_success() {
2667 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2668
2669 let mut body = Vec::new();
2670 response.body_mut().read_to_end(&mut body).await?;
2671 return Ok((serde_json::from_slice(&body)?, usage));
2672 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2673 did_retry = true;
2674 token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2675 } else {
2676 let mut body = String::new();
2677 response.body_mut().read_to_string(&mut body).await?;
2678 anyhow::bail!(
2679 "Request failed with status: {:?}\nBody: {}",
2680 response.status(),
2681 body
2682 );
2683 }
2684 }
2685 }
2686
2687 pub fn refresh_context(
2688 &mut self,
2689 project: &Entity<Project>,
2690 buffer: &Entity<language::Buffer>,
2691 cursor_position: language::Anchor,
2692 cx: &mut Context<Self>,
2693 ) {
2694 self.get_or_init_project(project, cx)
2695 .context
2696 .update(cx, |store, cx| {
2697 store.refresh(buffer.clone(), cursor_position, cx);
2698 });
2699 }
2700
2701 #[cfg(feature = "cli-support")]
2702 pub fn set_context_for_buffer(
2703 &mut self,
2704 project: &Entity<Project>,
2705 related_files: Vec<RelatedFile>,
2706 cx: &mut Context<Self>,
2707 ) {
2708 self.get_or_init_project(project, cx)
2709 .context
2710 .update(cx, |store, cx| {
2711 store.set_related_files(related_files, cx);
2712 });
2713 }
2714
2715 #[cfg(feature = "cli-support")]
2716 pub fn set_recent_paths_for_project(
2717 &mut self,
2718 project: &Entity<Project>,
2719 paths: impl IntoIterator<Item = project::ProjectPath>,
2720 cx: &mut Context<Self>,
2721 ) {
2722 let project_state = self.get_or_init_project(project, cx);
2723 project_state.recent_paths = paths.into_iter().collect();
2724 }
2725
2726 fn is_file_open_source(
2727 &self,
2728 project: &Entity<Project>,
2729 file: &Arc<dyn File>,
2730 cx: &App,
2731 ) -> bool {
2732 if !file.is_local() || file.is_private() {
2733 return false;
2734 }
2735 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2736 return false;
2737 };
2738 project_state
2739 .license_detection_watchers
2740 .get(&file.worktree_id(cx))
2741 .as_ref()
2742 .is_some_and(|watcher| watcher.is_project_open_source())
2743 }
2744
2745 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2746 self.data_collection_choice.is_enabled(cx)
2747 }
2748
2749 fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2750 let choice = KeyValueStore::global(cx)
2751 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2752 .log_err()
2753 .flatten();
2754
2755 match choice.as_deref() {
2756 Some("true") => DataCollectionChoice::Enabled,
2757 Some("false") => DataCollectionChoice::Disabled,
2758 Some(_) => {
2759 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2760 DataCollectionChoice::NotAnswered
2761 }
2762 None => DataCollectionChoice::NotAnswered,
2763 }
2764 }
2765
2766 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2767 self.data_collection_choice = self.data_collection_choice.toggle();
2768 let new_choice = self.data_collection_choice;
2769 let is_enabled = new_choice.is_enabled(cx);
2770 let kvp = KeyValueStore::global(cx);
2771 db::write_and_log(cx, move || async move {
2772 kvp.write_kvp(
2773 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2774 is_enabled.to_string(),
2775 )
2776 .await
2777 });
2778 }
2779
2780 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2781 self.shown_predictions.iter()
2782 }
2783
2784 pub fn shown_completions_len(&self) -> usize {
2785 self.shown_predictions.len()
2786 }
2787
2788 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2789 self.rated_predictions.contains(id)
2790 }
2791
2792 pub fn rate_prediction(
2793 &mut self,
2794 prediction: &EditPrediction,
2795 rating: EditPredictionRating,
2796 feedback: String,
2797 cx: &mut Context<Self>,
2798 ) {
2799 let organization = self.user_store.read(cx).current_organization();
2800
2801 self.rated_predictions.insert(prediction.id.clone());
2802
2803 cx.background_spawn({
2804 let client = self.client.clone();
2805 let prediction_id = prediction.id.to_string();
2806 let inputs = serde_json::to_value(&prediction.inputs);
2807 let output = prediction
2808 .edit_preview
2809 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2810 async move {
2811 client
2812 .cloud_client()
2813 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2814 organization_id: organization.map(|organization| organization.id.clone()),
2815 request_id: prediction_id,
2816 rating: match rating {
2817 EditPredictionRating::Positive => "positive".to_string(),
2818 EditPredictionRating::Negative => "negative".to_string(),
2819 },
2820 inputs: inputs?,
2821 output,
2822 feedback,
2823 })
2824 .await?;
2825
2826 anyhow::Ok(())
2827 }
2828 })
2829 .detach_and_log_err(cx);
2830
2831 cx.notify();
2832 }
2833}
2834
2835fn collaborator_edit_overlaps_locality_region(
2836 project_state: &ProjectState,
2837 project: &Entity<Project>,
2838 buffer: &Entity<Buffer>,
2839 snapshot: &BufferSnapshot,
2840 edit_range: &Range<Anchor>,
2841 cx: &App,
2842) -> bool {
2843 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2844 return false;
2845 };
2846
2847 if active_buffer.entity_id() != buffer.entity_id() {
2848 return false;
2849 }
2850
2851 let locality_point_range = expand_context_syntactically_then_linewise(
2852 snapshot,
2853 (position..position).to_point(snapshot),
2854 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2855 );
2856 let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2857
2858 edit_range.overlaps(&locality_anchor_range, snapshot)
2859}
2860
2861fn merge_trailing_events_if_needed(
2862 events: &mut VecDeque<StoredEvent>,
2863 end_snapshot: &TextBufferSnapshot,
2864 latest_snapshot: &TextBufferSnapshot,
2865 latest_edit_range: &Range<Anchor>,
2866) {
2867 if let Some(last_event) = events.back() {
2868 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2869 return;
2870 }
2871 if !latest_snapshot
2872 .version
2873 .observed_all(&last_event.new_snapshot_version)
2874 {
2875 return;
2876 }
2877 }
2878
2879 let mut next_old_event = None;
2880 let mut mergeable_count = 0;
2881 for old_event in events.iter().rev() {
2882 if let Some(next_old_event) = next_old_event
2883 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2884 {
2885 break;
2886 }
2887 mergeable_count += 1;
2888 next_old_event = Some(old_event);
2889 }
2890
2891 if mergeable_count <= 1 {
2892 return;
2893 }
2894
2895 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2896 let oldest_event = events_to_merge.peek().unwrap();
2897 let oldest_snapshot = oldest_event.old_snapshot.clone();
2898 let newest_snapshot = end_snapshot;
2899 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2900
2901 for event in events.range(events.len() - mergeable_count + 1..) {
2902 merged_edit_range =
2903 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2904 }
2905
2906 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2907 &oldest_snapshot,
2908 newest_snapshot,
2909 &merged_edit_range,
2910 ) {
2911 let merged_event = match oldest_event.event.as_ref() {
2912 zeta_prompt::Event::BufferChange {
2913 old_path,
2914 path,
2915 in_open_source_repo,
2916 ..
2917 } => StoredEvent {
2918 event: Arc::new(zeta_prompt::Event::BufferChange {
2919 old_path: old_path.clone(),
2920 path: path.clone(),
2921 diff,
2922 in_open_source_repo: *in_open_source_repo,
2923 predicted: events_to_merge.all(|e| {
2924 matches!(
2925 e.event.as_ref(),
2926 zeta_prompt::Event::BufferChange {
2927 predicted: true,
2928 ..
2929 }
2930 )
2931 }),
2932 }),
2933 old_snapshot: oldest_snapshot.clone(),
2934 new_snapshot_version: newest_snapshot.version.clone(),
2935 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2936 ..newest_snapshot.anchor_before(edit_range.end),
2937 },
2938 };
2939 events.truncate(events.len() - mergeable_count);
2940 events.push_back(merged_event);
2941 }
2942}
2943
2944fn merge_anchor_ranges(
2945 left: &Range<Anchor>,
2946 right: &Range<Anchor>,
2947 snapshot: &TextBufferSnapshot,
2948) -> Range<Anchor> {
2949 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2950 left.start
2951 } else {
2952 right.start
2953 };
2954 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2955 left.end
2956 } else {
2957 right.end
2958 };
2959 start..end
2960}
2961
2962#[derive(Error, Debug)]
2963#[error(
2964 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2965)]
2966pub struct ZedUpdateRequiredError {
2967 minimum_version: Version,
2968}
2969
2970#[derive(Debug, Clone, Copy)]
2971pub enum DataCollectionChoice {
2972 NotAnswered,
2973 Enabled,
2974 Disabled,
2975}
2976
2977impl DataCollectionChoice {
2978 pub fn is_enabled(self, cx: &App) -> bool {
2979 if cx.is_staff() {
2980 return true;
2981 }
2982 match self {
2983 Self::Enabled => true,
2984 Self::NotAnswered | Self::Disabled => false,
2985 }
2986 }
2987
2988 #[must_use]
2989 pub fn toggle(&self) -> DataCollectionChoice {
2990 match self {
2991 Self::Enabled => Self::Disabled,
2992 Self::Disabled => Self::Enabled,
2993 Self::NotAnswered => Self::Enabled,
2994 }
2995 }
2996}
2997
2998impl From<bool> for DataCollectionChoice {
2999 fn from(value: bool) -> Self {
3000 match value {
3001 true => DataCollectionChoice::Enabled,
3002 false => DataCollectionChoice::Disabled,
3003 }
3004 }
3005}
3006
3007struct ZedPredictUpsell;
3008
3009impl Dismissable for ZedPredictUpsell {
3010 const KEY: &'static str = "dismissed-edit-predict-upsell";
3011
3012 fn dismissed(cx: &App) -> bool {
3013 // To make this backwards compatible with older versions of Zed, we
3014 // check if the user has seen the previous Edit Prediction Onboarding
3015 // before, by checking the data collection choice which was written to
3016 // the database once the user clicked on "Accept and Enable"
3017 let kvp = KeyValueStore::global(cx);
3018 if kvp
3019 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3020 .log_err()
3021 .is_some_and(|s| s.is_some())
3022 {
3023 return true;
3024 }
3025
3026 kvp.read_kvp(Self::KEY)
3027 .log_err()
3028 .is_some_and(|s| s.is_some())
3029 }
3030}
3031
3032pub fn should_show_upsell_modal(cx: &App) -> bool {
3033 !ZedPredictUpsell::dismissed(cx)
3034}
3035
3036pub fn init(cx: &mut App) {
3037 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3038 workspace.register_action(
3039 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3040 ZedPredictModal::toggle(
3041 workspace,
3042 workspace.user_store().clone(),
3043 workspace.client().clone(),
3044 window,
3045 cx,
3046 )
3047 },
3048 );
3049
3050 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3051 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3052 settings
3053 .project
3054 .all_languages
3055 .edit_predictions
3056 .get_or_insert_default()
3057 .provider = Some(EditPredictionProvider::None)
3058 });
3059 });
3060 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3061 EditPredictionStore::try_global(cx).and_then(|store| {
3062 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3063 })
3064 }
3065
3066 workspace.register_action(|workspace, _: &SignIn, window, cx| {
3067 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3068 copilot_ui::initiate_sign_in(copilot, window, cx);
3069 }
3070 });
3071 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3072 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3073 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3074 }
3075 });
3076 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3077 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3078 copilot_ui::initiate_sign_out(copilot, window, cx);
3079 }
3080 });
3081 })
3082 .detach();
3083}