1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::SubmitEditPredictionFeedbackBody;
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::Edit;
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod cursor_excerpt;
59pub mod example_spec;
60pub mod fim;
61mod license_detection;
62pub mod mercury;
63pub mod ollama;
64mod onboarding_modal;
65pub mod open_ai_response;
66mod prediction;
67pub mod sweep_ai;
68
69pub mod udiff;
70
71mod capture_example;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::license_detection::LicenseDetectionWatcher;
79use crate::mercury::Mercury;
80use crate::onboarding_modal::ZedPredictModal;
81pub use crate::prediction::EditPrediction;
82pub use crate::prediction::EditPredictionId;
83use crate::prediction::EditPredictionResult;
84pub use crate::sweep_ai::SweepAi;
85pub use capture_example::capture_example;
86pub use language_model::ApiKeyState;
87pub use telemetry_events::EditPredictionRating;
88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
89
90actions!(
91 edit_prediction,
92 [
93 /// Resets the edit prediction onboarding state.
94 ResetOnboarding,
95 /// Clears the edit prediction history.
96 ClearHistory,
97 ]
98);
99
100/// Maximum number of events to track.
101const EVENT_COUNT_MAX: usize = 6;
102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
106
107pub struct Zeta2FeatureFlag;
108pub struct EditPredictionJumpsFeatureFlag;
109
110impl FeatureFlag for Zeta2FeatureFlag {
111 const NAME: &'static str = "zeta2";
112}
113
114impl FeatureFlag for EditPredictionJumpsFeatureFlag {
115 const NAME: &'static str = "edit_prediction_jumps";
116}
117
118#[derive(Clone)]
119struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
120
121impl Global for EditPredictionStoreGlobal {}
122
123/// Configuration for using the raw Zeta2 endpoint.
124/// When set, the client uses the raw endpoint and constructs the prompt itself.
125/// The version is also used as the Baseten environment name (lowercased).
126#[derive(Clone)]
127pub struct Zeta2RawConfig {
128 pub model_id: Option<String>,
129 pub format: ZetaFormat,
130}
131
132pub struct EditPredictionStore {
133 client: Arc<Client>,
134 user_store: Entity<UserStore>,
135 llm_token: LlmApiToken,
136 _llm_token_subscription: Subscription,
137 projects: HashMap<EntityId, ProjectState>,
138 update_required: bool,
139 edit_prediction_model: EditPredictionModel,
140 zeta2_raw_config: Option<Zeta2RawConfig>,
141 pub sweep_ai: SweepAi,
142 pub mercury: Mercury,
143 data_collection_choice: DataCollectionChoice,
144 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
145 shown_predictions: VecDeque<EditPrediction>,
146 rated_predictions: HashSet<EditPredictionId>,
147}
148
149#[derive(Copy, Clone, PartialEq, Eq)]
150pub enum EditPredictionModel {
151 Zeta1,
152 Zeta2,
153 Fim { format: EditPredictionPromptFormat },
154 Sweep,
155 Mercury,
156}
157
158#[derive(Clone)]
159pub struct EditPredictionModelInput {
160 project: Entity<Project>,
161 buffer: Entity<Buffer>,
162 snapshot: BufferSnapshot,
163 position: Anchor,
164 events: Vec<Arc<zeta_prompt::Event>>,
165 related_files: Vec<RelatedFile>,
166 recent_paths: VecDeque<ProjectPath>,
167 trigger: PredictEditsRequestTrigger,
168 diagnostic_search_range: Range<Point>,
169 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
170 pub user_actions: Vec<UserActionRecord>,
171}
172
173#[derive(Debug)]
174pub enum DebugEvent {
175 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
176 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
177 EditPredictionStarted(EditPredictionStartedDebugEvent),
178 EditPredictionFinished(EditPredictionFinishedDebugEvent),
179}
180
181#[derive(Debug)]
182pub struct ContextRetrievalStartedDebugEvent {
183 pub project_entity_id: EntityId,
184 pub timestamp: Instant,
185 pub search_prompt: String,
186}
187
188#[derive(Debug)]
189pub struct ContextRetrievalFinishedDebugEvent {
190 pub project_entity_id: EntityId,
191 pub timestamp: Instant,
192 pub metadata: Vec<(&'static str, SharedString)>,
193}
194
195#[derive(Debug)]
196pub struct EditPredictionStartedDebugEvent {
197 pub buffer: WeakEntity<Buffer>,
198 pub position: Anchor,
199 pub prompt: Option<String>,
200}
201
202#[derive(Debug)]
203pub struct EditPredictionFinishedDebugEvent {
204 pub buffer: WeakEntity<Buffer>,
205 pub position: Anchor,
206 pub model_output: Option<String>,
207}
208
209const USER_ACTION_HISTORY_SIZE: usize = 16;
210
211#[derive(Clone, Debug)]
212pub struct UserActionRecord {
213 pub action_type: UserActionType,
214 pub buffer_id: EntityId,
215 pub line_number: u32,
216 pub offset: usize,
217 pub timestamp_epoch_ms: u64,
218}
219
220#[derive(Clone, Copy, Debug, PartialEq, Eq)]
221pub enum UserActionType {
222 InsertChar,
223 InsertSelection,
224 DeleteChar,
225 DeleteSelection,
226 CursorMovement,
227}
228
229/// An event with associated metadata for reconstructing buffer state.
230#[derive(Clone)]
231pub struct StoredEvent {
232 pub event: Arc<zeta_prompt::Event>,
233 pub old_snapshot: TextBufferSnapshot,
234 pub edit_range: Range<Anchor>,
235}
236
237impl StoredEvent {
238 fn can_merge(
239 &self,
240 next_old_event: &&&StoredEvent,
241 new_snapshot: &TextBufferSnapshot,
242 last_edit_range: &Range<Anchor>,
243 ) -> bool {
244 // Events must be for the same buffer
245 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
246 return false;
247 }
248 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
249 return false;
250 }
251
252 let a_is_predicted = matches!(
253 self.event.as_ref(),
254 zeta_prompt::Event::BufferChange {
255 predicted: true,
256 ..
257 }
258 );
259 let b_is_predicted = matches!(
260 next_old_event.event.as_ref(),
261 zeta_prompt::Event::BufferChange {
262 predicted: true,
263 ..
264 }
265 );
266
267 // If events come from the same source (both predicted or both manual) then
268 // we would have coalesced them already.
269 if a_is_predicted == b_is_predicted {
270 return false;
271 }
272
273 let left_range = self.edit_range.to_point(new_snapshot);
274 let right_range = next_old_event.edit_range.to_point(new_snapshot);
275 let latest_range = last_edit_range.to_point(&new_snapshot);
276
277 // Events near to the latest edit are not merged if their sources differ.
278 if lines_between_ranges(&left_range, &latest_range)
279 .min(lines_between_ranges(&right_range, &latest_range))
280 <= CHANGE_GROUPING_LINE_SPAN
281 {
282 return false;
283 }
284
285 // Events that are distant from each other are not merged.
286 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
287 return false;
288 }
289
290 true
291 }
292}
293
294fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
295 if left.start > right.end {
296 return left.start.row - right.end.row;
297 }
298 if right.start > left.end {
299 return right.start.row - left.end.row;
300 }
301 0
302}
303
304struct ProjectState {
305 events: VecDeque<StoredEvent>,
306 last_event: Option<LastEvent>,
307 recent_paths: VecDeque<ProjectPath>,
308 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
309 current_prediction: Option<CurrentEditPrediction>,
310 next_pending_prediction_id: usize,
311 pending_predictions: ArrayVec<PendingPrediction, 2>,
312 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
313 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
314 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
315 cancelled_predictions: HashSet<usize>,
316 context: Entity<RelatedExcerptStore>,
317 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
318 user_actions: VecDeque<UserActionRecord>,
319 _subscriptions: [gpui::Subscription; 2],
320 copilot: Option<Entity<Copilot>>,
321}
322
323impl ProjectState {
324 fn record_user_action(&mut self, action: UserActionRecord) {
325 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
326 self.user_actions.pop_front();
327 }
328 self.user_actions.push_back(action);
329 }
330
331 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
332 self.events
333 .iter()
334 .cloned()
335 .chain(self.last_event.as_ref().iter().flat_map(|event| {
336 let (one, two) = event.split_by_pause();
337 let one = one.finalize(&self.license_detection_watchers, cx);
338 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
339 one.into_iter().chain(two)
340 }))
341 .collect()
342 }
343
344 fn cancel_pending_prediction(
345 &mut self,
346 pending_prediction: PendingPrediction,
347 cx: &mut Context<EditPredictionStore>,
348 ) {
349 self.cancelled_predictions.insert(pending_prediction.id);
350
351 if pending_prediction.drop_on_cancel {
352 drop(pending_prediction.task);
353 } else {
354 cx.spawn(async move |this, cx| {
355 let Some(prediction_id) = pending_prediction.task.await else {
356 return;
357 };
358
359 this.update(cx, |this, cx| {
360 this.reject_prediction(
361 prediction_id,
362 EditPredictionRejectReason::Canceled,
363 false,
364 None,
365 cx,
366 );
367 })
368 .ok();
369 })
370 .detach()
371 }
372 }
373
374 fn active_buffer(
375 &self,
376 project: &Entity<Project>,
377 cx: &App,
378 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
379 let project = project.read(cx);
380 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
381 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
382 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
383 Some((active_buffer, registered_buffer.last_position))
384 }
385}
386
387#[derive(Debug, Clone)]
388struct CurrentEditPrediction {
389 pub requested_by: PredictionRequestedBy,
390 pub prediction: EditPrediction,
391 pub was_shown: bool,
392 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
393}
394
395impl CurrentEditPrediction {
396 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
397 let Some(new_edits) = self
398 .prediction
399 .interpolate(&self.prediction.buffer.read(cx))
400 else {
401 return false;
402 };
403
404 if self.prediction.buffer != old_prediction.prediction.buffer {
405 return true;
406 }
407
408 let Some(old_edits) = old_prediction
409 .prediction
410 .interpolate(&old_prediction.prediction.buffer.read(cx))
411 else {
412 return true;
413 };
414
415 let requested_by_buffer_id = self.requested_by.buffer_id();
416
417 // This reduces the occurrence of UI thrash from replacing edits
418 //
419 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
420 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
421 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
422 && old_edits.len() == 1
423 && new_edits.len() == 1
424 {
425 let (old_range, old_text) = &old_edits[0];
426 let (new_range, new_text) = &new_edits[0];
427 new_range == old_range && new_text.starts_with(old_text.as_ref())
428 } else {
429 true
430 }
431 }
432}
433
434#[derive(Debug, Clone)]
435enum PredictionRequestedBy {
436 DiagnosticsUpdate,
437 Buffer(EntityId),
438}
439
440impl PredictionRequestedBy {
441 pub fn buffer_id(&self) -> Option<EntityId> {
442 match self {
443 PredictionRequestedBy::DiagnosticsUpdate => None,
444 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
445 }
446 }
447}
448
449const DIAGNOSTIC_LINES_RANGE: u32 = 20;
450
451#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
452pub enum DiagnosticSearchScope {
453 Local,
454 Global,
455}
456
457#[derive(Debug)]
458struct PendingPrediction {
459 id: usize,
460 task: Task<Option<EditPredictionId>>,
461 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
462 /// If false, the task is awaited to completion so rejection can be reported.
463 drop_on_cancel: bool,
464}
465
466/// A prediction from the perspective of a buffer.
467#[derive(Debug)]
468enum BufferEditPrediction<'a> {
469 Local { prediction: &'a EditPrediction },
470 Jump { prediction: &'a EditPrediction },
471}
472
473#[cfg(test)]
474impl std::ops::Deref for BufferEditPrediction<'_> {
475 type Target = EditPrediction;
476
477 fn deref(&self) -> &Self::Target {
478 match self {
479 BufferEditPrediction::Local { prediction } => prediction,
480 BufferEditPrediction::Jump { prediction } => prediction,
481 }
482 }
483}
484
485struct RegisteredBuffer {
486 file: Option<Arc<dyn File>>,
487 snapshot: TextBufferSnapshot,
488 last_position: Option<Anchor>,
489 _subscriptions: [gpui::Subscription; 2],
490}
491
492#[derive(Clone)]
493struct LastEvent {
494 old_snapshot: TextBufferSnapshot,
495 new_snapshot: TextBufferSnapshot,
496 old_file: Option<Arc<dyn File>>,
497 new_file: Option<Arc<dyn File>>,
498 edit_range: Option<Range<Anchor>>,
499 predicted: bool,
500 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
501 last_edit_time: Option<Instant>,
502}
503
504impl LastEvent {
505 pub fn finalize(
506 &self,
507 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
508 cx: &App,
509 ) -> Option<StoredEvent> {
510 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
511 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
512
513 let in_open_source_repo =
514 [self.new_file.as_ref(), self.old_file.as_ref()]
515 .iter()
516 .all(|file| {
517 file.is_some_and(|file| {
518 license_detection_watchers
519 .get(&file.worktree_id(cx))
520 .is_some_and(|watcher| watcher.is_project_open_source())
521 })
522 });
523
524 let (diff, edit_range) =
525 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
526
527 if path == old_path && diff.is_empty() {
528 None
529 } else {
530 Some(StoredEvent {
531 event: Arc::new(zeta_prompt::Event::BufferChange {
532 old_path,
533 path,
534 diff,
535 in_open_source_repo,
536 predicted: self.predicted,
537 }),
538 edit_range: self.new_snapshot.anchor_before(edit_range.start)
539 ..self.new_snapshot.anchor_before(edit_range.end),
540 old_snapshot: self.old_snapshot.clone(),
541 })
542 }
543 }
544
545 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
546 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
547 return (self.clone(), None);
548 };
549
550 let before = LastEvent {
551 old_snapshot: self.old_snapshot.clone(),
552 new_snapshot: boundary_snapshot.clone(),
553 old_file: self.old_file.clone(),
554 new_file: self.new_file.clone(),
555 edit_range: None,
556 predicted: self.predicted,
557 snapshot_after_last_editing_pause: None,
558 last_edit_time: self.last_edit_time,
559 };
560
561 let after = LastEvent {
562 old_snapshot: boundary_snapshot.clone(),
563 new_snapshot: self.new_snapshot.clone(),
564 old_file: self.old_file.clone(),
565 new_file: self.new_file.clone(),
566 edit_range: None,
567 predicted: self.predicted,
568 snapshot_after_last_editing_pause: None,
569 last_edit_time: self.last_edit_time,
570 };
571
572 (before, Some(after))
573 }
574}
575
576pub(crate) fn compute_diff_between_snapshots(
577 old_snapshot: &TextBufferSnapshot,
578 new_snapshot: &TextBufferSnapshot,
579) -> Option<(String, Range<Point>)> {
580 let edits: Vec<Edit<usize>> = new_snapshot
581 .edits_since::<usize>(&old_snapshot.version)
582 .collect();
583
584 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
585
586 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
587 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
588 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
589 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
590
591 const CONTEXT_LINES: u32 = 3;
592
593 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
594 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
595 let old_context_end_row =
596 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
597 let new_context_end_row =
598 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
599
600 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
601 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
602 let old_end_line_offset = old_snapshot
603 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
604 let new_end_line_offset = new_snapshot
605 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
606 let old_edit_range = old_start_line_offset..old_end_line_offset;
607 let new_edit_range = new_start_line_offset..new_end_line_offset;
608
609 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
610 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
611
612 let diff = language::unified_diff_with_offsets(
613 &old_region_text,
614 &new_region_text,
615 old_context_start_row,
616 new_context_start_row,
617 );
618
619 Some((diff, new_start_point..new_end_point))
620}
621
622fn buffer_path_with_id_fallback(
623 file: Option<&Arc<dyn File>>,
624 snapshot: &TextBufferSnapshot,
625 cx: &App,
626) -> Arc<Path> {
627 if let Some(file) = file {
628 file.full_path(cx).into()
629 } else {
630 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
631 }
632}
633
634impl EditPredictionStore {
635 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
636 cx.try_global::<EditPredictionStoreGlobal>()
637 .map(|global| global.0.clone())
638 }
639
640 pub fn global(
641 client: &Arc<Client>,
642 user_store: &Entity<UserStore>,
643 cx: &mut App,
644 ) -> Entity<Self> {
645 cx.try_global::<EditPredictionStoreGlobal>()
646 .map(|global| global.0.clone())
647 .unwrap_or_else(|| {
648 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
649 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
650 ep_store
651 })
652 }
653
654 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
655 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
656 let data_collection_choice = Self::load_data_collection_choice();
657
658 let llm_token = LlmApiToken::default();
659
660 let (reject_tx, reject_rx) = mpsc::unbounded();
661 cx.background_spawn({
662 let client = client.clone();
663 let llm_token = llm_token.clone();
664 let app_version = AppVersion::global(cx);
665 let background_executor = cx.background_executor().clone();
666 async move {
667 Self::handle_rejected_predictions(
668 reject_rx,
669 client,
670 llm_token,
671 app_version,
672 background_executor,
673 )
674 .await
675 }
676 })
677 .detach();
678
679 let this = Self {
680 projects: HashMap::default(),
681 client,
682 user_store,
683 llm_token,
684 _llm_token_subscription: cx.subscribe(
685 &refresh_llm_token_listener,
686 |this, _listener, _event, cx| {
687 let client = this.client.clone();
688 let llm_token = this.llm_token.clone();
689 cx.spawn(async move |_this, _cx| {
690 llm_token.refresh(&client).await?;
691 anyhow::Ok(())
692 })
693 .detach_and_log_err(cx);
694 },
695 ),
696 update_required: false,
697 edit_prediction_model: EditPredictionModel::Zeta2,
698 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
699 sweep_ai: SweepAi::new(cx),
700 mercury: Mercury::new(cx),
701
702 data_collection_choice,
703 reject_predictions_tx: reject_tx,
704 rated_predictions: Default::default(),
705 shown_predictions: Default::default(),
706 };
707
708 this
709 }
710
711 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
712 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
713 let format = ZetaFormat::parse(&version_str).ok()?;
714 let model_id = env::var("ZED_ZETA_MODEL").ok();
715 Some(Zeta2RawConfig { model_id, format })
716 }
717
718 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
719 self.edit_prediction_model = model;
720 }
721
722 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
723 self.zeta2_raw_config = Some(config);
724 }
725
726 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
727 self.zeta2_raw_config.as_ref()
728 }
729
730 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
731 use ui::IconName;
732 match self.edit_prediction_model {
733 EditPredictionModel::Sweep => {
734 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
735 .with_disabled(IconName::SweepAiDisabled)
736 .with_up(IconName::SweepAiUp)
737 .with_down(IconName::SweepAiDown)
738 .with_error(IconName::SweepAiError)
739 }
740 EditPredictionModel::Mercury => {
741 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
742 }
743 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
744 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
745 .with_disabled(IconName::ZedPredictDisabled)
746 .with_up(IconName::ZedPredictUp)
747 .with_down(IconName::ZedPredictDown)
748 .with_error(IconName::ZedPredictError)
749 }
750 EditPredictionModel::Fim { .. } => {
751 let settings = &all_language_settings(None, cx).edit_predictions;
752 match settings.provider {
753 EditPredictionProvider::Ollama => {
754 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
755 }
756 _ => {
757 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
758 }
759 }
760 }
761 }
762 }
763
764 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
765 self.sweep_ai.api_token.read(cx).has_key()
766 }
767
768 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
769 self.mercury.api_token.read(cx).has_key()
770 }
771
772 pub fn clear_history(&mut self) {
773 for project_state in self.projects.values_mut() {
774 project_state.events.clear();
775 project_state.last_event.take();
776 }
777 }
778
779 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
780 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
781 project_state.events.clear();
782 project_state.last_event.take();
783 }
784 }
785
786 pub fn edit_history_for_project(
787 &self,
788 project: &Entity<Project>,
789 cx: &App,
790 ) -> Vec<StoredEvent> {
791 self.projects
792 .get(&project.entity_id())
793 .map(|project_state| project_state.events(cx))
794 .unwrap_or_default()
795 }
796
797 pub fn context_for_project<'a>(
798 &'a self,
799 project: &Entity<Project>,
800 cx: &'a mut App,
801 ) -> Vec<RelatedFile> {
802 self.projects
803 .get(&project.entity_id())
804 .map(|project_state| {
805 project_state.context.update(cx, |context, cx| {
806 context
807 .related_files_with_buffers(cx)
808 .map(|(mut related_file, buffer)| {
809 related_file.in_open_source_repo = buffer
810 .read(cx)
811 .file()
812 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
813 related_file
814 })
815 .collect()
816 })
817 })
818 .unwrap_or_default()
819 }
820
821 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
822 self.projects
823 .get(&project.entity_id())
824 .and_then(|project| project.copilot.clone())
825 }
826
827 pub fn start_copilot_for_project(
828 &mut self,
829 project: &Entity<Project>,
830 cx: &mut Context<Self>,
831 ) -> Option<Entity<Copilot>> {
832 if DisableAiSettings::get(None, cx).disable_ai {
833 return None;
834 }
835 let state = self.get_or_init_project(project, cx);
836
837 if state.copilot.is_some() {
838 return state.copilot.clone();
839 }
840 let _project = project.clone();
841 let project = project.read(cx);
842
843 let node = project.node_runtime().cloned();
844 if let Some(node) = node {
845 let next_id = project.languages().next_language_server_id();
846 let fs = project.fs().clone();
847
848 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
849 state.copilot = Some(copilot.clone());
850 Some(copilot)
851 } else {
852 None
853 }
854 }
855
856 pub fn context_for_project_with_buffers<'a>(
857 &'a self,
858 project: &Entity<Project>,
859 cx: &'a mut App,
860 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
861 self.projects
862 .get(&project.entity_id())
863 .map(|project| {
864 project.context.update(cx, |context, cx| {
865 context.related_files_with_buffers(cx).collect()
866 })
867 })
868 .unwrap_or_default()
869 }
870
871 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
872 if matches!(
873 self.edit_prediction_model,
874 EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
875 ) {
876 self.user_store.read(cx).edit_prediction_usage()
877 } else {
878 None
879 }
880 }
881
882 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
883 self.get_or_init_project(project, cx);
884 }
885
886 pub fn register_buffer(
887 &mut self,
888 buffer: &Entity<Buffer>,
889 project: &Entity<Project>,
890 cx: &mut Context<Self>,
891 ) {
892 let project_state = self.get_or_init_project(project, cx);
893 Self::register_buffer_impl(project_state, buffer, project, cx);
894 }
895
896 fn get_or_init_project(
897 &mut self,
898 project: &Entity<Project>,
899 cx: &mut Context<Self>,
900 ) -> &mut ProjectState {
901 let entity_id = project.entity_id();
902 self.projects
903 .entry(entity_id)
904 .or_insert_with(|| ProjectState {
905 context: {
906 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
907 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
908 this.handle_excerpt_store_event(entity_id, event);
909 })
910 .detach();
911 related_excerpt_store
912 },
913 events: VecDeque::new(),
914 last_event: None,
915 recent_paths: VecDeque::new(),
916 debug_tx: None,
917 registered_buffers: HashMap::default(),
918 current_prediction: None,
919 cancelled_predictions: HashSet::default(),
920 pending_predictions: ArrayVec::new(),
921 next_pending_prediction_id: 0,
922 last_edit_prediction_refresh: None,
923 last_jump_prediction_refresh: None,
924 license_detection_watchers: HashMap::default(),
925 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
926 _subscriptions: [
927 cx.subscribe(&project, Self::handle_project_event),
928 cx.observe_release(&project, move |this, _, cx| {
929 this.projects.remove(&entity_id);
930 cx.notify();
931 }),
932 ],
933 copilot: None,
934 })
935 }
936
937 pub fn remove_project(&mut self, project: &Entity<Project>) {
938 self.projects.remove(&project.entity_id());
939 }
940
941 fn handle_excerpt_store_event(
942 &mut self,
943 project_entity_id: EntityId,
944 event: &RelatedExcerptStoreEvent,
945 ) {
946 if let Some(project_state) = self.projects.get(&project_entity_id) {
947 if let Some(debug_tx) = project_state.debug_tx.clone() {
948 match event {
949 RelatedExcerptStoreEvent::StartedRefresh => {
950 debug_tx
951 .unbounded_send(DebugEvent::ContextRetrievalStarted(
952 ContextRetrievalStartedDebugEvent {
953 project_entity_id: project_entity_id,
954 timestamp: Instant::now(),
955 search_prompt: String::new(),
956 },
957 ))
958 .ok();
959 }
960 RelatedExcerptStoreEvent::FinishedRefresh {
961 cache_hit_count,
962 cache_miss_count,
963 mean_definition_latency,
964 max_definition_latency,
965 } => {
966 debug_tx
967 .unbounded_send(DebugEvent::ContextRetrievalFinished(
968 ContextRetrievalFinishedDebugEvent {
969 project_entity_id: project_entity_id,
970 timestamp: Instant::now(),
971 metadata: vec![
972 (
973 "Cache Hits",
974 format!(
975 "{}/{}",
976 cache_hit_count,
977 cache_hit_count + cache_miss_count
978 )
979 .into(),
980 ),
981 (
982 "Max LSP Time",
983 format!("{} ms", max_definition_latency.as_millis())
984 .into(),
985 ),
986 (
987 "Mean LSP Time",
988 format!("{} ms", mean_definition_latency.as_millis())
989 .into(),
990 ),
991 ],
992 },
993 ))
994 .ok();
995 }
996 }
997 }
998 }
999 }
1000
1001 pub fn debug_info(
1002 &mut self,
1003 project: &Entity<Project>,
1004 cx: &mut Context<Self>,
1005 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1006 let project_state = self.get_or_init_project(project, cx);
1007 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1008 project_state.debug_tx = Some(debug_watch_tx);
1009 debug_watch_rx
1010 }
1011
1012 fn handle_project_event(
1013 &mut self,
1014 project: Entity<Project>,
1015 event: &project::Event,
1016 cx: &mut Context<Self>,
1017 ) {
1018 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1019 return;
1020 }
1021 // TODO [zeta2] init with recent paths
1022 match event {
1023 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1024 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1025 return;
1026 };
1027 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1028 if let Some(path) = path {
1029 if let Some(ix) = project_state
1030 .recent_paths
1031 .iter()
1032 .position(|probe| probe == &path)
1033 {
1034 project_state.recent_paths.remove(ix);
1035 }
1036 project_state.recent_paths.push_front(path);
1037 }
1038 }
1039 project::Event::DiagnosticsUpdated { .. } => {
1040 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1041 self.refresh_prediction_from_diagnostics(
1042 project,
1043 DiagnosticSearchScope::Global,
1044 cx,
1045 );
1046 }
1047 }
1048 _ => (),
1049 }
1050 }
1051
1052 fn register_buffer_impl<'a>(
1053 project_state: &'a mut ProjectState,
1054 buffer: &Entity<Buffer>,
1055 project: &Entity<Project>,
1056 cx: &mut Context<Self>,
1057 ) -> &'a mut RegisteredBuffer {
1058 let buffer_id = buffer.entity_id();
1059
1060 if let Some(file) = buffer.read(cx).file() {
1061 let worktree_id = file.worktree_id(cx);
1062 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1063 project_state
1064 .license_detection_watchers
1065 .entry(worktree_id)
1066 .or_insert_with(|| {
1067 let project_entity_id = project.entity_id();
1068 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1069 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1070 else {
1071 return;
1072 };
1073 project_state
1074 .license_detection_watchers
1075 .remove(&worktree_id);
1076 })
1077 .detach();
1078 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1079 });
1080 }
1081 }
1082
1083 match project_state.registered_buffers.entry(buffer_id) {
1084 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1085 hash_map::Entry::Vacant(entry) => {
1086 let buf = buffer.read(cx);
1087 let snapshot = buf.text_snapshot();
1088 let file = buf.file().cloned();
1089 let project_entity_id = project.entity_id();
1090 entry.insert(RegisteredBuffer {
1091 snapshot,
1092 file,
1093 last_position: None,
1094 _subscriptions: [
1095 cx.subscribe(buffer, {
1096 let project = project.downgrade();
1097 move |this, buffer, event, cx| {
1098 if let language::BufferEvent::Edited = event
1099 && let Some(project) = project.upgrade()
1100 {
1101 this.report_changes_for_buffer(&buffer, &project, false, cx);
1102 }
1103 }
1104 }),
1105 cx.observe_release(buffer, move |this, _buffer, _cx| {
1106 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1107 else {
1108 return;
1109 };
1110 project_state.registered_buffers.remove(&buffer_id);
1111 }),
1112 ],
1113 })
1114 }
1115 }
1116 }
1117
1118 fn report_changes_for_buffer(
1119 &mut self,
1120 buffer: &Entity<Buffer>,
1121 project: &Entity<Project>,
1122 is_predicted: bool,
1123 cx: &mut Context<Self>,
1124 ) {
1125 let project_state = self.get_or_init_project(project, cx);
1126 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1127
1128 let buf = buffer.read(cx);
1129 let new_file = buf.file().cloned();
1130 let new_snapshot = buf.text_snapshot();
1131 if new_snapshot.version == registered_buffer.snapshot.version {
1132 return;
1133 }
1134
1135 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1136 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1137 let mut num_edits = 0usize;
1138 let mut total_deleted = 0usize;
1139 let mut total_inserted = 0usize;
1140 let mut edit_range: Option<Range<Anchor>> = None;
1141 let mut last_offset: Option<usize> = None;
1142
1143 for (edit, anchor_range) in
1144 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1145 {
1146 num_edits += 1;
1147 total_deleted += edit.old.len();
1148 total_inserted += edit.new.len();
1149 edit_range = Some(match edit_range {
1150 None => anchor_range,
1151 Some(acc) => acc.start..anchor_range.end,
1152 });
1153 last_offset = Some(edit.new.end);
1154 }
1155
1156 let Some(edit_range) = edit_range else {
1157 return;
1158 };
1159
1160 let action_type = match (total_deleted, total_inserted, num_edits) {
1161 (0, ins, n) if ins == n => UserActionType::InsertChar,
1162 (0, _, _) => UserActionType::InsertSelection,
1163 (del, 0, n) if del == n => UserActionType::DeleteChar,
1164 (_, 0, _) => UserActionType::DeleteSelection,
1165 (_, ins, n) if ins == n => UserActionType::InsertChar,
1166 (_, _, _) => UserActionType::InsertSelection,
1167 };
1168
1169 if let Some(offset) = last_offset {
1170 let point = new_snapshot.offset_to_point(offset);
1171 let timestamp_epoch_ms = SystemTime::now()
1172 .duration_since(UNIX_EPOCH)
1173 .map(|d| d.as_millis() as u64)
1174 .unwrap_or(0);
1175 project_state.record_user_action(UserActionRecord {
1176 action_type,
1177 buffer_id: buffer.entity_id(),
1178 line_number: point.row,
1179 offset,
1180 timestamp_epoch_ms,
1181 });
1182 }
1183
1184 let events = &mut project_state.events;
1185
1186 let now = cx.background_executor().now();
1187 if let Some(last_event) = project_state.last_event.as_mut() {
1188 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1189 == last_event.new_snapshot.remote_id()
1190 && old_snapshot.version == last_event.new_snapshot.version;
1191
1192 let prediction_source_changed = is_predicted != last_event.predicted;
1193
1194 let should_coalesce = is_next_snapshot_of_same_buffer
1195 && !prediction_source_changed
1196 && last_event
1197 .edit_range
1198 .as_ref()
1199 .is_some_and(|last_edit_range| {
1200 lines_between_ranges(
1201 &edit_range.to_point(&new_snapshot),
1202 &last_edit_range.to_point(&new_snapshot),
1203 ) <= CHANGE_GROUPING_LINE_SPAN
1204 });
1205
1206 if should_coalesce {
1207 let pause_elapsed = last_event
1208 .last_edit_time
1209 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1210 .unwrap_or(false);
1211 if pause_elapsed {
1212 last_event.snapshot_after_last_editing_pause =
1213 Some(last_event.new_snapshot.clone());
1214 }
1215
1216 last_event.edit_range = Some(edit_range);
1217 last_event.new_snapshot = new_snapshot;
1218 last_event.last_edit_time = Some(now);
1219 return;
1220 }
1221 }
1222
1223 if let Some(event) = project_state.last_event.take() {
1224 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1225 if events.len() + 1 >= EVENT_COUNT_MAX {
1226 events.pop_front();
1227 }
1228 events.push_back(event);
1229 }
1230 }
1231
1232 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1233
1234 project_state.last_event = Some(LastEvent {
1235 old_file,
1236 new_file,
1237 old_snapshot,
1238 new_snapshot,
1239 edit_range: Some(edit_range),
1240 predicted: is_predicted,
1241 snapshot_after_last_editing_pause: None,
1242 last_edit_time: Some(now),
1243 });
1244 }
1245
1246 fn prediction_at(
1247 &mut self,
1248 buffer: &Entity<Buffer>,
1249 position: Option<language::Anchor>,
1250 project: &Entity<Project>,
1251 cx: &App,
1252 ) -> Option<BufferEditPrediction<'_>> {
1253 let project_state = self.projects.get_mut(&project.entity_id())?;
1254 if let Some(position) = position
1255 && let Some(buffer) = project_state
1256 .registered_buffers
1257 .get_mut(&buffer.entity_id())
1258 {
1259 buffer.last_position = Some(position);
1260 }
1261
1262 let CurrentEditPrediction {
1263 requested_by,
1264 prediction,
1265 ..
1266 } = project_state.current_prediction.as_ref()?;
1267
1268 if prediction.targets_buffer(buffer.read(cx)) {
1269 Some(BufferEditPrediction::Local { prediction })
1270 } else {
1271 let show_jump = match requested_by {
1272 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1273 requested_by_buffer_id == &buffer.entity_id()
1274 }
1275 PredictionRequestedBy::DiagnosticsUpdate => true,
1276 };
1277
1278 if show_jump {
1279 Some(BufferEditPrediction::Jump { prediction })
1280 } else {
1281 None
1282 }
1283 }
1284 }
1285
1286 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1287 let Some(current_prediction) = self
1288 .projects
1289 .get_mut(&project.entity_id())
1290 .and_then(|project_state| project_state.current_prediction.take())
1291 else {
1292 return;
1293 };
1294
1295 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1296
1297 // can't hold &mut project_state ref across report_changes_for_buffer_call
1298 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1299 return;
1300 };
1301
1302 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1303 project_state.cancel_pending_prediction(pending_prediction, cx);
1304 }
1305
1306 match self.edit_prediction_model {
1307 EditPredictionModel::Sweep => {
1308 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1309 }
1310 EditPredictionModel::Mercury => {
1311 mercury::edit_prediction_accepted(
1312 current_prediction.prediction.id,
1313 self.client.http_client(),
1314 cx,
1315 );
1316 }
1317 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1318 let is_cloud = !matches!(
1319 all_language_settings(None, cx).edit_predictions.provider,
1320 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1321 );
1322 if is_cloud {
1323 zeta::edit_prediction_accepted(self, current_prediction, cx)
1324 }
1325 }
1326 EditPredictionModel::Fim { .. } => {}
1327 }
1328 }
1329
1330 async fn handle_rejected_predictions(
1331 rx: UnboundedReceiver<EditPredictionRejection>,
1332 client: Arc<Client>,
1333 llm_token: LlmApiToken,
1334 app_version: Version,
1335 background_executor: BackgroundExecutor,
1336 ) {
1337 let mut rx = std::pin::pin!(rx.peekable());
1338 let mut batched = Vec::new();
1339
1340 while let Some(rejection) = rx.next().await {
1341 batched.push(rejection);
1342
1343 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1344 select_biased! {
1345 next = rx.as_mut().peek().fuse() => {
1346 if next.is_some() {
1347 continue;
1348 }
1349 }
1350 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1351 }
1352 }
1353
1354 let url = client
1355 .http_client()
1356 .build_zed_llm_url("/predict_edits/reject", &[])
1357 .unwrap();
1358
1359 let flush_count = batched
1360 .len()
1361 // in case items have accumulated after failure
1362 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1363 let start = batched.len() - flush_count;
1364
1365 let body = RejectEditPredictionsBodyRef {
1366 rejections: &batched[start..],
1367 };
1368
1369 let result = Self::send_api_request::<()>(
1370 |builder| {
1371 let req = builder
1372 .uri(url.as_ref())
1373 .body(serde_json::to_string(&body)?.into());
1374 anyhow::Ok(req?)
1375 },
1376 client.clone(),
1377 llm_token.clone(),
1378 app_version.clone(),
1379 true,
1380 )
1381 .await;
1382
1383 if result.log_err().is_some() {
1384 batched.drain(start..);
1385 }
1386 }
1387 }
1388
1389 fn reject_current_prediction(
1390 &mut self,
1391 reason: EditPredictionRejectReason,
1392 project: &Entity<Project>,
1393 cx: &App,
1394 ) {
1395 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1396 project_state.pending_predictions.clear();
1397 if let Some(prediction) = project_state.current_prediction.take() {
1398 let model_version = prediction.prediction.model_version.clone();
1399 self.reject_prediction(
1400 prediction.prediction.id,
1401 reason,
1402 prediction.was_shown,
1403 model_version,
1404 cx,
1405 );
1406 }
1407 };
1408 }
1409
1410 fn did_show_current_prediction(
1411 &mut self,
1412 project: &Entity<Project>,
1413 display_type: edit_prediction_types::SuggestionDisplayType,
1414 cx: &mut Context<Self>,
1415 ) {
1416 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1417 return;
1418 };
1419
1420 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1421 return;
1422 };
1423
1424 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1425 let previous_shown_with = current_prediction.shown_with;
1426
1427 if previous_shown_with.is_none() || !is_jump {
1428 current_prediction.shown_with = Some(display_type);
1429 }
1430
1431 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1432
1433 if is_first_non_jump_show {
1434 current_prediction.was_shown = true;
1435 }
1436
1437 let display_type_changed = previous_shown_with != Some(display_type);
1438
1439 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1440 sweep_ai::edit_prediction_shown(
1441 &self.sweep_ai,
1442 self.client.clone(),
1443 ¤t_prediction.prediction,
1444 display_type,
1445 cx,
1446 );
1447 }
1448
1449 if is_first_non_jump_show {
1450 self.shown_predictions
1451 .push_front(current_prediction.prediction.clone());
1452 if self.shown_predictions.len() > 50 {
1453 let completion = self.shown_predictions.pop_back().unwrap();
1454 self.rated_predictions.remove(&completion.id);
1455 }
1456 }
1457 }
1458
1459 fn reject_prediction(
1460 &mut self,
1461 prediction_id: EditPredictionId,
1462 reason: EditPredictionRejectReason,
1463 was_shown: bool,
1464 model_version: Option<String>,
1465 cx: &App,
1466 ) {
1467 match self.edit_prediction_model {
1468 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1469 let is_cloud = !matches!(
1470 all_language_settings(None, cx).edit_predictions.provider,
1471 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1472 );
1473 if is_cloud {
1474 self.reject_predictions_tx
1475 .unbounded_send(EditPredictionRejection {
1476 request_id: prediction_id.to_string(),
1477 reason,
1478 was_shown,
1479 model_version,
1480 })
1481 .log_err();
1482 }
1483 }
1484 EditPredictionModel::Mercury => {
1485 mercury::edit_prediction_rejected(
1486 prediction_id,
1487 was_shown,
1488 reason,
1489 self.client.http_client(),
1490 cx,
1491 );
1492 }
1493 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1494 }
1495 }
1496
1497 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1498 self.projects
1499 .get(&project.entity_id())
1500 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1501 }
1502
1503 pub fn refresh_prediction_from_buffer(
1504 &mut self,
1505 project: Entity<Project>,
1506 buffer: Entity<Buffer>,
1507 position: language::Anchor,
1508 cx: &mut Context<Self>,
1509 ) {
1510 self.queue_prediction_refresh(
1511 project.clone(),
1512 PredictEditsRequestTrigger::Other,
1513 buffer.entity_id(),
1514 cx,
1515 move |this, cx| {
1516 let Some(request_task) = this
1517 .update(cx, |this, cx| {
1518 this.request_prediction(
1519 &project,
1520 &buffer,
1521 position,
1522 PredictEditsRequestTrigger::Other,
1523 cx,
1524 )
1525 })
1526 .log_err()
1527 else {
1528 return Task::ready(anyhow::Ok(None));
1529 };
1530
1531 cx.spawn(async move |_cx| {
1532 request_task.await.map(|prediction_result| {
1533 prediction_result.map(|prediction_result| {
1534 (
1535 prediction_result,
1536 PredictionRequestedBy::Buffer(buffer.entity_id()),
1537 )
1538 })
1539 })
1540 })
1541 },
1542 )
1543 }
1544
1545 pub fn refresh_prediction_from_diagnostics(
1546 &mut self,
1547 project: Entity<Project>,
1548 scope: DiagnosticSearchScope,
1549 cx: &mut Context<Self>,
1550 ) {
1551 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1552 return;
1553 }
1554
1555 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1556 return;
1557 };
1558
1559 // Prefer predictions from buffer
1560 if project_state.current_prediction.is_some() {
1561 return;
1562 }
1563
1564 self.queue_prediction_refresh(
1565 project.clone(),
1566 PredictEditsRequestTrigger::Diagnostics,
1567 project.entity_id(),
1568 cx,
1569 move |this, cx| {
1570 let Some((active_buffer, snapshot, cursor_point)) = this
1571 .read_with(cx, |this, cx| {
1572 let project_state = this.projects.get(&project.entity_id())?;
1573 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1574 let snapshot = buffer.read(cx).snapshot();
1575
1576 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1577 return None;
1578 }
1579
1580 let cursor_point = position
1581 .map(|pos| pos.to_point(&snapshot))
1582 .unwrap_or_default();
1583
1584 Some((buffer, snapshot, cursor_point))
1585 })
1586 .log_err()
1587 .flatten()
1588 else {
1589 return Task::ready(anyhow::Ok(None));
1590 };
1591
1592 cx.spawn(async move |cx| {
1593 let diagnostic_search_range = match scope {
1594 DiagnosticSearchScope::Local => {
1595 let diagnostic_search_start =
1596 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1597 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1598 Point::new(diagnostic_search_start, 0)
1599 ..Point::new(diagnostic_search_end, 0)
1600 }
1601 DiagnosticSearchScope::Global => Default::default(),
1602 };
1603
1604 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1605 active_buffer,
1606 &snapshot,
1607 diagnostic_search_range,
1608 cursor_point,
1609 &project,
1610 cx,
1611 )
1612 .await?
1613 else {
1614 return anyhow::Ok(None);
1615 };
1616
1617 let Some(prediction_result) = this
1618 .update(cx, |this, cx| {
1619 this.request_prediction(
1620 &project,
1621 &jump_buffer,
1622 jump_position,
1623 PredictEditsRequestTrigger::Diagnostics,
1624 cx,
1625 )
1626 })?
1627 .await?
1628 else {
1629 return anyhow::Ok(None);
1630 };
1631
1632 this.update(cx, |this, cx| {
1633 Some((
1634 if this
1635 .get_or_init_project(&project, cx)
1636 .current_prediction
1637 .is_none()
1638 {
1639 prediction_result
1640 } else {
1641 EditPredictionResult {
1642 id: prediction_result.id,
1643 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1644 }
1645 },
1646 PredictionRequestedBy::DiagnosticsUpdate,
1647 ))
1648 })
1649 })
1650 },
1651 );
1652 }
1653
1654 fn predictions_enabled_at(
1655 snapshot: &BufferSnapshot,
1656 position: Option<language::Anchor>,
1657 cx: &App,
1658 ) -> bool {
1659 let file = snapshot.file();
1660 let all_settings = all_language_settings(file, cx);
1661 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1662 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1663 {
1664 return false;
1665 }
1666
1667 if let Some(last_position) = position {
1668 let settings = snapshot.settings_at(last_position, cx);
1669
1670 if !settings.edit_predictions_disabled_in.is_empty()
1671 && let Some(scope) = snapshot.language_scope_at(last_position)
1672 && let Some(scope_name) = scope.override_name()
1673 && settings
1674 .edit_predictions_disabled_in
1675 .iter()
1676 .any(|s| s == scope_name)
1677 {
1678 return false;
1679 }
1680 }
1681
1682 true
1683 }
1684
1685 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1686}
1687
1688fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1689 match provider {
1690 EditPredictionProvider::Zed
1691 | EditPredictionProvider::Sweep
1692 | EditPredictionProvider::Mercury
1693 | EditPredictionProvider::Ollama
1694 | EditPredictionProvider::OpenAiCompatibleApi
1695 | EditPredictionProvider::Experimental(_) => true,
1696 EditPredictionProvider::None
1697 | EditPredictionProvider::Copilot
1698 | EditPredictionProvider::Supermaven
1699 | EditPredictionProvider::Codestral => false,
1700 }
1701}
1702
1703impl EditPredictionStore {
1704 fn queue_prediction_refresh(
1705 &mut self,
1706 project: Entity<Project>,
1707 request_trigger: PredictEditsRequestTrigger,
1708 throttle_entity: EntityId,
1709 cx: &mut Context<Self>,
1710 do_refresh: impl FnOnce(
1711 WeakEntity<Self>,
1712 &mut AsyncApp,
1713 )
1714 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1715 + 'static,
1716 ) {
1717 fn select_throttle(
1718 project_state: &mut ProjectState,
1719 request_trigger: PredictEditsRequestTrigger,
1720 ) -> &mut Option<(EntityId, Instant)> {
1721 match request_trigger {
1722 PredictEditsRequestTrigger::Diagnostics => {
1723 &mut project_state.last_jump_prediction_refresh
1724 }
1725 _ => &mut project_state.last_edit_prediction_refresh,
1726 }
1727 }
1728
1729 let (needs_acceptance_tracking, max_pending_predictions) =
1730 match all_language_settings(None, cx).edit_predictions.provider {
1731 EditPredictionProvider::Zed
1732 | EditPredictionProvider::Sweep
1733 | EditPredictionProvider::Mercury
1734 | EditPredictionProvider::Experimental(_) => (true, 2),
1735 EditPredictionProvider::Ollama => (false, 1),
1736 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1737 EditPredictionProvider::None
1738 | EditPredictionProvider::Copilot
1739 | EditPredictionProvider::Supermaven
1740 | EditPredictionProvider::Codestral => {
1741 log::error!("queue_prediction_refresh called with non-store provider");
1742 return;
1743 }
1744 };
1745
1746 let drop_on_cancel = !needs_acceptance_tracking;
1747 let throttle_timeout = Self::THROTTLE_TIMEOUT;
1748 let project_state = self.get_or_init_project(&project, cx);
1749 let pending_prediction_id = project_state.next_pending_prediction_id;
1750 project_state.next_pending_prediction_id += 1;
1751 let last_request = *select_throttle(project_state, request_trigger);
1752
1753 let task = cx.spawn(async move |this, cx| {
1754 if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1755 if throttle_entity != last_entity {
1756 return None;
1757 }
1758 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1759 }) {
1760 cx.background_executor().timer(timeout).await;
1761 }
1762
1763 // If this task was cancelled before the throttle timeout expired,
1764 // do not perform a request.
1765 let mut is_cancelled = true;
1766 this.update(cx, |this, cx| {
1767 let project_state = this.get_or_init_project(&project, cx);
1768 let was_cancelled = project_state
1769 .cancelled_predictions
1770 .remove(&pending_prediction_id);
1771 if !was_cancelled {
1772 let new_refresh = (throttle_entity, Instant::now());
1773 *select_throttle(project_state, request_trigger) = Some(new_refresh);
1774 is_cancelled = false;
1775 }
1776 })
1777 .ok();
1778 if is_cancelled {
1779 return None;
1780 }
1781
1782 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1783 let new_prediction_id = new_prediction_result
1784 .as_ref()
1785 .map(|(prediction, _)| prediction.id.clone());
1786
1787 // When a prediction completes, remove it from the pending list, and cancel
1788 // any pending predictions that were enqueued before it.
1789 this.update(cx, |this, cx| {
1790 let project_state = this.get_or_init_project(&project, cx);
1791
1792 let is_cancelled = project_state
1793 .cancelled_predictions
1794 .remove(&pending_prediction_id);
1795
1796 let new_current_prediction = if !is_cancelled
1797 && let Some((prediction_result, requested_by)) = new_prediction_result
1798 {
1799 match prediction_result.prediction {
1800 Ok(prediction) => {
1801 let new_prediction = CurrentEditPrediction {
1802 requested_by,
1803 prediction,
1804 was_shown: false,
1805 shown_with: None,
1806 };
1807
1808 if let Some(current_prediction) =
1809 project_state.current_prediction.as_ref()
1810 {
1811 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1812 {
1813 this.reject_current_prediction(
1814 EditPredictionRejectReason::Replaced,
1815 &project,
1816 cx,
1817 );
1818
1819 Some(new_prediction)
1820 } else {
1821 this.reject_prediction(
1822 new_prediction.prediction.id,
1823 EditPredictionRejectReason::CurrentPreferred,
1824 false,
1825 new_prediction.prediction.model_version,
1826 cx,
1827 );
1828 None
1829 }
1830 } else {
1831 Some(new_prediction)
1832 }
1833 }
1834 Err(reject_reason) => {
1835 this.reject_prediction(
1836 prediction_result.id,
1837 reject_reason,
1838 false,
1839 None,
1840 cx,
1841 );
1842 None
1843 }
1844 }
1845 } else {
1846 None
1847 };
1848
1849 let project_state = this.get_or_init_project(&project, cx);
1850
1851 if let Some(new_prediction) = new_current_prediction {
1852 project_state.current_prediction = Some(new_prediction);
1853 }
1854
1855 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1856 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1857 if pending_prediction.id == pending_prediction_id {
1858 pending_predictions.remove(ix);
1859 for pending_prediction in pending_predictions.drain(0..ix) {
1860 project_state.cancel_pending_prediction(pending_prediction, cx)
1861 }
1862 break;
1863 }
1864 }
1865 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1866 cx.notify();
1867 })
1868 .ok();
1869
1870 new_prediction_id
1871 });
1872
1873 if project_state.pending_predictions.len() < max_pending_predictions {
1874 project_state.pending_predictions.push(PendingPrediction {
1875 id: pending_prediction_id,
1876 task,
1877 drop_on_cancel,
1878 });
1879 } else {
1880 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1881 project_state.pending_predictions.push(PendingPrediction {
1882 id: pending_prediction_id,
1883 task,
1884 drop_on_cancel,
1885 });
1886 project_state.cancel_pending_prediction(pending_prediction, cx);
1887 }
1888 }
1889
1890 pub fn request_prediction(
1891 &mut self,
1892 project: &Entity<Project>,
1893 active_buffer: &Entity<Buffer>,
1894 position: language::Anchor,
1895 trigger: PredictEditsRequestTrigger,
1896 cx: &mut Context<Self>,
1897 ) -> Task<Result<Option<EditPredictionResult>>> {
1898 self.request_prediction_internal(
1899 project.clone(),
1900 active_buffer.clone(),
1901 position,
1902 trigger,
1903 cx.has_flag::<Zeta2FeatureFlag>(),
1904 cx,
1905 )
1906 }
1907
1908 fn request_prediction_internal(
1909 &mut self,
1910 project: Entity<Project>,
1911 active_buffer: Entity<Buffer>,
1912 position: language::Anchor,
1913 trigger: PredictEditsRequestTrigger,
1914 allow_jump: bool,
1915 cx: &mut Context<Self>,
1916 ) -> Task<Result<Option<EditPredictionResult>>> {
1917 self.get_or_init_project(&project, cx);
1918 let project_state = self.projects.get(&project.entity_id()).unwrap();
1919 let stored_events = project_state.events(cx);
1920 let has_events = !stored_events.is_empty();
1921 let events: Vec<Arc<zeta_prompt::Event>> =
1922 stored_events.into_iter().map(|e| e.event).collect();
1923 let debug_tx = project_state.debug_tx.clone();
1924
1925 let snapshot = active_buffer.read(cx).snapshot();
1926 let cursor_point = position.to_point(&snapshot);
1927 let current_offset = position.to_offset(&snapshot);
1928
1929 let mut user_actions: Vec<UserActionRecord> =
1930 project_state.user_actions.iter().cloned().collect();
1931
1932 if let Some(last_action) = user_actions.last() {
1933 if last_action.buffer_id == active_buffer.entity_id()
1934 && current_offset != last_action.offset
1935 {
1936 let timestamp_epoch_ms = SystemTime::now()
1937 .duration_since(UNIX_EPOCH)
1938 .map(|d| d.as_millis() as u64)
1939 .unwrap_or(0);
1940 user_actions.push(UserActionRecord {
1941 action_type: UserActionType::CursorMovement,
1942 buffer_id: active_buffer.entity_id(),
1943 line_number: cursor_point.row,
1944 offset: current_offset,
1945 timestamp_epoch_ms,
1946 });
1947 }
1948 }
1949 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1950 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1951 let diagnostic_search_range =
1952 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1953
1954 let related_files = self.context_for_project(&project, cx);
1955
1956 let inputs = EditPredictionModelInput {
1957 project: project.clone(),
1958 buffer: active_buffer,
1959 snapshot: snapshot,
1960 position,
1961 events,
1962 related_files,
1963 recent_paths: project_state.recent_paths.clone(),
1964 trigger,
1965 diagnostic_search_range: diagnostic_search_range,
1966 debug_tx,
1967 user_actions,
1968 };
1969
1970 let task = match self.edit_prediction_model {
1971 EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
1972 self,
1973 inputs,
1974 Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1975 cx,
1976 ),
1977 EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
1978 self,
1979 inputs,
1980 Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1981 cx,
1982 ),
1983 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
1984 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1985 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1986 };
1987
1988 cx.spawn(async move |this, cx| {
1989 let prediction = task.await?;
1990
1991 if prediction.is_none() && allow_jump && has_events {
1992 this.update(cx, |this, cx| {
1993 this.refresh_prediction_from_diagnostics(
1994 project,
1995 DiagnosticSearchScope::Local,
1996 cx,
1997 );
1998 })?;
1999 return anyhow::Ok(None);
2000 }
2001
2002 Ok(prediction)
2003 })
2004 }
2005
2006 pub(crate) async fn next_diagnostic_location(
2007 active_buffer: Entity<Buffer>,
2008 active_buffer_snapshot: &BufferSnapshot,
2009 active_buffer_diagnostic_search_range: Range<Point>,
2010 active_buffer_cursor_point: Point,
2011 project: &Entity<Project>,
2012 cx: &mut AsyncApp,
2013 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2014 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2015 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2016 .flat_map(|(_, _, _, selections)| {
2017 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2018 })
2019 .collect();
2020
2021 let mut jump_location = active_buffer_snapshot
2022 .diagnostic_groups(None)
2023 .into_iter()
2024 .filter_map(|(_, group)| {
2025 let range = &group.entries[group.primary_ix]
2026 .range
2027 .to_point(&active_buffer_snapshot);
2028 if range.overlaps(&active_buffer_diagnostic_search_range) {
2029 return None;
2030 }
2031 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2032 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2033 });
2034 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2035 <= DIAGNOSTIC_LINES_RANGE;
2036 if near_collaborator && !near_local {
2037 return None;
2038 }
2039 Some(range.start)
2040 })
2041 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2042 .map(|position| {
2043 (
2044 active_buffer.clone(),
2045 active_buffer_snapshot.anchor_before(position),
2046 )
2047 });
2048
2049 if jump_location.is_none() {
2050 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2051 let file = buffer.file()?;
2052
2053 Some(ProjectPath {
2054 worktree_id: file.worktree_id(cx),
2055 path: file.path().clone(),
2056 })
2057 });
2058
2059 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2060 project
2061 .diagnostic_summaries(false, cx)
2062 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2063 .map(|(path, _, _)| {
2064 let shared_prefix = path
2065 .path
2066 .components()
2067 .zip(
2068 active_buffer_path
2069 .as_ref()
2070 .map(|p| p.path.components())
2071 .unwrap_or_default(),
2072 )
2073 .take_while(|(a, b)| a == b)
2074 .count();
2075 (path, shared_prefix)
2076 })
2077 .collect()
2078 });
2079
2080 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2081
2082 for (path, _) in candidates {
2083 let candidate_buffer = project
2084 .update(cx, |project, cx| project.open_buffer(path, cx))
2085 .await?;
2086
2087 let (has_collaborators, diagnostic_position) =
2088 candidate_buffer.read_with(cx, |buffer, _cx| {
2089 let snapshot = buffer.snapshot();
2090 let has_collaborators = snapshot
2091 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2092 .next()
2093 .is_some();
2094 let position = buffer
2095 .buffer_diagnostics(None)
2096 .into_iter()
2097 .min_by_key(|entry| entry.diagnostic.severity)
2098 .map(|entry| entry.range.start);
2099 (has_collaborators, position)
2100 });
2101
2102 if has_collaborators {
2103 continue;
2104 }
2105
2106 if let Some(position) = diagnostic_position {
2107 jump_location = Some((candidate_buffer, position));
2108 break;
2109 }
2110 }
2111 }
2112
2113 anyhow::Ok(jump_location)
2114 }
2115
2116 async fn send_raw_llm_request(
2117 request: RawCompletionRequest,
2118 client: Arc<Client>,
2119 custom_url: Option<Arc<Url>>,
2120 llm_token: LlmApiToken,
2121 app_version: Version,
2122 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2123 let url = if let Some(custom_url) = custom_url {
2124 custom_url.as_ref().clone()
2125 } else {
2126 client
2127 .http_client()
2128 .build_zed_llm_url("/predict_edits/raw", &[])?
2129 };
2130
2131 Self::send_api_request(
2132 |builder| {
2133 let req = builder
2134 .uri(url.as_ref())
2135 .body(serde_json::to_string(&request)?.into());
2136 Ok(req?)
2137 },
2138 client,
2139 llm_token,
2140 app_version,
2141 true,
2142 )
2143 .await
2144 }
2145
2146 pub(crate) async fn send_v3_request(
2147 input: ZetaPromptInput,
2148 client: Arc<Client>,
2149 llm_token: LlmApiToken,
2150 app_version: Version,
2151 trigger: PredictEditsRequestTrigger,
2152 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2153 let url = client
2154 .http_client()
2155 .build_zed_llm_url("/predict_edits/v3", &[])?;
2156
2157 let request = PredictEditsV3Request { input, trigger };
2158
2159 let json_bytes = serde_json::to_vec(&request)?;
2160 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2161
2162 Self::send_api_request(
2163 |builder| {
2164 let req = builder
2165 .uri(url.as_ref())
2166 .header("Content-Encoding", "zstd")
2167 .body(compressed.clone().into());
2168 Ok(req?)
2169 },
2170 client,
2171 llm_token,
2172 app_version,
2173 true,
2174 )
2175 .await
2176 }
2177
2178 fn handle_api_response<T>(
2179 this: &WeakEntity<Self>,
2180 response: Result<(T, Option<EditPredictionUsage>)>,
2181 cx: &mut gpui::AsyncApp,
2182 ) -> Result<T> {
2183 match response {
2184 Ok((data, usage)) => {
2185 if let Some(usage) = usage {
2186 this.update(cx, |this, cx| {
2187 this.user_store.update(cx, |user_store, cx| {
2188 user_store.update_edit_prediction_usage(usage, cx);
2189 });
2190 })
2191 .ok();
2192 }
2193 Ok(data)
2194 }
2195 Err(err) => {
2196 if err.is::<ZedUpdateRequiredError>() {
2197 cx.update(|cx| {
2198 this.update(cx, |this, _cx| {
2199 this.update_required = true;
2200 })
2201 .ok();
2202
2203 let error_message: SharedString = err.to_string().into();
2204 show_app_notification(
2205 NotificationId::unique::<ZedUpdateRequiredError>(),
2206 cx,
2207 move |cx| {
2208 cx.new(|cx| {
2209 ErrorMessagePrompt::new(error_message.clone(), cx)
2210 .with_link_button("Update Zed", "https://zed.dev/releases")
2211 })
2212 },
2213 );
2214 });
2215 }
2216 Err(err)
2217 }
2218 }
2219 }
2220
2221 async fn send_api_request<Res>(
2222 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2223 client: Arc<Client>,
2224 llm_token: LlmApiToken,
2225 app_version: Version,
2226 require_auth: bool,
2227 ) -> Result<(Res, Option<EditPredictionUsage>)>
2228 where
2229 Res: DeserializeOwned,
2230 {
2231 let http_client = client.http_client();
2232
2233 let mut token = if require_auth {
2234 Some(llm_token.acquire(&client).await?)
2235 } else {
2236 llm_token.acquire(&client).await.ok()
2237 };
2238 let mut did_retry = false;
2239
2240 loop {
2241 let request_builder = http_client::Request::builder().method(Method::POST);
2242
2243 let mut request_builder = request_builder
2244 .header("Content-Type", "application/json")
2245 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2246
2247 // Only add Authorization header if we have a token
2248 if let Some(ref token_value) = token {
2249 request_builder =
2250 request_builder.header("Authorization", format!("Bearer {}", token_value));
2251 }
2252
2253 let request = build(request_builder)?;
2254
2255 let mut response = http_client.send(request).await?;
2256
2257 if let Some(minimum_required_version) = response
2258 .headers()
2259 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2260 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2261 {
2262 anyhow::ensure!(
2263 app_version >= minimum_required_version,
2264 ZedUpdateRequiredError {
2265 minimum_version: minimum_required_version
2266 }
2267 );
2268 }
2269
2270 if response.status().is_success() {
2271 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2272
2273 let mut body = Vec::new();
2274 response.body_mut().read_to_end(&mut body).await?;
2275 return Ok((serde_json::from_slice(&body)?, usage));
2276 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2277 did_retry = true;
2278 token = Some(llm_token.refresh(&client).await?);
2279 } else {
2280 let mut body = String::new();
2281 response.body_mut().read_to_string(&mut body).await?;
2282 anyhow::bail!(
2283 "Request failed with status: {:?}\nBody: {}",
2284 response.status(),
2285 body
2286 );
2287 }
2288 }
2289 }
2290
2291 pub fn refresh_context(
2292 &mut self,
2293 project: &Entity<Project>,
2294 buffer: &Entity<language::Buffer>,
2295 cursor_position: language::Anchor,
2296 cx: &mut Context<Self>,
2297 ) {
2298 self.get_or_init_project(project, cx)
2299 .context
2300 .update(cx, |store, cx| {
2301 store.refresh(buffer.clone(), cursor_position, cx);
2302 });
2303 }
2304
2305 #[cfg(feature = "cli-support")]
2306 pub fn set_context_for_buffer(
2307 &mut self,
2308 project: &Entity<Project>,
2309 related_files: Vec<RelatedFile>,
2310 cx: &mut Context<Self>,
2311 ) {
2312 self.get_or_init_project(project, cx)
2313 .context
2314 .update(cx, |store, cx| {
2315 store.set_related_files(related_files, cx);
2316 });
2317 }
2318
2319 #[cfg(feature = "cli-support")]
2320 pub fn set_recent_paths_for_project(
2321 &mut self,
2322 project: &Entity<Project>,
2323 paths: impl IntoIterator<Item = project::ProjectPath>,
2324 cx: &mut Context<Self>,
2325 ) {
2326 let project_state = self.get_or_init_project(project, cx);
2327 project_state.recent_paths = paths.into_iter().collect();
2328 }
2329
2330 fn is_file_open_source(
2331 &self,
2332 project: &Entity<Project>,
2333 file: &Arc<dyn File>,
2334 cx: &App,
2335 ) -> bool {
2336 if !file.is_local() || file.is_private() {
2337 return false;
2338 }
2339 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2340 return false;
2341 };
2342 project_state
2343 .license_detection_watchers
2344 .get(&file.worktree_id(cx))
2345 .as_ref()
2346 .is_some_and(|watcher| watcher.is_project_open_source())
2347 }
2348
2349 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2350 self.data_collection_choice.is_enabled(cx)
2351 }
2352
2353 fn load_data_collection_choice() -> DataCollectionChoice {
2354 let choice = KEY_VALUE_STORE
2355 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2356 .log_err()
2357 .flatten();
2358
2359 match choice.as_deref() {
2360 Some("true") => DataCollectionChoice::Enabled,
2361 Some("false") => DataCollectionChoice::Disabled,
2362 Some(_) => {
2363 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2364 DataCollectionChoice::NotAnswered
2365 }
2366 None => DataCollectionChoice::NotAnswered,
2367 }
2368 }
2369
2370 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2371 self.data_collection_choice = self.data_collection_choice.toggle();
2372 let new_choice = self.data_collection_choice;
2373 let is_enabled = new_choice.is_enabled(cx);
2374 db::write_and_log(cx, move || {
2375 KEY_VALUE_STORE.write_kvp(
2376 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2377 is_enabled.to_string(),
2378 )
2379 });
2380 }
2381
2382 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2383 self.shown_predictions.iter()
2384 }
2385
2386 pub fn shown_completions_len(&self) -> usize {
2387 self.shown_predictions.len()
2388 }
2389
2390 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2391 self.rated_predictions.contains(id)
2392 }
2393
2394 pub fn rate_prediction(
2395 &mut self,
2396 prediction: &EditPrediction,
2397 rating: EditPredictionRating,
2398 feedback: String,
2399 cx: &mut Context<Self>,
2400 ) {
2401 let organization = self.user_store.read(cx).current_organization();
2402
2403 self.rated_predictions.insert(prediction.id.clone());
2404
2405 cx.background_spawn({
2406 let client = self.client.clone();
2407 let prediction_id = prediction.id.to_string();
2408 let inputs = serde_json::to_value(&prediction.inputs);
2409 let output = prediction
2410 .edit_preview
2411 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2412 async move {
2413 client
2414 .cloud_client()
2415 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2416 organization_id: organization.map(|organization| organization.id.clone()),
2417 request_id: prediction_id,
2418 rating: match rating {
2419 EditPredictionRating::Positive => "positive".to_string(),
2420 EditPredictionRating::Negative => "negative".to_string(),
2421 },
2422 inputs: inputs?,
2423 output,
2424 feedback,
2425 })
2426 .await?;
2427
2428 anyhow::Ok(())
2429 }
2430 })
2431 .detach_and_log_err(cx);
2432
2433 cx.notify();
2434 }
2435}
2436
2437fn merge_trailing_events_if_needed(
2438 events: &mut VecDeque<StoredEvent>,
2439 end_snapshot: &TextBufferSnapshot,
2440 latest_snapshot: &TextBufferSnapshot,
2441 latest_edit_range: &Range<Anchor>,
2442) {
2443 if let Some(last_event) = events.back() {
2444 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2445 return;
2446 }
2447 }
2448
2449 let mut next_old_event = None;
2450 let mut mergeable_count = 0;
2451 for old_event in events.iter().rev() {
2452 if let Some(next_old_event) = &next_old_event
2453 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2454 {
2455 break;
2456 }
2457 mergeable_count += 1;
2458 next_old_event = Some(old_event);
2459 }
2460
2461 if mergeable_count <= 1 {
2462 return;
2463 }
2464
2465 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2466 let oldest_event = events_to_merge.peek().unwrap();
2467 let oldest_snapshot = oldest_event.old_snapshot.clone();
2468
2469 if let Some((diff, edited_range)) =
2470 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2471 {
2472 let merged_event = match oldest_event.event.as_ref() {
2473 zeta_prompt::Event::BufferChange {
2474 old_path,
2475 path,
2476 in_open_source_repo,
2477 ..
2478 } => StoredEvent {
2479 event: Arc::new(zeta_prompt::Event::BufferChange {
2480 old_path: old_path.clone(),
2481 path: path.clone(),
2482 diff,
2483 in_open_source_repo: *in_open_source_repo,
2484 predicted: events_to_merge.all(|e| {
2485 matches!(
2486 e.event.as_ref(),
2487 zeta_prompt::Event::BufferChange {
2488 predicted: true,
2489 ..
2490 }
2491 )
2492 }),
2493 }),
2494 old_snapshot: oldest_snapshot.clone(),
2495 edit_range: end_snapshot.anchor_before(edited_range.start)
2496 ..end_snapshot.anchor_before(edited_range.end),
2497 },
2498 };
2499 events.truncate(events.len() - mergeable_count);
2500 events.push_back(merged_event);
2501 }
2502}
2503
2504pub(crate) fn filter_redundant_excerpts(
2505 mut related_files: Vec<RelatedFile>,
2506 cursor_path: &Path,
2507 cursor_row_range: Range<u32>,
2508) -> Vec<RelatedFile> {
2509 for file in &mut related_files {
2510 if file.path.as_ref() == cursor_path {
2511 file.excerpts.retain(|excerpt| {
2512 excerpt.row_range.start < cursor_row_range.start
2513 || excerpt.row_range.end > cursor_row_range.end
2514 });
2515 }
2516 }
2517 related_files.retain(|file| !file.excerpts.is_empty());
2518 related_files
2519}
2520
2521#[derive(Error, Debug)]
2522#[error(
2523 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2524)]
2525pub struct ZedUpdateRequiredError {
2526 minimum_version: Version,
2527}
2528
2529#[derive(Debug, Clone, Copy)]
2530pub enum DataCollectionChoice {
2531 NotAnswered,
2532 Enabled,
2533 Disabled,
2534}
2535
2536impl DataCollectionChoice {
2537 pub fn is_enabled(self, cx: &App) -> bool {
2538 if cx.is_staff() {
2539 return true;
2540 }
2541 match self {
2542 Self::Enabled => true,
2543 Self::NotAnswered | Self::Disabled => false,
2544 }
2545 }
2546
2547 #[must_use]
2548 pub fn toggle(&self) -> DataCollectionChoice {
2549 match self {
2550 Self::Enabled => Self::Disabled,
2551 Self::Disabled => Self::Enabled,
2552 Self::NotAnswered => Self::Enabled,
2553 }
2554 }
2555}
2556
2557impl From<bool> for DataCollectionChoice {
2558 fn from(value: bool) -> Self {
2559 match value {
2560 true => DataCollectionChoice::Enabled,
2561 false => DataCollectionChoice::Disabled,
2562 }
2563 }
2564}
2565
2566struct ZedPredictUpsell;
2567
2568impl Dismissable for ZedPredictUpsell {
2569 const KEY: &'static str = "dismissed-edit-predict-upsell";
2570
2571 fn dismissed() -> bool {
2572 // To make this backwards compatible with older versions of Zed, we
2573 // check if the user has seen the previous Edit Prediction Onboarding
2574 // before, by checking the data collection choice which was written to
2575 // the database once the user clicked on "Accept and Enable"
2576 if KEY_VALUE_STORE
2577 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2578 .log_err()
2579 .is_some_and(|s| s.is_some())
2580 {
2581 return true;
2582 }
2583
2584 KEY_VALUE_STORE
2585 .read_kvp(Self::KEY)
2586 .log_err()
2587 .is_some_and(|s| s.is_some())
2588 }
2589}
2590
2591pub fn should_show_upsell_modal() -> bool {
2592 !ZedPredictUpsell::dismissed()
2593}
2594
2595pub fn init(cx: &mut App) {
2596 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2597 workspace.register_action(
2598 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2599 ZedPredictModal::toggle(
2600 workspace,
2601 workspace.user_store().clone(),
2602 workspace.client().clone(),
2603 window,
2604 cx,
2605 )
2606 },
2607 );
2608
2609 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2610 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2611 settings
2612 .project
2613 .all_languages
2614 .edit_predictions
2615 .get_or_insert_default()
2616 .provider = Some(EditPredictionProvider::None)
2617 });
2618 });
2619 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2620 EditPredictionStore::try_global(cx).and_then(|store| {
2621 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2622 })
2623 }
2624
2625 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2626 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2627 copilot_ui::initiate_sign_in(copilot, window, cx);
2628 }
2629 });
2630 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2631 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2632 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2633 }
2634 });
2635 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2636 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2637 copilot_ui::initiate_sign_out(copilot, window, cx);
2638 }
2639 });
2640 })
2641 .detach();
2642}