1use anyhow::Result;
2use client::{
3 Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore,
4 global_llm_token as global_llm_api_token,
5};
6use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
7use cloud_llm_client::predict_edits_v3::{
8 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
9};
10use cloud_llm_client::{
11 EditPredictionRejectReason, EditPredictionRejection,
12 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
13 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
14};
15use collections::{HashMap, HashSet};
16use copilot::{Copilot, Reinstall, SignIn, SignOut};
17use credentials_provider::CredentialsProvider;
18use db::kvp::{Dismissable, KeyValueStore};
19use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
21use futures::{
22 AsyncReadExt as _, FutureExt as _, StreamExt as _,
23 channel::mpsc::{self, UnboundedReceiver},
24 select_biased,
25};
26use gpui::BackgroundExecutor;
27use gpui::http_client::Url;
28use gpui::{
29 App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
30 http_client::{self, AsyncBody, Method},
31 prelude::*,
32};
33use heapless::Vec as ArrayVec;
34use language::language_settings::all_language_settings;
35use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
36use language::{BufferSnapshot, OffsetRangeExt};
37use language_model::LlmApiToken;
38use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
39use release_channel::AppVersion;
40use semver::Version;
41use serde::de::DeserializeOwned;
42use settings::{
43 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
44};
45use std::collections::{VecDeque, hash_map};
46use std::env;
47use text::{AnchorRangeExt, Edit};
48use workspace::{AppState, Workspace};
49use zeta_prompt::{ZetaFormat, ZetaPromptInput};
50
51use std::mem;
52use std::ops::Range;
53use std::path::Path;
54use std::rc::Rc;
55use std::str::FromStr as _;
56use std::sync::Arc;
57use std::time::{Duration, Instant};
58
59use thiserror::Error;
60use util::{RangeExt as _, ResultExt as _};
61
62pub mod cursor_excerpt;
63pub mod example_spec;
64pub mod fim;
65mod license_detection;
66pub mod mercury;
67pub mod ollama;
68mod onboarding_modal;
69pub mod open_ai_response;
70mod prediction;
71
72pub mod udiff;
73
74mod capture_example;
75pub mod open_ai_compatible;
76mod zed_edit_prediction_delegate;
77pub mod zeta;
78
79#[cfg(test)]
80mod edit_prediction_tests;
81
82use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
83use crate::example_spec::ExampleSpec;
84use crate::license_detection::LicenseDetectionWatcher;
85use crate::mercury::Mercury;
86use crate::onboarding_modal::ZedPredictModal;
87pub use crate::prediction::EditPrediction;
88pub use crate::prediction::EditPredictionId;
89use crate::prediction::EditPredictionResult;
90pub use capture_example::capture_example;
91pub use language_model::ApiKeyState;
92pub use telemetry_events::EditPredictionRating;
93pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
94
95actions!(
96 edit_prediction,
97 [
98 /// Resets the edit prediction onboarding state.
99 ResetOnboarding,
100 /// Clears the edit prediction history.
101 ClearHistory,
102 ]
103);
104
105/// Maximum number of events to track.
106const EVENT_COUNT_MAX: usize = 10;
107const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
108const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
109const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
110const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
111const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
112const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
113const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
114const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
115const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
116
117pub struct EditPredictionJumpsFeatureFlag;
118
119impl FeatureFlag for EditPredictionJumpsFeatureFlag {
120 const NAME: &'static str = "edit_prediction_jumps";
121}
122
123#[derive(Clone)]
124struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
125
126impl Global for EditPredictionStoreGlobal {}
127
128/// Configuration for using the raw Zeta2 endpoint.
129/// When set, the client uses the raw endpoint and constructs the prompt itself.
130/// The version is also used as the Baseten environment name (lowercased).
131#[derive(Clone)]
132pub struct Zeta2RawConfig {
133 pub model_id: Option<String>,
134 pub environment: Option<String>,
135 pub format: ZetaFormat,
136}
137
138pub struct EditPredictionStore {
139 client: Arc<Client>,
140 user_store: Entity<UserStore>,
141 llm_token: LlmApiToken,
142 _fetch_experiments_task: Task<()>,
143 projects: HashMap<EntityId, ProjectState>,
144 update_required: bool,
145 edit_prediction_model: EditPredictionModel,
146 zeta2_raw_config: Option<Zeta2RawConfig>,
147 preferred_experiment: Option<String>,
148 available_experiments: Vec<String>,
149 pub mercury: Mercury,
150 data_collection_choice: DataCollectionChoice,
151 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
152 settled_predictions_tx: mpsc::UnboundedSender<Instant>,
153 shown_predictions: VecDeque<EditPrediction>,
154 rated_predictions: HashSet<EditPredictionId>,
155 #[cfg(test)]
156 settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
157 credentials_provider: Arc<dyn CredentialsProvider>,
158}
159
160pub(crate) struct EditPredictionRejectionPayload {
161 rejection: EditPredictionRejection,
162 organization_id: Option<OrganizationId>,
163}
164
165#[derive(Copy, Clone, PartialEq, Eq)]
166pub enum EditPredictionModel {
167 Zeta,
168 Fim { format: EditPredictionPromptFormat },
169 Mercury,
170}
171
172#[derive(Clone)]
173pub struct EditPredictionModelInput {
174 project: Entity<Project>,
175 buffer: Entity<Buffer>,
176 snapshot: BufferSnapshot,
177 position: Anchor,
178 events: Vec<Arc<zeta_prompt::Event>>,
179 related_files: Vec<RelatedFile>,
180 trigger: PredictEditsRequestTrigger,
181 diagnostic_search_range: Range<Point>,
182 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
183 can_collect_data: bool,
184 is_open_source: bool,
185}
186
187#[derive(Debug)]
188pub enum DebugEvent {
189 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
190 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
191 EditPredictionStarted(EditPredictionStartedDebugEvent),
192 EditPredictionFinished(EditPredictionFinishedDebugEvent),
193}
194
195#[derive(Debug)]
196pub struct ContextRetrievalStartedDebugEvent {
197 pub project_entity_id: EntityId,
198 pub timestamp: Instant,
199 pub search_prompt: String,
200}
201
202#[derive(Debug)]
203pub struct ContextRetrievalFinishedDebugEvent {
204 pub project_entity_id: EntityId,
205 pub timestamp: Instant,
206 pub metadata: Vec<(&'static str, SharedString)>,
207}
208
209#[derive(Debug)]
210pub struct EditPredictionStartedDebugEvent {
211 pub buffer: WeakEntity<Buffer>,
212 pub position: Anchor,
213 pub prompt: Option<String>,
214}
215
216#[derive(Debug)]
217pub struct EditPredictionFinishedDebugEvent {
218 pub buffer: WeakEntity<Buffer>,
219 pub position: Anchor,
220 pub model_output: Option<String>,
221}
222
223/// An event with associated metadata for reconstructing buffer state.
224#[derive(Clone)]
225pub struct StoredEvent {
226 pub event: Arc<zeta_prompt::Event>,
227 pub old_snapshot: TextBufferSnapshot,
228 pub new_snapshot_version: clock::Global,
229 pub total_edit_range: Range<Anchor>,
230}
231
232impl StoredEvent {
233 fn can_merge(
234 &self,
235 next_old_event: &StoredEvent,
236 latest_snapshot: &TextBufferSnapshot,
237 latest_edit_range: &Range<Anchor>,
238 ) -> bool {
239 // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
240 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
241 return false;
242 }
243 if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
244 return false;
245 }
246 if self.new_snapshot_version != next_old_event.old_snapshot.version {
247 return false;
248 }
249 if !latest_snapshot
250 .version
251 .observed_all(&next_old_event.new_snapshot_version)
252 {
253 return false;
254 }
255
256 let a_is_predicted = matches!(
257 self.event.as_ref(),
258 zeta_prompt::Event::BufferChange {
259 predicted: true,
260 ..
261 }
262 );
263 let b_is_predicted = matches!(
264 next_old_event.event.as_ref(),
265 zeta_prompt::Event::BufferChange {
266 predicted: true,
267 ..
268 }
269 );
270
271 // If events come from the same source (both predicted or both manual) then
272 // we would have coalesced them already.
273 if a_is_predicted == b_is_predicted {
274 return false;
275 }
276
277 let left_range = self.total_edit_range.to_point(latest_snapshot);
278 let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
279 let latest_range = latest_edit_range.to_point(latest_snapshot);
280
281 // Events near to the latest edit are not merged if their sources differ.
282 if lines_between_ranges(&left_range, &latest_range)
283 .min(lines_between_ranges(&right_range, &latest_range))
284 <= CHANGE_GROUPING_LINE_SPAN
285 {
286 return false;
287 }
288
289 // Events that are distant from each other are not merged.
290 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
291 return false;
292 }
293
294 true
295 }
296}
297
298fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
299 if left.start > right.end {
300 return left.start.row - right.end.row;
301 }
302 if right.start > left.end {
303 return right.start.row - left.end.row;
304 }
305 0
306}
307
308struct ProjectState {
309 events: VecDeque<StoredEvent>,
310 last_event: Option<LastEvent>,
311 recent_paths: VecDeque<ProjectPath>,
312 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
313 current_prediction: Option<CurrentEditPrediction>,
314 next_pending_prediction_id: usize,
315 pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
316 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
317 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
318 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
319 cancelled_predictions: HashSet<usize>,
320 context: Entity<RelatedExcerptStore>,
321 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
322 _subscriptions: [gpui::Subscription; 2],
323 copilot: Option<Entity<Copilot>>,
324}
325
326impl ProjectState {
327 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
328 self.events
329 .iter()
330 .cloned()
331 .chain(self.last_event.as_ref().iter().flat_map(|event| {
332 let (one, two) = event.split_by_pause();
333 let one = one.finalize(&self.license_detection_watchers, cx);
334 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
335 one.into_iter().chain(two)
336 }))
337 .collect()
338 }
339
340 fn cancel_pending_prediction(
341 &mut self,
342 pending_prediction: PendingPrediction,
343 cx: &mut Context<EditPredictionStore>,
344 ) {
345 self.cancelled_predictions.insert(pending_prediction.id);
346
347 if pending_prediction.drop_on_cancel {
348 drop(pending_prediction.task);
349 } else {
350 cx.spawn(async move |this, cx| {
351 let Some(prediction_id) = pending_prediction.task.await else {
352 return;
353 };
354
355 this.update(cx, |this, cx| {
356 this.reject_prediction(
357 prediction_id,
358 EditPredictionRejectReason::Canceled,
359 false,
360 None,
361 None,
362 cx,
363 );
364 })
365 .ok();
366 })
367 .detach()
368 }
369 }
370
371 fn active_buffer(
372 &self,
373 project: &Entity<Project>,
374 cx: &App,
375 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
376 let project = project.read(cx);
377 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
378 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
379 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
380 Some((active_buffer, registered_buffer.last_position))
381 }
382}
383
384#[derive(Debug, Clone)]
385struct CurrentEditPrediction {
386 pub requested_by: PredictionRequestedBy,
387 pub prediction: EditPrediction,
388 pub was_shown: bool,
389 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
390 pub e2e_latency: std::time::Duration,
391}
392
393impl CurrentEditPrediction {
394 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
395 let Some(new_edits) = self
396 .prediction
397 .interpolate(&self.prediction.buffer.read(cx))
398 else {
399 return false;
400 };
401
402 if self.prediction.buffer != old_prediction.prediction.buffer {
403 return true;
404 }
405
406 let Some(old_edits) = old_prediction
407 .prediction
408 .interpolate(&old_prediction.prediction.buffer.read(cx))
409 else {
410 return true;
411 };
412
413 let requested_by_buffer_id = self.requested_by.buffer_id();
414
415 // This reduces the occurrence of UI thrash from replacing edits
416 //
417 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
418 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
419 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
420 && old_edits.len() == 1
421 && new_edits.len() == 1
422 {
423 let (old_range, old_text) = &old_edits[0];
424 let (new_range, new_text) = &new_edits[0];
425 new_range == old_range && new_text.starts_with(old_text.as_ref())
426 } else {
427 true
428 }
429 }
430}
431
432#[derive(Debug, Clone)]
433enum PredictionRequestedBy {
434 DiagnosticsUpdate,
435 Buffer(EntityId),
436}
437
438impl PredictionRequestedBy {
439 pub fn buffer_id(&self) -> Option<EntityId> {
440 match self {
441 PredictionRequestedBy::DiagnosticsUpdate => None,
442 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
443 }
444 }
445}
446
447const DIAGNOSTIC_LINES_RANGE: u32 = 20;
448
449#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
450pub enum DiagnosticSearchScope {
451 Local,
452 Global,
453}
454
455#[derive(Debug)]
456struct PendingPrediction {
457 id: usize,
458 task: Task<Option<EditPredictionId>>,
459 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
460 /// If false, the task is awaited to completion so rejection can be reported.
461 drop_on_cancel: bool,
462}
463
464/// A prediction from the perspective of a buffer.
465#[derive(Debug)]
466enum BufferEditPrediction<'a> {
467 Local { prediction: &'a EditPrediction },
468 Jump { prediction: &'a EditPrediction },
469}
470
471#[cfg(test)]
472impl std::ops::Deref for BufferEditPrediction<'_> {
473 type Target = EditPrediction;
474
475 fn deref(&self) -> &Self::Target {
476 match self {
477 BufferEditPrediction::Local { prediction } => prediction,
478 BufferEditPrediction::Jump { prediction } => prediction,
479 }
480 }
481}
482
483#[derive(Clone)]
484
485struct PendingSettledPrediction {
486 request_id: EditPredictionId,
487 editable_anchor_range: Range<Anchor>,
488 example: Option<ExampleSpec>,
489 enqueued_at: Instant,
490 last_edit_at: Instant,
491 e2e_latency: std::time::Duration,
492}
493
494struct RegisteredBuffer {
495 file: Option<Arc<dyn File>>,
496 snapshot: TextBufferSnapshot,
497 pending_predictions: Vec<PendingSettledPrediction>,
498 last_position: Option<Anchor>,
499 _subscriptions: [gpui::Subscription; 2],
500}
501
502#[derive(Clone)]
503struct LastEvent {
504 old_snapshot: TextBufferSnapshot,
505 new_snapshot: TextBufferSnapshot,
506 old_file: Option<Arc<dyn File>>,
507 new_file: Option<Arc<dyn File>>,
508 latest_edit_range: Range<Anchor>,
509 total_edit_range: Range<Anchor>,
510 total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
511 predicted: bool,
512 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
513 last_edit_time: Option<Instant>,
514}
515
516impl LastEvent {
517 pub fn finalize(
518 &self,
519 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
520 cx: &App,
521 ) -> Option<StoredEvent> {
522 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
523 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
524
525 let in_open_source_repo =
526 [self.new_file.as_ref(), self.old_file.as_ref()]
527 .iter()
528 .all(|file| {
529 file.is_some_and(|file| {
530 license_detection_watchers
531 .get(&file.worktree_id(cx))
532 .is_some_and(|watcher| watcher.is_project_open_source())
533 })
534 });
535
536 let (diff, edit_range) = compute_diff_between_snapshots_in_range(
537 &self.old_snapshot,
538 &self.new_snapshot,
539 &self.total_edit_range,
540 )?;
541
542 if path == old_path && diff.is_empty() {
543 None
544 } else {
545 Some(StoredEvent {
546 event: Arc::new(zeta_prompt::Event::BufferChange {
547 old_path,
548 path,
549 diff,
550 in_open_source_repo,
551 predicted: self.predicted,
552 }),
553 old_snapshot: self.old_snapshot.clone(),
554 new_snapshot_version: self.new_snapshot.version.clone(),
555 total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
556 ..self.new_snapshot.anchor_before(edit_range.end),
557 })
558 }
559 }
560
561 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
562 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
563 return (self.clone(), None);
564 };
565
566 let total_edit_range_before_pause = self
567 .total_edit_range_at_last_pause_boundary
568 .clone()
569 .unwrap_or_else(|| self.total_edit_range.clone());
570
571 let Some(total_edit_range_after_pause) =
572 compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
573 else {
574 return (self.clone(), None);
575 };
576
577 let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
578 let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
579
580 let before = LastEvent {
581 old_snapshot: self.old_snapshot.clone(),
582 new_snapshot: boundary_snapshot.clone(),
583 old_file: self.old_file.clone(),
584 new_file: self.new_file.clone(),
585 latest_edit_range: latest_edit_range_before_pause,
586 total_edit_range: total_edit_range_before_pause,
587 total_edit_range_at_last_pause_boundary: None,
588 predicted: self.predicted,
589 snapshot_after_last_editing_pause: None,
590 last_edit_time: self.last_edit_time,
591 };
592
593 let after = LastEvent {
594 old_snapshot: boundary_snapshot.clone(),
595 new_snapshot: self.new_snapshot.clone(),
596 old_file: self.old_file.clone(),
597 new_file: self.new_file.clone(),
598 latest_edit_range: latest_edit_range_after_pause,
599 total_edit_range: total_edit_range_after_pause,
600 total_edit_range_at_last_pause_boundary: None,
601 predicted: self.predicted,
602 snapshot_after_last_editing_pause: None,
603 last_edit_time: self.last_edit_time,
604 };
605
606 (before, Some(after))
607 }
608}
609
610fn compute_total_edit_range_between_snapshots(
611 old_snapshot: &TextBufferSnapshot,
612 new_snapshot: &TextBufferSnapshot,
613) -> Option<Range<Anchor>> {
614 let edits: Vec<Edit<usize>> = new_snapshot
615 .edits_since::<usize>(&old_snapshot.version)
616 .collect();
617
618 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
619 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
620 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
621
622 Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
623}
624
625fn compute_old_range_for_new_range(
626 old_snapshot: &TextBufferSnapshot,
627 new_snapshot: &TextBufferSnapshot,
628 total_edit_range: &Range<Anchor>,
629) -> Option<Range<Point>> {
630 let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
631 let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
632
633 let edits: Vec<Edit<usize>> = new_snapshot
634 .edits_since::<usize>(&old_snapshot.version)
635 .collect();
636 let mut old_start_offset = None;
637 let mut old_end_offset = None;
638 let mut delta: isize = 0;
639
640 for edit in &edits {
641 if old_start_offset.is_none() && new_start_offset <= edit.new.end {
642 old_start_offset = Some(if new_start_offset < edit.new.start {
643 new_start_offset.checked_add_signed(-delta)?
644 } else {
645 edit.old.start
646 });
647 }
648
649 if old_end_offset.is_none() && new_end_offset <= edit.new.end {
650 old_end_offset = Some(if new_end_offset < edit.new.start {
651 new_end_offset.checked_add_signed(-delta)?
652 } else {
653 edit.old.end
654 });
655 }
656
657 delta += edit.new.len() as isize - edit.old.len() as isize;
658 }
659
660 let old_start_offset =
661 old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
662 let old_end_offset =
663 old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
664
665 Some(
666 old_snapshot.offset_to_point(old_start_offset)
667 ..old_snapshot.offset_to_point(old_end_offset),
668 )
669}
670
671fn compute_diff_between_snapshots_in_range(
672 old_snapshot: &TextBufferSnapshot,
673 new_snapshot: &TextBufferSnapshot,
674 total_edit_range: &Range<Anchor>,
675) -> Option<(String, Range<Point>)> {
676 let new_start_point = total_edit_range.start.to_point(new_snapshot);
677 let new_end_point = total_edit_range.end.to_point(new_snapshot);
678 let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
679 let old_start_point = old_range.start;
680 let old_end_point = old_range.end;
681
682 const CONTEXT_LINES: u32 = 3;
683
684 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
685 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
686 let old_context_end_row =
687 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
688 let new_context_end_row =
689 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
690
691 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
692 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
693 let old_end_line_offset = old_snapshot
694 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
695 let new_end_line_offset = new_snapshot
696 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
697 let old_edit_range = old_start_line_offset..old_end_line_offset;
698 let new_edit_range = new_start_line_offset..new_end_line_offset;
699
700 if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
701 || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
702 {
703 return None;
704 }
705
706 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
707 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
708
709 let diff = language::unified_diff_with_offsets(
710 &old_region_text,
711 &new_region_text,
712 old_context_start_row,
713 new_context_start_row,
714 );
715
716 Some((diff, new_start_point..new_end_point))
717}
718
719fn buffer_path_with_id_fallback(
720 file: Option<&Arc<dyn File>>,
721 snapshot: &TextBufferSnapshot,
722 cx: &App,
723) -> Arc<Path> {
724 if let Some(file) = file {
725 file.full_path(cx).into()
726 } else {
727 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
728 }
729}
730
731impl EditPredictionStore {
732 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
733 cx.try_global::<EditPredictionStoreGlobal>()
734 .map(|global| global.0.clone())
735 }
736
737 pub fn global(
738 client: &Arc<Client>,
739 user_store: &Entity<UserStore>,
740 cx: &mut App,
741 ) -> Entity<Self> {
742 cx.try_global::<EditPredictionStoreGlobal>()
743 .map(|global| global.0.clone())
744 .unwrap_or_else(|| {
745 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
746 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
747 ep_store
748 })
749 }
750
751 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
752 let data_collection_choice = Self::load_data_collection_choice(cx);
753
754 let llm_token = global_llm_api_token(cx);
755
756 let (reject_tx, reject_rx) = mpsc::unbounded();
757 cx.background_spawn({
758 let client = client.clone();
759 let llm_token = llm_token.clone();
760 let app_version = AppVersion::global(cx);
761 let background_executor = cx.background_executor().clone();
762 async move {
763 Self::handle_rejected_predictions(
764 reject_rx,
765 client,
766 llm_token,
767 app_version,
768 background_executor,
769 )
770 .await
771 }
772 })
773 .detach();
774
775 let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
776 cx.spawn(async move |this, cx| {
777 Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
778 })
779 .detach();
780
781 let mut current_user = user_store.read(cx).watch_current_user();
782 let fetch_experiments_task = cx.spawn(async move |this, cx| {
783 while current_user.borrow().is_none() {
784 current_user.next().await;
785 }
786
787 this.update(cx, |this, cx| {
788 if cx.is_staff() {
789 this.refresh_available_experiments(cx);
790 }
791 })
792 .log_err();
793 });
794
795 let credentials_provider = zed_credentials_provider::global(cx);
796
797 let this = Self {
798 projects: HashMap::default(),
799 client,
800 user_store,
801 llm_token,
802 _fetch_experiments_task: fetch_experiments_task,
803 update_required: false,
804 edit_prediction_model: EditPredictionModel::Zeta,
805 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
806 preferred_experiment: None,
807 available_experiments: Vec::new(),
808 mercury: Mercury::new(cx),
809
810 data_collection_choice,
811 reject_predictions_tx: reject_tx,
812 settled_predictions_tx,
813 rated_predictions: Default::default(),
814 shown_predictions: Default::default(),
815 #[cfg(test)]
816 settled_event_callback: None,
817
818 credentials_provider,
819 };
820
821 this
822 }
823
824 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
825 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
826 let format = ZetaFormat::parse(&version_str).ok()?;
827 let model_id = env::var("ZED_ZETA_MODEL").ok();
828 let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
829 Some(Zeta2RawConfig {
830 model_id,
831 environment,
832 format,
833 })
834 }
835
836 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
837 self.edit_prediction_model = model;
838 }
839
840 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
841 self.zeta2_raw_config = Some(config);
842 }
843
844 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
845 self.zeta2_raw_config.as_ref()
846 }
847
848 pub fn preferred_experiment(&self) -> Option<&str> {
849 self.preferred_experiment.as_deref()
850 }
851
852 pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
853 self.preferred_experiment = experiment;
854 }
855
856 pub fn available_experiments(&self) -> &[String] {
857 &self.available_experiments
858 }
859
860 pub fn active_experiment(&self) -> Option<&str> {
861 self.preferred_experiment.as_deref().or_else(|| {
862 self.shown_predictions
863 .iter()
864 .find_map(|p| p.model_version.as_ref())
865 .and_then(|model_version| model_version.strip_prefix("zeta2:"))
866 })
867 }
868
869 pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
870 let client = self.client.clone();
871 let llm_token = self.llm_token.clone();
872 let app_version = AppVersion::global(cx);
873 let organization_id = self
874 .user_store
875 .read(cx)
876 .current_organization()
877 .map(|organization| organization.id.clone());
878
879 cx.spawn(async move |this, cx| {
880 let experiments = cx
881 .background_spawn(async move {
882 let http_client = client.http_client();
883 let token = client
884 .acquire_llm_token(&llm_token, organization_id.clone())
885 .await?;
886 let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
887 let request = http_client::Request::builder()
888 .method(Method::GET)
889 .uri(url.as_ref())
890 .header("Authorization", format!("Bearer {}", token))
891 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
892 .body(Default::default())?;
893 let mut response = http_client.send(request).await?;
894 if response.status().is_success() {
895 let mut body = Vec::new();
896 response.body_mut().read_to_end(&mut body).await?;
897 let experiments: Vec<String> = serde_json::from_slice(&body)?;
898 Ok(experiments)
899 } else {
900 let mut body = String::new();
901 response.body_mut().read_to_string(&mut body).await?;
902 anyhow::bail!(
903 "Failed to fetch experiments: {:?}\nBody: {}",
904 response.status(),
905 body
906 );
907 }
908 })
909 .await?;
910 this.update(cx, |this, cx| {
911 this.available_experiments = experiments;
912 cx.notify();
913 })?;
914 anyhow::Ok(())
915 })
916 .detach_and_log_err(cx);
917 }
918
919 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
920 use ui::IconName;
921 match self.edit_prediction_model {
922 EditPredictionModel::Mercury => {
923 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
924 }
925 EditPredictionModel::Zeta => {
926 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
927 .with_disabled(IconName::ZedPredictDisabled)
928 .with_up(IconName::ZedPredictUp)
929 .with_down(IconName::ZedPredictDown)
930 .with_error(IconName::ZedPredictError)
931 }
932 EditPredictionModel::Fim { .. } => {
933 let settings = &all_language_settings(None, cx).edit_predictions;
934 match settings.provider {
935 EditPredictionProvider::Ollama => {
936 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
937 }
938 _ => {
939 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
940 }
941 }
942 }
943 }
944 }
945
946 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
947 self.mercury.api_token.read(cx).has_key()
948 }
949
950 pub fn mercury_has_payment_required_error(&self) -> bool {
951 self.mercury.has_payment_required_error()
952 }
953
954 pub fn clear_history(&mut self) {
955 for project_state in self.projects.values_mut() {
956 project_state.events.clear();
957 project_state.last_event.take();
958 }
959 }
960
961 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
962 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
963 project_state.events.clear();
964 project_state.last_event.take();
965 }
966 }
967
968 pub fn edit_history_for_project(
969 &self,
970 project: &Entity<Project>,
971 cx: &App,
972 ) -> Vec<StoredEvent> {
973 self.projects
974 .get(&project.entity_id())
975 .map(|project_state| project_state.events(cx))
976 .unwrap_or_default()
977 }
978
979 pub fn context_for_project<'a>(
980 &'a self,
981 project: &Entity<Project>,
982 cx: &'a mut App,
983 ) -> Vec<RelatedFile> {
984 self.projects
985 .get(&project.entity_id())
986 .map(|project_state| {
987 project_state.context.update(cx, |context, cx| {
988 context
989 .related_files_with_buffers(cx)
990 .map(|(mut related_file, buffer)| {
991 related_file.in_open_source_repo = buffer
992 .read(cx)
993 .file()
994 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
995 related_file
996 })
997 .collect()
998 })
999 })
1000 .unwrap_or_default()
1001 }
1002
1003 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1004 self.projects
1005 .get(&project.entity_id())
1006 .and_then(|project| project.copilot.clone())
1007 }
1008
1009 pub fn start_copilot_for_project(
1010 &mut self,
1011 project: &Entity<Project>,
1012 cx: &mut Context<Self>,
1013 ) -> Option<Entity<Copilot>> {
1014 if DisableAiSettings::get(None, cx).disable_ai {
1015 return None;
1016 }
1017 let state = self.get_or_init_project(project, cx);
1018
1019 if state.copilot.is_some() {
1020 return state.copilot.clone();
1021 }
1022 let _project = project.clone();
1023 let project = project.read(cx);
1024
1025 let node = project.node_runtime().cloned();
1026 if let Some(node) = node {
1027 let next_id = project.languages().next_language_server_id();
1028 let fs = project.fs().clone();
1029
1030 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1031 state.copilot = Some(copilot.clone());
1032 Some(copilot)
1033 } else {
1034 None
1035 }
1036 }
1037
1038 pub fn context_for_project_with_buffers<'a>(
1039 &'a self,
1040 project: &Entity<Project>,
1041 cx: &'a mut App,
1042 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1043 self.projects
1044 .get(&project.entity_id())
1045 .map(|project| {
1046 project.context.update(cx, |context, cx| {
1047 context.related_files_with_buffers(cx).collect()
1048 })
1049 })
1050 .unwrap_or_default()
1051 }
1052
1053 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1054 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1055 self.user_store.read(cx).edit_prediction_usage()
1056 } else {
1057 None
1058 }
1059 }
1060
1061 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1062 self.get_or_init_project(project, cx);
1063 }
1064
1065 pub fn register_buffer(
1066 &mut self,
1067 buffer: &Entity<Buffer>,
1068 project: &Entity<Project>,
1069 cx: &mut Context<Self>,
1070 ) {
1071 let project_state = self.get_or_init_project(project, cx);
1072 Self::register_buffer_impl(project_state, buffer, project, cx);
1073 }
1074
1075 fn get_or_init_project(
1076 &mut self,
1077 project: &Entity<Project>,
1078 cx: &mut Context<Self>,
1079 ) -> &mut ProjectState {
1080 let entity_id = project.entity_id();
1081 self.projects
1082 .entry(entity_id)
1083 .or_insert_with(|| ProjectState {
1084 context: {
1085 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1086 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1087 this.handle_excerpt_store_event(entity_id, event);
1088 })
1089 .detach();
1090 related_excerpt_store
1091 },
1092 events: VecDeque::new(),
1093 last_event: None,
1094 recent_paths: VecDeque::new(),
1095 debug_tx: None,
1096 registered_buffers: HashMap::default(),
1097 current_prediction: None,
1098 cancelled_predictions: HashSet::default(),
1099 pending_predictions: ArrayVec::new(),
1100 next_pending_prediction_id: 0,
1101 last_edit_prediction_refresh: None,
1102 last_jump_prediction_refresh: None,
1103 license_detection_watchers: HashMap::default(),
1104 _subscriptions: [
1105 cx.subscribe(&project, Self::handle_project_event),
1106 cx.observe_release(&project, move |this, _, cx| {
1107 this.projects.remove(&entity_id);
1108 cx.notify();
1109 }),
1110 ],
1111 copilot: None,
1112 })
1113 }
1114
1115 pub fn remove_project(&mut self, project: &Entity<Project>) {
1116 self.projects.remove(&project.entity_id());
1117 }
1118
1119 fn handle_excerpt_store_event(
1120 &mut self,
1121 project_entity_id: EntityId,
1122 event: &RelatedExcerptStoreEvent,
1123 ) {
1124 if let Some(project_state) = self.projects.get(&project_entity_id) {
1125 if let Some(debug_tx) = project_state.debug_tx.clone() {
1126 match event {
1127 RelatedExcerptStoreEvent::StartedRefresh => {
1128 debug_tx
1129 .unbounded_send(DebugEvent::ContextRetrievalStarted(
1130 ContextRetrievalStartedDebugEvent {
1131 project_entity_id: project_entity_id,
1132 timestamp: Instant::now(),
1133 search_prompt: String::new(),
1134 },
1135 ))
1136 .ok();
1137 }
1138 RelatedExcerptStoreEvent::FinishedRefresh {
1139 cache_hit_count,
1140 cache_miss_count,
1141 mean_definition_latency,
1142 max_definition_latency,
1143 } => {
1144 debug_tx
1145 .unbounded_send(DebugEvent::ContextRetrievalFinished(
1146 ContextRetrievalFinishedDebugEvent {
1147 project_entity_id: project_entity_id,
1148 timestamp: Instant::now(),
1149 metadata: vec![
1150 (
1151 "Cache Hits",
1152 format!(
1153 "{}/{}",
1154 cache_hit_count,
1155 cache_hit_count + cache_miss_count
1156 )
1157 .into(),
1158 ),
1159 (
1160 "Max LSP Time",
1161 format!("{} ms", max_definition_latency.as_millis())
1162 .into(),
1163 ),
1164 (
1165 "Mean LSP Time",
1166 format!("{} ms", mean_definition_latency.as_millis())
1167 .into(),
1168 ),
1169 ],
1170 },
1171 ))
1172 .ok();
1173 }
1174 }
1175 }
1176 }
1177 }
1178
1179 pub fn debug_info(
1180 &mut self,
1181 project: &Entity<Project>,
1182 cx: &mut Context<Self>,
1183 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1184 let project_state = self.get_or_init_project(project, cx);
1185 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1186 project_state.debug_tx = Some(debug_watch_tx);
1187 debug_watch_rx
1188 }
1189
1190 fn handle_project_event(
1191 &mut self,
1192 project: Entity<Project>,
1193 event: &project::Event,
1194 cx: &mut Context<Self>,
1195 ) {
1196 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1197 return;
1198 }
1199 // TODO [zeta2] init with recent paths
1200 match event {
1201 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1202 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1203 return;
1204 };
1205 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1206 if let Some(path) = path {
1207 if let Some(ix) = project_state
1208 .recent_paths
1209 .iter()
1210 .position(|probe| probe == &path)
1211 {
1212 project_state.recent_paths.remove(ix);
1213 }
1214 project_state.recent_paths.push_front(path);
1215 }
1216 }
1217 project::Event::DiagnosticsUpdated { .. } => {
1218 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1219 self.refresh_prediction_from_diagnostics(
1220 project,
1221 DiagnosticSearchScope::Global,
1222 cx,
1223 );
1224 }
1225 }
1226 _ => (),
1227 }
1228 }
1229
1230 fn register_buffer_impl<'a>(
1231 project_state: &'a mut ProjectState,
1232 buffer: &Entity<Buffer>,
1233 project: &Entity<Project>,
1234 cx: &mut Context<Self>,
1235 ) -> &'a mut RegisteredBuffer {
1236 let buffer_id = buffer.entity_id();
1237
1238 if let Some(file) = buffer.read(cx).file() {
1239 let worktree_id = file.worktree_id(cx);
1240 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1241 project_state
1242 .license_detection_watchers
1243 .entry(worktree_id)
1244 .or_insert_with(|| {
1245 let project_entity_id = project.entity_id();
1246 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1247 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1248 else {
1249 return;
1250 };
1251 project_state
1252 .license_detection_watchers
1253 .remove(&worktree_id);
1254 })
1255 .detach();
1256 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1257 });
1258 }
1259 }
1260
1261 match project_state.registered_buffers.entry(buffer_id) {
1262 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1263 hash_map::Entry::Vacant(entry) => {
1264 let buf = buffer.read(cx);
1265 let snapshot = buf.text_snapshot();
1266 let file = buf.file().cloned();
1267 let project_entity_id = project.entity_id();
1268 entry.insert(RegisteredBuffer {
1269 snapshot,
1270 file,
1271 last_position: None,
1272 pending_predictions: Vec::new(),
1273 _subscriptions: [
1274 cx.subscribe(buffer, {
1275 let project = project.downgrade();
1276 move |this, buffer, event, cx| {
1277 if let language::BufferEvent::Edited { is_local } = event
1278 && let Some(project) = project.upgrade()
1279 {
1280 this.report_changes_for_buffer(
1281 &buffer, &project, false, *is_local, cx,
1282 );
1283 }
1284 }
1285 }),
1286 cx.observe_release(buffer, move |this, _buffer, _cx| {
1287 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1288 else {
1289 return;
1290 };
1291 project_state.registered_buffers.remove(&buffer_id);
1292 }),
1293 ],
1294 })
1295 }
1296 }
1297 }
1298
1299 fn report_changes_for_buffer(
1300 &mut self,
1301 buffer: &Entity<Buffer>,
1302 project: &Entity<Project>,
1303 is_predicted: bool,
1304 is_local: bool,
1305 cx: &mut Context<Self>,
1306 ) {
1307 let project_state = self.get_or_init_project(project, cx);
1308 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1309
1310 let buf = buffer.read(cx);
1311 let new_file = buf.file().cloned();
1312 let new_snapshot = buf.text_snapshot();
1313 if new_snapshot.version == registered_buffer.snapshot.version {
1314 return;
1315 }
1316 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1317 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1318 let mut edit_range: Option<Range<Anchor>> = None;
1319 let now = cx.background_executor().now();
1320
1321 for (_edit, anchor_range) in
1322 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1323 {
1324 edit_range = Some(match edit_range {
1325 None => anchor_range,
1326 Some(acc) => acc.start..anchor_range.end,
1327 });
1328 }
1329
1330 let Some(edit_range) = edit_range else {
1331 return;
1332 };
1333
1334 for pending_prediction in &mut registered_buffer.pending_predictions {
1335 if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1336 pending_prediction.last_edit_at = now;
1337 }
1338 }
1339
1340 let include_in_history = is_local
1341 || collaborator_edit_overlaps_locality_region(
1342 project_state,
1343 project,
1344 buffer,
1345 &buf.snapshot(),
1346 &edit_range,
1347 cx,
1348 );
1349
1350 if !include_in_history {
1351 return;
1352 }
1353
1354 let is_recordable_history_edit =
1355 compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1356 .is_some();
1357
1358 let events = &mut project_state.events;
1359
1360 if !is_recordable_history_edit {
1361 if let Some(event) = project_state.last_event.take() {
1362 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1363 if events.len() + 1 >= EVENT_COUNT_MAX {
1364 events.pop_front();
1365 }
1366 events.push_back(event);
1367 }
1368 }
1369 return;
1370 }
1371
1372 if let Some(last_event) = project_state.last_event.as_mut() {
1373 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1374 == last_event.new_snapshot.remote_id()
1375 && old_snapshot.version == last_event.new_snapshot.version;
1376
1377 let prediction_source_changed = is_predicted != last_event.predicted;
1378
1379 let should_coalesce = is_next_snapshot_of_same_buffer
1380 && !prediction_source_changed
1381 && lines_between_ranges(
1382 &edit_range.to_point(&new_snapshot),
1383 &last_event.latest_edit_range.to_point(&new_snapshot),
1384 ) <= CHANGE_GROUPING_LINE_SPAN;
1385
1386 if should_coalesce {
1387 let pause_elapsed = last_event
1388 .last_edit_time
1389 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1390 .unwrap_or(false);
1391 if pause_elapsed {
1392 last_event.snapshot_after_last_editing_pause =
1393 Some(last_event.new_snapshot.clone());
1394 last_event.total_edit_range_at_last_pause_boundary =
1395 Some(last_event.total_edit_range.clone());
1396 }
1397
1398 last_event.latest_edit_range = edit_range.clone();
1399 last_event.total_edit_range =
1400 merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1401 last_event.new_snapshot = new_snapshot;
1402 last_event.last_edit_time = Some(now);
1403 return;
1404 }
1405 }
1406
1407 if let Some(event) = project_state.last_event.take() {
1408 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1409 if events.len() + 1 >= EVENT_COUNT_MAX {
1410 events.pop_front();
1411 }
1412 events.push_back(event);
1413 }
1414 }
1415
1416 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1417
1418 project_state.last_event = Some(LastEvent {
1419 old_file,
1420 new_file,
1421 old_snapshot,
1422 new_snapshot,
1423 latest_edit_range: edit_range.clone(),
1424 total_edit_range: edit_range,
1425 total_edit_range_at_last_pause_boundary: None,
1426 predicted: is_predicted,
1427 snapshot_after_last_editing_pause: None,
1428 last_edit_time: Some(now),
1429 });
1430 }
1431
1432 fn prediction_at(
1433 &mut self,
1434 buffer: &Entity<Buffer>,
1435 position: Option<language::Anchor>,
1436 project: &Entity<Project>,
1437 cx: &App,
1438 ) -> Option<BufferEditPrediction<'_>> {
1439 let project_state = self.projects.get_mut(&project.entity_id())?;
1440 if let Some(position) = position
1441 && let Some(buffer) = project_state
1442 .registered_buffers
1443 .get_mut(&buffer.entity_id())
1444 {
1445 buffer.last_position = Some(position);
1446 }
1447
1448 let CurrentEditPrediction {
1449 requested_by,
1450 prediction,
1451 ..
1452 } = project_state.current_prediction.as_ref()?;
1453
1454 if prediction.targets_buffer(buffer.read(cx)) {
1455 Some(BufferEditPrediction::Local { prediction })
1456 } else {
1457 let show_jump = match requested_by {
1458 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1459 requested_by_buffer_id == &buffer.entity_id()
1460 }
1461 PredictionRequestedBy::DiagnosticsUpdate => true,
1462 };
1463
1464 if show_jump {
1465 Some(BufferEditPrediction::Jump { prediction })
1466 } else {
1467 None
1468 }
1469 }
1470 }
1471
1472 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1473 let Some(current_prediction) = self
1474 .projects
1475 .get_mut(&project.entity_id())
1476 .and_then(|project_state| project_state.current_prediction.take())
1477 else {
1478 return;
1479 };
1480
1481 self.report_changes_for_buffer(
1482 ¤t_prediction.prediction.buffer,
1483 project,
1484 true,
1485 true,
1486 cx,
1487 );
1488
1489 // can't hold &mut project_state ref across report_changes_for_buffer_call
1490 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1491 return;
1492 };
1493
1494 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1495 project_state.cancel_pending_prediction(pending_prediction, cx);
1496 }
1497
1498 match self.edit_prediction_model {
1499 EditPredictionModel::Mercury => {
1500 mercury::edit_prediction_accepted(
1501 current_prediction.prediction.id,
1502 self.client.http_client(),
1503 cx,
1504 );
1505 }
1506 EditPredictionModel::Zeta => {
1507 let is_cloud = !matches!(
1508 all_language_settings(None, cx).edit_predictions.provider,
1509 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1510 );
1511 if is_cloud {
1512 zeta::edit_prediction_accepted(self, current_prediction, cx)
1513 }
1514 }
1515 EditPredictionModel::Fim { .. } => {}
1516 }
1517 }
1518
1519 async fn handle_rejected_predictions(
1520 rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1521 client: Arc<Client>,
1522 llm_token: LlmApiToken,
1523 app_version: Version,
1524 background_executor: BackgroundExecutor,
1525 ) {
1526 let mut rx = std::pin::pin!(rx.peekable());
1527 let mut batched = Vec::new();
1528
1529 while let Some(EditPredictionRejectionPayload {
1530 rejection,
1531 organization_id,
1532 }) = rx.next().await
1533 {
1534 batched.push(rejection);
1535
1536 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1537 select_biased! {
1538 next = rx.as_mut().peek().fuse() => {
1539 if next.is_some() {
1540 continue;
1541 }
1542 }
1543 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1544 }
1545 }
1546
1547 let url = client
1548 .http_client()
1549 .build_zed_llm_url("/predict_edits/reject", &[])
1550 .unwrap();
1551
1552 let flush_count = batched
1553 .len()
1554 // in case items have accumulated after failure
1555 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1556 let start = batched.len() - flush_count;
1557
1558 let body = RejectEditPredictionsBodyRef {
1559 rejections: &batched[start..],
1560 };
1561
1562 let result = Self::send_api_request::<()>(
1563 |builder| {
1564 let req = builder
1565 .uri(url.as_ref())
1566 .body(serde_json::to_string(&body)?.into());
1567 anyhow::Ok(req?)
1568 },
1569 client.clone(),
1570 llm_token.clone(),
1571 organization_id,
1572 app_version.clone(),
1573 true,
1574 )
1575 .await;
1576
1577 if result.log_err().is_some() {
1578 batched.drain(start..);
1579 }
1580 }
1581 }
1582
1583 async fn run_settled_predictions_worker(
1584 this: WeakEntity<Self>,
1585 mut rx: UnboundedReceiver<Instant>,
1586 cx: &mut AsyncApp,
1587 ) {
1588 let mut next_wake_time: Option<Instant> = None;
1589 loop {
1590 let now = cx.background_executor().now();
1591 if let Some(wake_time) = next_wake_time.take() {
1592 cx.background_executor()
1593 .timer(wake_time.duration_since(now))
1594 .await;
1595 } else {
1596 let Some(new_enqueue_time) = rx.next().await else {
1597 break;
1598 };
1599 next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1600 while rx.next().now_or_never().flatten().is_some() {}
1601 continue;
1602 }
1603
1604 let Some(this) = this.upgrade() else {
1605 break;
1606 };
1607
1608 let now = cx.background_executor().now();
1609
1610 let mut oldest_edited_at = None;
1611
1612 this.update(cx, |this, _| {
1613 for (_, project_state) in this.projects.iter_mut() {
1614 for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1615 registered_buffer
1616 .pending_predictions
1617 .retain_mut(|pending_prediction| {
1618 let age =
1619 now.saturating_duration_since(pending_prediction.enqueued_at);
1620 if age >= EDIT_PREDICTION_SETTLED_TTL {
1621 return false;
1622 }
1623
1624 let quiet_for =
1625 now.saturating_duration_since(pending_prediction.last_edit_at);
1626 if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1627 let settled_editable_region = registered_buffer
1628 .snapshot
1629 .text_for_range(
1630 pending_prediction.editable_anchor_range.clone(),
1631 )
1632 .collect::<String>();
1633
1634 #[cfg(test)]
1635 if let Some(callback) = &this.settled_event_callback {
1636 callback(
1637 pending_prediction.request_id.clone(),
1638 settled_editable_region.clone(),
1639 );
1640 }
1641
1642 telemetry::event!(
1643 EDIT_PREDICTION_SETTLED_EVENT,
1644 request_id = pending_prediction.request_id.0.clone(),
1645 settled_editable_region,
1646 example = pending_prediction.example.take(),
1647 e2e_latency = pending_prediction.e2e_latency.as_millis(),
1648 );
1649
1650 return false;
1651 }
1652
1653 if oldest_edited_at
1654 .is_none_or(|t| pending_prediction.last_edit_at < t)
1655 {
1656 oldest_edited_at = Some(pending_prediction.last_edit_at);
1657 }
1658
1659 true
1660 });
1661 }
1662 }
1663 });
1664
1665 next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1666 }
1667 }
1668
1669 pub(crate) fn enqueue_settled_prediction(
1670 &mut self,
1671 request_id: EditPredictionId,
1672 project: &Entity<Project>,
1673 edited_buffer: &Entity<Buffer>,
1674 edited_buffer_snapshot: &BufferSnapshot,
1675 editable_offset_range: Range<usize>,
1676 example: Option<ExampleSpec>,
1677 e2e_latency: std::time::Duration,
1678 cx: &mut Context<Self>,
1679 ) {
1680 let this = &mut *self;
1681 let project_state = this.get_or_init_project(project, cx);
1682 if let Some(buffer) = project_state
1683 .registered_buffers
1684 .get_mut(&edited_buffer.entity_id())
1685 {
1686 let now = cx.background_executor().now();
1687 buffer.pending_predictions.push(PendingSettledPrediction {
1688 request_id: request_id,
1689 editable_anchor_range: edited_buffer_snapshot
1690 .anchor_range_around(editable_offset_range),
1691 example,
1692 e2e_latency,
1693 enqueued_at: now,
1694 last_edit_at: now,
1695 });
1696 this.settled_predictions_tx.unbounded_send(now).ok();
1697 }
1698 }
1699
1700 fn reject_current_prediction(
1701 &mut self,
1702 reason: EditPredictionRejectReason,
1703 project: &Entity<Project>,
1704 cx: &App,
1705 ) {
1706 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1707 project_state.pending_predictions.clear();
1708 if let Some(prediction) = project_state.current_prediction.take() {
1709 let model_version = prediction.prediction.model_version.clone();
1710 self.reject_prediction(
1711 prediction.prediction.id,
1712 reason,
1713 prediction.was_shown,
1714 model_version,
1715 Some(prediction.e2e_latency),
1716 cx,
1717 );
1718 }
1719 };
1720 }
1721
1722 fn did_show_current_prediction(
1723 &mut self,
1724 project: &Entity<Project>,
1725 display_type: edit_prediction_types::SuggestionDisplayType,
1726 _cx: &mut Context<Self>,
1727 ) {
1728 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1729 return;
1730 };
1731
1732 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1733 return;
1734 };
1735
1736 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1737 let previous_shown_with = current_prediction.shown_with;
1738
1739 if previous_shown_with.is_none() || !is_jump {
1740 current_prediction.shown_with = Some(display_type);
1741 }
1742
1743 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1744
1745 if is_first_non_jump_show {
1746 current_prediction.was_shown = true;
1747 }
1748
1749 if is_first_non_jump_show {
1750 self.shown_predictions
1751 .push_front(current_prediction.prediction.clone());
1752 if self.shown_predictions.len() > 50 {
1753 let completion = self.shown_predictions.pop_back().unwrap();
1754 self.rated_predictions.remove(&completion.id);
1755 }
1756 }
1757 }
1758
1759 fn reject_prediction(
1760 &mut self,
1761 prediction_id: EditPredictionId,
1762 reason: EditPredictionRejectReason,
1763 was_shown: bool,
1764 model_version: Option<String>,
1765 e2e_latency: Option<std::time::Duration>,
1766 cx: &App,
1767 ) {
1768 match self.edit_prediction_model {
1769 EditPredictionModel::Zeta => {
1770 let is_cloud = !matches!(
1771 all_language_settings(None, cx).edit_predictions.provider,
1772 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1773 );
1774
1775 if is_cloud {
1776 let organization_id = self
1777 .user_store
1778 .read(cx)
1779 .current_organization()
1780 .map(|organization| organization.id.clone());
1781
1782 self.reject_predictions_tx
1783 .unbounded_send(EditPredictionRejectionPayload {
1784 rejection: EditPredictionRejection {
1785 request_id: prediction_id.to_string(),
1786 reason,
1787 was_shown,
1788 model_version,
1789 e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1790 },
1791 organization_id,
1792 })
1793 .log_err();
1794 }
1795 }
1796 EditPredictionModel::Mercury => {
1797 mercury::edit_prediction_rejected(
1798 prediction_id,
1799 was_shown,
1800 reason,
1801 self.client.http_client(),
1802 cx,
1803 );
1804 }
1805 EditPredictionModel::Fim { .. } => {}
1806 }
1807 }
1808
1809 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1810 self.projects
1811 .get(&project.entity_id())
1812 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1813 }
1814
1815 pub fn refresh_prediction_from_buffer(
1816 &mut self,
1817 project: Entity<Project>,
1818 buffer: Entity<Buffer>,
1819 position: language::Anchor,
1820 cx: &mut Context<Self>,
1821 ) {
1822 self.queue_prediction_refresh(
1823 project.clone(),
1824 PredictEditsRequestTrigger::Other,
1825 buffer.entity_id(),
1826 cx,
1827 move |this, cx| {
1828 let Some(request_task) = this
1829 .update(cx, |this, cx| {
1830 this.request_prediction(
1831 &project,
1832 &buffer,
1833 position,
1834 PredictEditsRequestTrigger::Other,
1835 cx,
1836 )
1837 })
1838 .log_err()
1839 else {
1840 return Task::ready(anyhow::Ok(None));
1841 };
1842
1843 cx.spawn(async move |_cx| {
1844 request_task.await.map(|prediction_result| {
1845 prediction_result.map(|prediction_result| {
1846 (
1847 prediction_result,
1848 PredictionRequestedBy::Buffer(buffer.entity_id()),
1849 )
1850 })
1851 })
1852 })
1853 },
1854 )
1855 }
1856
1857 pub fn refresh_prediction_from_diagnostics(
1858 &mut self,
1859 project: Entity<Project>,
1860 scope: DiagnosticSearchScope,
1861 cx: &mut Context<Self>,
1862 ) {
1863 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1864 return;
1865 }
1866
1867 if currently_following(&project, cx) {
1868 return;
1869 }
1870
1871 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1872 return;
1873 };
1874
1875 // Prefer predictions from buffer
1876 if project_state.current_prediction.is_some() {
1877 log::debug!(
1878 "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1879 );
1880 return;
1881 }
1882
1883 self.queue_prediction_refresh(
1884 project.clone(),
1885 PredictEditsRequestTrigger::Diagnostics,
1886 project.entity_id(),
1887 cx,
1888 move |this, cx| {
1889 let Some((active_buffer, snapshot, cursor_point)) = this
1890 .read_with(cx, |this, cx| {
1891 let project_state = this.projects.get(&project.entity_id())?;
1892 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1893 let snapshot = buffer.read(cx).snapshot();
1894
1895 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1896 return None;
1897 }
1898
1899 let cursor_point = position
1900 .map(|pos| pos.to_point(&snapshot))
1901 .unwrap_or_default();
1902
1903 Some((buffer, snapshot, cursor_point))
1904 })
1905 .log_err()
1906 .flatten()
1907 else {
1908 return Task::ready(anyhow::Ok(None));
1909 };
1910
1911 cx.spawn(async move |cx| {
1912 let diagnostic_search_range = match scope {
1913 DiagnosticSearchScope::Local => {
1914 let diagnostic_search_start =
1915 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1916 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1917 Point::new(diagnostic_search_start, 0)
1918 ..Point::new(diagnostic_search_end, 0)
1919 }
1920 DiagnosticSearchScope::Global => Default::default(),
1921 };
1922
1923 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1924 active_buffer,
1925 &snapshot,
1926 diagnostic_search_range,
1927 cursor_point,
1928 &project,
1929 cx,
1930 )
1931 .await?
1932 else {
1933 return anyhow::Ok(None);
1934 };
1935
1936 let Some(prediction_result) = this
1937 .update(cx, |this, cx| {
1938 this.request_prediction(
1939 &project,
1940 &jump_buffer,
1941 jump_position,
1942 PredictEditsRequestTrigger::Diagnostics,
1943 cx,
1944 )
1945 })?
1946 .await?
1947 else {
1948 return anyhow::Ok(None);
1949 };
1950
1951 this.update(cx, |this, cx| {
1952 Some((
1953 if this
1954 .get_or_init_project(&project, cx)
1955 .current_prediction
1956 .is_none()
1957 {
1958 prediction_result
1959 } else {
1960 EditPredictionResult {
1961 id: prediction_result.id,
1962 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1963 e2e_latency: prediction_result.e2e_latency,
1964 }
1965 },
1966 PredictionRequestedBy::DiagnosticsUpdate,
1967 ))
1968 })
1969 })
1970 },
1971 );
1972 }
1973
1974 fn predictions_enabled_at(
1975 snapshot: &BufferSnapshot,
1976 position: Option<language::Anchor>,
1977 cx: &App,
1978 ) -> bool {
1979 let file = snapshot.file();
1980 let all_settings = all_language_settings(file, cx);
1981 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1982 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1983 {
1984 return false;
1985 }
1986
1987 if let Some(last_position) = position {
1988 let settings = snapshot.settings_at(last_position, cx);
1989
1990 if !settings.edit_predictions_disabled_in.is_empty()
1991 && let Some(scope) = snapshot.language_scope_at(last_position)
1992 && let Some(scope_name) = scope.override_name()
1993 && settings
1994 .edit_predictions_disabled_in
1995 .iter()
1996 .any(|s| s == scope_name)
1997 {
1998 return false;
1999 }
2000 }
2001
2002 true
2003 }
2004
2005 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2006}
2007
2008fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2009 let Some(app_state) = AppState::try_global(cx) else {
2010 return false;
2011 };
2012
2013 app_state
2014 .workspace_store
2015 .read(cx)
2016 .workspaces()
2017 .filter_map(|workspace| workspace.upgrade())
2018 .any(|workspace| {
2019 workspace.read(cx).project().entity_id() == project.entity_id()
2020 && workspace
2021 .read(cx)
2022 .leader_for_pane(workspace.read(cx).active_pane())
2023 .is_some()
2024 })
2025}
2026
2027fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2028 match provider {
2029 EditPredictionProvider::Zed
2030 | EditPredictionProvider::Mercury
2031 | EditPredictionProvider::Ollama
2032 | EditPredictionProvider::OpenAiCompatibleApi
2033 | EditPredictionProvider::Experimental(_) => true,
2034 EditPredictionProvider::None
2035 | EditPredictionProvider::Copilot
2036 | EditPredictionProvider::Codestral => false,
2037 }
2038}
2039
2040impl EditPredictionStore {
2041 fn queue_prediction_refresh(
2042 &mut self,
2043 project: Entity<Project>,
2044 request_trigger: PredictEditsRequestTrigger,
2045 throttle_entity: EntityId,
2046 cx: &mut Context<Self>,
2047 do_refresh: impl FnOnce(
2048 WeakEntity<Self>,
2049 &mut AsyncApp,
2050 )
2051 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2052 + 'static,
2053 ) {
2054 fn select_throttle(
2055 project_state: &mut ProjectState,
2056 request_trigger: PredictEditsRequestTrigger,
2057 ) -> &mut Option<(EntityId, Instant)> {
2058 match request_trigger {
2059 PredictEditsRequestTrigger::Diagnostics => {
2060 &mut project_state.last_jump_prediction_refresh
2061 }
2062 _ => &mut project_state.last_edit_prediction_refresh,
2063 }
2064 }
2065
2066 let (needs_acceptance_tracking, max_pending_predictions) =
2067 match all_language_settings(None, cx).edit_predictions.provider {
2068 EditPredictionProvider::Zed
2069 | EditPredictionProvider::Mercury
2070 | EditPredictionProvider::Experimental(_) => (true, 2),
2071 EditPredictionProvider::Ollama => (false, 1),
2072 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2073 EditPredictionProvider::None
2074 | EditPredictionProvider::Copilot
2075 | EditPredictionProvider::Codestral => {
2076 log::error!("queue_prediction_refresh called with non-store provider");
2077 return;
2078 }
2079 };
2080
2081 let drop_on_cancel = !needs_acceptance_tracking;
2082 let throttle_timeout = Self::THROTTLE_TIMEOUT;
2083 let project_state = self.get_or_init_project(&project, cx);
2084 let pending_prediction_id = project_state.next_pending_prediction_id;
2085 project_state.next_pending_prediction_id += 1;
2086 let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2087
2088 let task = cx.spawn(async move |this, cx| {
2089 let throttle_wait = this
2090 .update(cx, |this, cx| {
2091 let project_state = this.get_or_init_project(&project, cx);
2092 let throttle = *select_throttle(project_state, request_trigger);
2093
2094 let now = cx.background_executor().now();
2095 throttle.and_then(|(last_entity, last_timestamp)| {
2096 if throttle_entity != last_entity {
2097 return None;
2098 }
2099 (last_timestamp + throttle_timeout).checked_duration_since(now)
2100 })
2101 })
2102 .ok()
2103 .flatten();
2104
2105 if let Some(timeout) = throttle_wait {
2106 cx.background_executor().timer(timeout).await;
2107 }
2108
2109 // If this task was cancelled before the throttle timeout expired,
2110 // do not perform a request. Also skip if another task already
2111 // proceeded since we were enqueued (duplicate).
2112 let mut is_cancelled = true;
2113 this.update(cx, |this, cx| {
2114 let project_state = this.get_or_init_project(&project, cx);
2115 let was_cancelled = project_state
2116 .cancelled_predictions
2117 .remove(&pending_prediction_id);
2118 if was_cancelled {
2119 return;
2120 }
2121
2122 // Another request has been already sent since this was enqueued
2123 if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2124 return;
2125 }
2126
2127 let new_refresh = (throttle_entity, cx.background_executor().now());
2128 *select_throttle(project_state, request_trigger) = Some(new_refresh);
2129 is_cancelled = false;
2130 })
2131 .ok();
2132 if is_cancelled {
2133 return None;
2134 }
2135
2136 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2137 let new_prediction_id = new_prediction_result
2138 .as_ref()
2139 .map(|(prediction, _)| prediction.id.clone());
2140
2141 // When a prediction completes, remove it from the pending list, and cancel
2142 // any pending predictions that were enqueued before it.
2143 this.update(cx, |this, cx| {
2144 let project_state = this.get_or_init_project(&project, cx);
2145
2146 let is_cancelled = project_state
2147 .cancelled_predictions
2148 .remove(&pending_prediction_id);
2149
2150 let new_current_prediction = if !is_cancelled
2151 && let Some((prediction_result, requested_by)) = new_prediction_result
2152 {
2153 match prediction_result.prediction {
2154 Ok(prediction) => {
2155 let new_prediction = CurrentEditPrediction {
2156 requested_by,
2157 prediction,
2158 was_shown: false,
2159 shown_with: None,
2160 e2e_latency: prediction_result.e2e_latency,
2161 };
2162
2163 if let Some(current_prediction) =
2164 project_state.current_prediction.as_ref()
2165 {
2166 if new_prediction.should_replace_prediction(¤t_prediction, cx)
2167 {
2168 this.reject_current_prediction(
2169 EditPredictionRejectReason::Replaced,
2170 &project,
2171 cx,
2172 );
2173
2174 Some(new_prediction)
2175 } else {
2176 this.reject_prediction(
2177 new_prediction.prediction.id,
2178 EditPredictionRejectReason::CurrentPreferred,
2179 false,
2180 new_prediction.prediction.model_version,
2181 Some(new_prediction.e2e_latency),
2182 cx,
2183 );
2184 None
2185 }
2186 } else {
2187 Some(new_prediction)
2188 }
2189 }
2190 Err(reject_reason) => {
2191 this.reject_prediction(
2192 prediction_result.id,
2193 reject_reason,
2194 false,
2195 None,
2196 Some(prediction_result.e2e_latency),
2197 cx,
2198 );
2199 None
2200 }
2201 }
2202 } else {
2203 None
2204 };
2205
2206 let project_state = this.get_or_init_project(&project, cx);
2207
2208 if let Some(new_prediction) = new_current_prediction {
2209 project_state.current_prediction = Some(new_prediction);
2210 }
2211
2212 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2213 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2214 if pending_prediction.id == pending_prediction_id {
2215 pending_predictions.remove(ix);
2216 for pending_prediction in pending_predictions.drain(0..ix) {
2217 project_state.cancel_pending_prediction(pending_prediction, cx)
2218 }
2219 break;
2220 }
2221 }
2222 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2223 cx.notify();
2224 })
2225 .ok();
2226
2227 new_prediction_id
2228 });
2229
2230 if project_state.pending_predictions.len() < max_pending_predictions {
2231 project_state
2232 .pending_predictions
2233 .push(PendingPrediction {
2234 id: pending_prediction_id,
2235 task,
2236 drop_on_cancel,
2237 })
2238 .unwrap();
2239 } else {
2240 let pending_prediction = project_state.pending_predictions.pop().unwrap();
2241 project_state
2242 .pending_predictions
2243 .push(PendingPrediction {
2244 id: pending_prediction_id,
2245 task,
2246 drop_on_cancel,
2247 })
2248 .unwrap();
2249 project_state.cancel_pending_prediction(pending_prediction, cx);
2250 }
2251 }
2252
2253 pub fn request_prediction(
2254 &mut self,
2255 project: &Entity<Project>,
2256 active_buffer: &Entity<Buffer>,
2257 position: language::Anchor,
2258 trigger: PredictEditsRequestTrigger,
2259 cx: &mut Context<Self>,
2260 ) -> Task<Result<Option<EditPredictionResult>>> {
2261 self.request_prediction_internal(
2262 project.clone(),
2263 active_buffer.clone(),
2264 position,
2265 trigger,
2266 cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2267 cx,
2268 )
2269 }
2270
2271 fn request_prediction_internal(
2272 &mut self,
2273 project: Entity<Project>,
2274 active_buffer: Entity<Buffer>,
2275 position: language::Anchor,
2276 trigger: PredictEditsRequestTrigger,
2277 allow_jump: bool,
2278 cx: &mut Context<Self>,
2279 ) -> Task<Result<Option<EditPredictionResult>>> {
2280 self.get_or_init_project(&project, cx);
2281 let project_state = self.projects.get(&project.entity_id()).unwrap();
2282 let stored_events = project_state.events(cx);
2283 let has_events = !stored_events.is_empty();
2284 let events: Vec<Arc<zeta_prompt::Event>> =
2285 stored_events.iter().map(|e| e.event.clone()).collect();
2286 let debug_tx = project_state.debug_tx.clone();
2287
2288 let snapshot = active_buffer.read(cx).snapshot();
2289 let cursor_point = position.to_point(&snapshot);
2290 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2291 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2292 let diagnostic_search_range =
2293 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2294
2295 let related_files = self.context_for_project(&project, cx);
2296
2297 let is_open_source = snapshot
2298 .file()
2299 .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2300 && events.iter().all(|event| event.in_open_source_repo())
2301 && related_files.iter().all(|file| file.in_open_source_repo);
2302
2303 let can_collect_data = !cfg!(test)
2304 && is_open_source
2305 && self.is_data_collection_enabled(cx)
2306 && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2307
2308 let inputs = EditPredictionModelInput {
2309 project: project.clone(),
2310 buffer: active_buffer,
2311 snapshot,
2312 position,
2313 events,
2314 related_files,
2315 trigger,
2316 diagnostic_search_range: diagnostic_search_range,
2317 debug_tx,
2318 can_collect_data,
2319 is_open_source,
2320 };
2321
2322 let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2323
2324 let task = match self.edit_prediction_model {
2325 EditPredictionModel::Zeta => {
2326 zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2327 }
2328 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2329 EditPredictionModel::Mercury => {
2330 self.mercury
2331 .request_prediction(inputs, self.credentials_provider.clone(), cx)
2332 }
2333 };
2334
2335 cx.spawn(async move |this, cx| {
2336 let prediction = task.await?;
2337
2338 // Only fall back to diagnostics-based prediction if we got a
2339 // the model had nothing to suggest for the buffer
2340 if prediction.is_none()
2341 && allow_jump
2342 && has_events
2343 && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2344 {
2345 this.update(cx, |this, cx| {
2346 this.refresh_prediction_from_diagnostics(
2347 project,
2348 DiagnosticSearchScope::Local,
2349 cx,
2350 );
2351 })?;
2352 return anyhow::Ok(None);
2353 }
2354
2355 Ok(prediction)
2356 })
2357 }
2358
2359 pub(crate) async fn next_diagnostic_location(
2360 active_buffer: Entity<Buffer>,
2361 active_buffer_snapshot: &BufferSnapshot,
2362 active_buffer_diagnostic_search_range: Range<Point>,
2363 active_buffer_cursor_point: Point,
2364 project: &Entity<Project>,
2365 cx: &mut AsyncApp,
2366 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2367 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2368 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2369 .flat_map(|(_, _, _, selections)| {
2370 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2371 })
2372 .collect();
2373
2374 let mut jump_location = active_buffer_snapshot
2375 .diagnostic_groups(None)
2376 .into_iter()
2377 .filter_map(|(_, group)| {
2378 let range = &group.entries[group.primary_ix]
2379 .range
2380 .to_point(&active_buffer_snapshot);
2381 if range.overlaps(&active_buffer_diagnostic_search_range) {
2382 return None;
2383 }
2384 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2385 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2386 });
2387 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2388 <= DIAGNOSTIC_LINES_RANGE;
2389 if near_collaborator && !near_local {
2390 return None;
2391 }
2392 Some(range.start)
2393 })
2394 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2395 .map(|position| {
2396 (
2397 active_buffer.clone(),
2398 active_buffer_snapshot.anchor_before(position),
2399 )
2400 });
2401
2402 if jump_location.is_none() {
2403 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2404 let file = buffer.file()?;
2405
2406 Some(ProjectPath {
2407 worktree_id: file.worktree_id(cx),
2408 path: file.path().clone(),
2409 })
2410 });
2411
2412 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2413 project
2414 .diagnostic_summaries(false, cx)
2415 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2416 .map(|(path, _, _)| {
2417 let shared_prefix = path
2418 .path
2419 .components()
2420 .zip(
2421 active_buffer_path
2422 .as_ref()
2423 .map(|p| p.path.components())
2424 .unwrap_or_default(),
2425 )
2426 .take_while(|(a, b)| a == b)
2427 .count();
2428 (path, shared_prefix)
2429 })
2430 .collect()
2431 });
2432
2433 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2434
2435 for (path, _) in candidates {
2436 let candidate_buffer = project
2437 .update(cx, |project, cx| project.open_buffer(path, cx))
2438 .await?;
2439
2440 let (has_collaborators, diagnostic_position) =
2441 candidate_buffer.read_with(cx, |buffer, _cx| {
2442 let snapshot = buffer.snapshot();
2443 let has_collaborators = snapshot
2444 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2445 .next()
2446 .is_some();
2447 let position = buffer
2448 .buffer_diagnostics(None)
2449 .into_iter()
2450 .min_by_key(|entry| entry.diagnostic.severity)
2451 .map(|entry| entry.range.start);
2452 (has_collaborators, position)
2453 });
2454
2455 if has_collaborators {
2456 continue;
2457 }
2458
2459 if let Some(position) = diagnostic_position {
2460 jump_location = Some((candidate_buffer, position));
2461 break;
2462 }
2463 }
2464 }
2465
2466 anyhow::Ok(jump_location)
2467 }
2468
2469 async fn send_raw_llm_request(
2470 request: RawCompletionRequest,
2471 client: Arc<Client>,
2472 custom_url: Option<Arc<Url>>,
2473 llm_token: LlmApiToken,
2474 organization_id: Option<OrganizationId>,
2475 app_version: Version,
2476 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2477 let url = if let Some(custom_url) = custom_url {
2478 custom_url.as_ref().clone()
2479 } else {
2480 client
2481 .http_client()
2482 .build_zed_llm_url("/predict_edits/raw", &[])?
2483 };
2484
2485 Self::send_api_request(
2486 |builder| {
2487 let req = builder
2488 .uri(url.as_ref())
2489 .body(serde_json::to_string(&request)?.into());
2490 Ok(req?)
2491 },
2492 client,
2493 llm_token,
2494 organization_id,
2495 app_version,
2496 true,
2497 )
2498 .await
2499 }
2500
2501 pub(crate) async fn send_v3_request(
2502 input: ZetaPromptInput,
2503 client: Arc<Client>,
2504 llm_token: LlmApiToken,
2505 organization_id: Option<OrganizationId>,
2506 app_version: Version,
2507 trigger: PredictEditsRequestTrigger,
2508 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2509 let url = client
2510 .http_client()
2511 .build_zed_llm_url("/predict_edits/v3", &[])?;
2512
2513 let request = PredictEditsV3Request { input, trigger };
2514
2515 let json_bytes = serde_json::to_vec(&request)?;
2516 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2517
2518 Self::send_api_request(
2519 |builder| {
2520 let req = builder
2521 .uri(url.as_ref())
2522 .header("Content-Encoding", "zstd")
2523 .body(compressed.clone().into());
2524 Ok(req?)
2525 },
2526 client,
2527 llm_token,
2528 organization_id,
2529 app_version,
2530 true,
2531 )
2532 .await
2533 }
2534
2535 async fn send_api_request<Res>(
2536 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2537 client: Arc<Client>,
2538 llm_token: LlmApiToken,
2539 organization_id: Option<OrganizationId>,
2540 app_version: Version,
2541 require_auth: bool,
2542 ) -> Result<(Res, Option<EditPredictionUsage>)>
2543 where
2544 Res: DeserializeOwned,
2545 {
2546 let http_client = client.http_client();
2547 let mut token = if require_auth {
2548 Some(
2549 client
2550 .acquire_llm_token(&llm_token, organization_id.clone())
2551 .await?,
2552 )
2553 } else {
2554 client
2555 .acquire_llm_token(&llm_token, organization_id.clone())
2556 .await
2557 .ok()
2558 };
2559 let mut did_retry = false;
2560
2561 loop {
2562 let request_builder = http_client::Request::builder().method(Method::POST);
2563
2564 let mut request_builder = request_builder
2565 .header("Content-Type", "application/json")
2566 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2567
2568 // Only add Authorization header if we have a token
2569 if let Some(ref token_value) = token {
2570 request_builder =
2571 request_builder.header("Authorization", format!("Bearer {}", token_value));
2572 }
2573
2574 let request = build(request_builder)?;
2575
2576 let mut response = http_client.send(request).await?;
2577
2578 if let Some(minimum_required_version) = response
2579 .headers()
2580 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2581 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2582 {
2583 anyhow::ensure!(
2584 app_version >= minimum_required_version,
2585 ZedUpdateRequiredError {
2586 minimum_version: minimum_required_version
2587 }
2588 );
2589 }
2590
2591 if response.status().is_success() {
2592 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2593
2594 let mut body = Vec::new();
2595 response.body_mut().read_to_end(&mut body).await?;
2596 return Ok((serde_json::from_slice(&body)?, usage));
2597 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2598 did_retry = true;
2599 token = Some(
2600 client
2601 .refresh_llm_token(&llm_token, organization_id.clone())
2602 .await?,
2603 );
2604 } else {
2605 let mut body = String::new();
2606 response.body_mut().read_to_string(&mut body).await?;
2607 anyhow::bail!(
2608 "Request failed with status: {:?}\nBody: {}",
2609 response.status(),
2610 body
2611 );
2612 }
2613 }
2614 }
2615
2616 pub fn refresh_context(
2617 &mut self,
2618 project: &Entity<Project>,
2619 buffer: &Entity<language::Buffer>,
2620 cursor_position: language::Anchor,
2621 cx: &mut Context<Self>,
2622 ) {
2623 self.get_or_init_project(project, cx)
2624 .context
2625 .update(cx, |store, cx| {
2626 store.refresh(buffer.clone(), cursor_position, cx);
2627 });
2628 }
2629
2630 #[cfg(feature = "cli-support")]
2631 pub fn set_context_for_buffer(
2632 &mut self,
2633 project: &Entity<Project>,
2634 related_files: Vec<RelatedFile>,
2635 cx: &mut Context<Self>,
2636 ) {
2637 self.get_or_init_project(project, cx)
2638 .context
2639 .update(cx, |store, cx| {
2640 store.set_related_files(related_files, cx);
2641 });
2642 }
2643
2644 #[cfg(feature = "cli-support")]
2645 pub fn set_recent_paths_for_project(
2646 &mut self,
2647 project: &Entity<Project>,
2648 paths: impl IntoIterator<Item = project::ProjectPath>,
2649 cx: &mut Context<Self>,
2650 ) {
2651 let project_state = self.get_or_init_project(project, cx);
2652 project_state.recent_paths = paths.into_iter().collect();
2653 }
2654
2655 fn is_file_open_source(
2656 &self,
2657 project: &Entity<Project>,
2658 file: &Arc<dyn File>,
2659 cx: &App,
2660 ) -> bool {
2661 if !file.is_local() || file.is_private() {
2662 return false;
2663 }
2664 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2665 return false;
2666 };
2667 project_state
2668 .license_detection_watchers
2669 .get(&file.worktree_id(cx))
2670 .as_ref()
2671 .is_some_and(|watcher| watcher.is_project_open_source())
2672 }
2673
2674 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2675 self.data_collection_choice.is_enabled(cx)
2676 }
2677
2678 fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2679 let choice = KeyValueStore::global(cx)
2680 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2681 .log_err()
2682 .flatten();
2683
2684 match choice.as_deref() {
2685 Some("true") => DataCollectionChoice::Enabled,
2686 Some("false") => DataCollectionChoice::Disabled,
2687 Some(_) => {
2688 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2689 DataCollectionChoice::NotAnswered
2690 }
2691 None => DataCollectionChoice::NotAnswered,
2692 }
2693 }
2694
2695 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2696 self.data_collection_choice = self.data_collection_choice.toggle();
2697 let new_choice = self.data_collection_choice;
2698 let is_enabled = new_choice.is_enabled(cx);
2699 let kvp = KeyValueStore::global(cx);
2700 db::write_and_log(cx, move || async move {
2701 kvp.write_kvp(
2702 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2703 is_enabled.to_string(),
2704 )
2705 .await
2706 });
2707 }
2708
2709 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2710 self.shown_predictions.iter()
2711 }
2712
2713 pub fn shown_completions_len(&self) -> usize {
2714 self.shown_predictions.len()
2715 }
2716
2717 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2718 self.rated_predictions.contains(id)
2719 }
2720
2721 pub fn rate_prediction(
2722 &mut self,
2723 prediction: &EditPrediction,
2724 rating: EditPredictionRating,
2725 feedback: String,
2726 cx: &mut Context<Self>,
2727 ) {
2728 let organization = self.user_store.read(cx).current_organization();
2729
2730 self.rated_predictions.insert(prediction.id.clone());
2731
2732 cx.background_spawn({
2733 let client = self.client.clone();
2734 let prediction_id = prediction.id.to_string();
2735 let inputs = serde_json::to_value(&prediction.inputs);
2736 let output = prediction
2737 .edit_preview
2738 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2739 async move {
2740 client
2741 .cloud_client()
2742 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2743 organization_id: organization.map(|organization| organization.id.clone()),
2744 request_id: prediction_id,
2745 rating: match rating {
2746 EditPredictionRating::Positive => "positive".to_string(),
2747 EditPredictionRating::Negative => "negative".to_string(),
2748 },
2749 inputs: inputs?,
2750 output,
2751 feedback,
2752 })
2753 .await?;
2754
2755 anyhow::Ok(())
2756 }
2757 })
2758 .detach_and_log_err(cx);
2759
2760 cx.notify();
2761 }
2762}
2763
2764fn collaborator_edit_overlaps_locality_region(
2765 project_state: &ProjectState,
2766 project: &Entity<Project>,
2767 buffer: &Entity<Buffer>,
2768 snapshot: &BufferSnapshot,
2769 edit_range: &Range<Anchor>,
2770 cx: &App,
2771) -> bool {
2772 let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2773 return false;
2774 };
2775
2776 if active_buffer.entity_id() != buffer.entity_id() {
2777 return false;
2778 }
2779
2780 let locality_point_range = expand_context_syntactically_then_linewise(
2781 snapshot,
2782 (position..position).to_point(snapshot),
2783 COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2784 );
2785 let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2786
2787 edit_range.overlaps(&locality_anchor_range, snapshot)
2788}
2789
2790fn merge_trailing_events_if_needed(
2791 events: &mut VecDeque<StoredEvent>,
2792 end_snapshot: &TextBufferSnapshot,
2793 latest_snapshot: &TextBufferSnapshot,
2794 latest_edit_range: &Range<Anchor>,
2795) {
2796 if let Some(last_event) = events.back() {
2797 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2798 return;
2799 }
2800 if !latest_snapshot
2801 .version
2802 .observed_all(&last_event.new_snapshot_version)
2803 {
2804 return;
2805 }
2806 }
2807
2808 let mut next_old_event = None;
2809 let mut mergeable_count = 0;
2810 for old_event in events.iter().rev() {
2811 if let Some(next_old_event) = next_old_event
2812 && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2813 {
2814 break;
2815 }
2816 mergeable_count += 1;
2817 next_old_event = Some(old_event);
2818 }
2819
2820 if mergeable_count <= 1 {
2821 return;
2822 }
2823
2824 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2825 let oldest_event = events_to_merge.peek().unwrap();
2826 let oldest_snapshot = oldest_event.old_snapshot.clone();
2827 let newest_snapshot = end_snapshot;
2828 let mut merged_edit_range = oldest_event.total_edit_range.clone();
2829
2830 for event in events.range(events.len() - mergeable_count + 1..) {
2831 merged_edit_range =
2832 merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2833 }
2834
2835 if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2836 &oldest_snapshot,
2837 newest_snapshot,
2838 &merged_edit_range,
2839 ) {
2840 let merged_event = match oldest_event.event.as_ref() {
2841 zeta_prompt::Event::BufferChange {
2842 old_path,
2843 path,
2844 in_open_source_repo,
2845 ..
2846 } => StoredEvent {
2847 event: Arc::new(zeta_prompt::Event::BufferChange {
2848 old_path: old_path.clone(),
2849 path: path.clone(),
2850 diff,
2851 in_open_source_repo: *in_open_source_repo,
2852 predicted: events_to_merge.all(|e| {
2853 matches!(
2854 e.event.as_ref(),
2855 zeta_prompt::Event::BufferChange {
2856 predicted: true,
2857 ..
2858 }
2859 )
2860 }),
2861 }),
2862 old_snapshot: oldest_snapshot.clone(),
2863 new_snapshot_version: newest_snapshot.version.clone(),
2864 total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2865 ..newest_snapshot.anchor_before(edit_range.end),
2866 },
2867 };
2868 events.truncate(events.len() - mergeable_count);
2869 events.push_back(merged_event);
2870 }
2871}
2872
2873fn merge_anchor_ranges(
2874 left: &Range<Anchor>,
2875 right: &Range<Anchor>,
2876 snapshot: &TextBufferSnapshot,
2877) -> Range<Anchor> {
2878 let start = if left.start.cmp(&right.start, snapshot).is_le() {
2879 left.start
2880 } else {
2881 right.start
2882 };
2883 let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2884 left.end
2885 } else {
2886 right.end
2887 };
2888 start..end
2889}
2890
2891#[derive(Error, Debug)]
2892#[error(
2893 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2894)]
2895pub struct ZedUpdateRequiredError {
2896 minimum_version: Version,
2897}
2898
2899#[derive(Debug, Clone, Copy)]
2900pub enum DataCollectionChoice {
2901 NotAnswered,
2902 Enabled,
2903 Disabled,
2904}
2905
2906impl DataCollectionChoice {
2907 pub fn is_enabled(self, cx: &App) -> bool {
2908 if cx.is_staff() {
2909 return true;
2910 }
2911 match self {
2912 Self::Enabled => true,
2913 Self::NotAnswered | Self::Disabled => false,
2914 }
2915 }
2916
2917 #[must_use]
2918 pub fn toggle(&self) -> DataCollectionChoice {
2919 match self {
2920 Self::Enabled => Self::Disabled,
2921 Self::Disabled => Self::Enabled,
2922 Self::NotAnswered => Self::Enabled,
2923 }
2924 }
2925}
2926
2927impl From<bool> for DataCollectionChoice {
2928 fn from(value: bool) -> Self {
2929 match value {
2930 true => DataCollectionChoice::Enabled,
2931 false => DataCollectionChoice::Disabled,
2932 }
2933 }
2934}
2935
2936struct ZedPredictUpsell;
2937
2938impl Dismissable for ZedPredictUpsell {
2939 const KEY: &'static str = "dismissed-edit-predict-upsell";
2940
2941 fn dismissed(cx: &App) -> bool {
2942 // To make this backwards compatible with older versions of Zed, we
2943 // check if the user has seen the previous Edit Prediction Onboarding
2944 // before, by checking the data collection choice which was written to
2945 // the database once the user clicked on "Accept and Enable"
2946 let kvp = KeyValueStore::global(cx);
2947 if kvp
2948 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2949 .log_err()
2950 .is_some_and(|s| s.is_some())
2951 {
2952 return true;
2953 }
2954
2955 kvp.read_kvp(Self::KEY)
2956 .log_err()
2957 .is_some_and(|s| s.is_some())
2958 }
2959}
2960
2961pub fn should_show_upsell_modal(cx: &App) -> bool {
2962 !ZedPredictUpsell::dismissed(cx)
2963}
2964
2965pub fn init(cx: &mut App) {
2966 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2967 workspace.register_action(
2968 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2969 ZedPredictModal::toggle(
2970 workspace,
2971 workspace.user_store().clone(),
2972 workspace.client().clone(),
2973 window,
2974 cx,
2975 )
2976 },
2977 );
2978
2979 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2980 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2981 settings
2982 .project
2983 .all_languages
2984 .edit_predictions
2985 .get_or_insert_default()
2986 .provider = Some(EditPredictionProvider::None)
2987 });
2988 });
2989 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2990 EditPredictionStore::try_global(cx).and_then(|store| {
2991 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2992 })
2993 }
2994
2995 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2996 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2997 copilot_ui::initiate_sign_in(copilot, window, cx);
2998 }
2999 });
3000 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3001 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3002 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3003 }
3004 });
3005 workspace.register_action(|workspace, _: &SignOut, window, cx| {
3006 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3007 copilot_ui::initiate_sign_out(copilot, window, cx);
3008 }
3009 });
3010 })
3011 .detach();
3012}